Skip to content
Snippets Groups Projects
Commit 08fc6895 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Cache images and crop

parent 8a519d32
No related branches found
No related tags found
1 merge request!232Cache images and crop
......@@ -3,6 +3,10 @@ stages:
- build
- deploy
variables:
# Submodule clone
GIT_SUBMODULE_STRATEGY: recursive
lint:
image: python:3.10
stage: test
......
[submodule "line_image_extractor"]
path = teklia_line_image_extractor
url = ../line_image_extractor.git
......@@ -4,9 +4,21 @@
## Documentation
To use DAN in your own scripts, install it using pip:
To use DAN in your own environment, you need to first clone with its submodules via:
```console
```shell
git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
```
If you forgot the `--recurse-submodules`, you can initialize the submodule using:
```shell
git submodule update --init
```
Then you can install it via pip:
```shell
pip install -e .
```
......
......@@ -138,4 +138,12 @@ def add_extract_parser(subcommands) -> None:
help="Images larger than this height will be resized to this width.",
)
parser.add_argument(
"--cache",
dest="cache_dir",
type=pathlib.Path,
help="Where the images should be cached.",
default=pathlib.Path(".cache"),
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
import ast
from dataclasses import dataclass
from itertools import starmap
from typing import List, Optional, Union
from urllib.parse import urljoin
from typing import List, Union
from arkindex_export import Image
from arkindex_export.models import Element as ArkindexElement
from arkindex_export.models import (
Element,
Entity,
EntityType,
Transcription,
......@@ -17,51 +13,10 @@ from arkindex_export.models import (
from arkindex_export.queries import list_children
def bounding_box(polygon: list):
"""
Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return int(x), int(y), int(width), int(height)
@dataclass
class Element:
id: str
type: str
polygon: str
url: str
width: int
height: int
max_width: Optional[int] = None
max_height: Optional[int] = None
def __post_init__(self):
self.max_height = self.max_height or self.height
self.max_width = self.max_width or self.width
@property
def bounding_box(self):
return bounding_box(ast.literal_eval(self.polygon))
@property
def image_url(self):
x, y, width, height = self.bounding_box
return urljoin(
self.url + "/",
f"{x},{y},{width},{height}/!{self.max_width},{self.max_height}/0/default.jpg",
)
def get_elements(
parent_id: str,
element_type: List[str],
max_width: Optional[int] = None,
max_height: Optional[int] = None,
) -> List[Element]:
):
"""
Retrieve elements from an SQLite export of an Arkindex corpus
"""
......@@ -69,23 +24,10 @@ def get_elements(
query = (
list_children(parent_id=parent_id)
.join(Image)
.where(ArkindexElement.type.in_(element_type))
.select(
ArkindexElement.id,
ArkindexElement.type,
ArkindexElement.polygon,
Image.url,
Image.width,
Image.height,
)
.where(Element.type.in_(element_type))
)
return list(
starmap(
lambda *x: Element(*x, max_width=max_width, max_height=max_height),
query.tuples(),
)
)
return query
def build_worker_version_filter(ArkindexModel, worker_version):
......
# -*- coding: utf-8 -*-
import json
import random
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Union
from uuid import UUID
import numpy as np
from tqdm import tqdm
from arkindex_export import open_database
......@@ -27,10 +29,13 @@ from dan.datasets.extract.utils import (
insert_token,
parse_tokens,
)
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"]
IIIF_URL_SUFFIX = "/full/full/0/default.jpg"
class ArkindexExtractor:
......@@ -51,6 +56,7 @@ class ArkindexExtractor:
entity_worker_version: Optional[Union[str, bool]] = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
cache_dir: Path = Path(".cache"),
) -> None:
self.folders = folders
self.element_type = element_type
......@@ -64,6 +70,19 @@ class ArkindexExtractor:
self.max_width = max_width
self.max_height = max_height
self.cache_dir = cache_dir
# Create cache dir if non existent
self.cache_dir.mkdir(exist_ok=True, parents=True)
def find_image_in_cache(self, image_id: str) -> Path:
"""Images are cached to avoid downloading them twice. They are stored under a specific name,
based on their Arkindex ID. Images are saved under the JPEG format.
:param image_id: ID of the image. The image is saved under this name.
:return: Where the image should be saved in the cache folder.
"""
return self.cache_dir / f"{image_id}.jpg"
def _keep_char(self, char: str) -> bool:
# Keep all text by default if no separator was given
return not self.entity_separators or char in self.entity_separators
......@@ -148,6 +167,52 @@ class ArkindexExtractor:
)
return self.reconstruct_text(transcription.text, entities)
def retrieve_image(self, child: Element):
"""Get or download image of the element. Checks in cache before downloading.
:param child: Processed element.
:return: The element's image.
"""
cached_img_path = self.find_image_in_cache(child.image.id)
if not cached_img_path.exists():
# Save in cache
download_image(child.image.url + IIIF_URL_SUFFIX).save(
cached_img_path, format="jpeg"
)
return read_img(cached_img_path)
def get_image(self, child: Element, destination: Path) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param child: Processed element.
:param destination: Where the image should be saved.
"""
polygon = json.loads(str(child.polygon))
if self.max_height or self.max_width:
polygon = resize(
polygon,
self.max_width,
self.max_height,
scale_x=1.0,
scale_y_top=1.0,
scale_y_bottom=1.0,
)
# Extract the polygon in the image
image = extract(
img=self.retrieve_image(child),
polygon=np.array(polygon),
bbox=polygon_to_bbox(polygon),
# Hardcoded while we don't have a configuration file
extraction_mode=Extraction.deskew_min_area_rect,
max_deskew_angle=45,
)
# Save the image to disk
save_img(path=destination, img=image)
def process_element(
self,
element: Element,
......@@ -162,13 +227,14 @@ class ArkindexExtractor:
base_path = Path(split, f"{element.type}_{element.id}")
Path(self.output, LABELS_DIR, base_path).with_suffix(".txt").write_text(text)
download_image(
element, Path(self.output, LABELS_DIR, base_path).with_suffix(".jpg")
self.get_image(
element, Path(self.output, IMAGES_DIR, base_path).with_suffix(".jpg")
)
return element.id
def process_parent(
self,
pbar,
parent: Element,
split: str,
):
......@@ -176,7 +242,10 @@ class ArkindexExtractor:
Extract data from a parent element.
"""
data = defaultdict(list)
base_description = (
f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
)
pbar.set_description(desc=base_description)
if self.element_type == [parent.type]:
try:
data[parent.type].append(self.process_element(parent, split))
......@@ -184,12 +253,15 @@ class ArkindexExtractor:
logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements
else:
for element in get_elements(
children = get_elements(
parent.id,
self.element_type,
max_width=self.max_width,
max_height=self.max_height,
):
)
nb_children = children.count()
for idx, element in enumerate(children, start=1):
# Update description to update the children processing progress
pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
try:
data[element.type].append(self.process_element(element, split))
except ProcessingError as e:
......@@ -198,23 +270,24 @@ class ArkindexExtractor:
def run(self):
# Iterate over the subsets to find the page images and labels.
for idx, (folder_id, split) in enumerate(
zip(self.folders, SPLIT_NAMES), start=1
):
# Iterate over the pages to create splits at page level.
for parent in tqdm(
for folder_id, split in zip(self.folders, SPLIT_NAMES):
with tqdm(
get_elements(
folder_id,
[self.parent_element_type],
max_width=self.max_width,
max_height=self.max_height,
),
desc=f"Processing {folder_id} {idx}/{len(self.subsets)}",
):
self.process_parent(
parent=parent,
split=split,
)
desc=f"Extracting data from ({folder_id}) for split ({split})",
) as pbar:
# Iterate over the pages to create splits at page level.
for parent in pbar:
self.process_parent(
pbar=pbar,
parent=parent,
split=split,
)
# Progress bar updates
pbar.update()
pbar.refresh()
def run(
......@@ -232,6 +305,7 @@ def run(
entity_worker_version: Optional[Union[str, bool]],
max_width: Optional[int],
max_height: Optional[int],
cache_dir: Path,
):
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
......@@ -258,4 +332,5 @@ def run(
entity_worker_version=entity_worker_version,
max_width=max_width,
max_height=max_height,
cache_dir=cache_dir,
).run()
# -*- coding: utf-8 -*-
import logging
import time
from io import BytesIO
from pathlib import Path
from typing import NamedTuple
import cv2
import imageio.v2 as iio
import requests
import yaml
from numpy import ndarray
from dan.datasets.extract.db import Element
from dan.datasets.extract.exceptions import ImageDownloadError
from PIL import Image
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
MAX_RETRIES = 5
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
class EntityType(NamedTuple):
......@@ -26,29 +36,36 @@ class EntityType(NamedTuple):
return len(self.start) + len(self.end)
def download_image(element: Element, im_path: Path):
if im_path.exists():
return im_path
tries = 1
# retry loop
while True:
if tries > MAX_RETRIES:
raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
try:
image = iio.imread(element.image_url)
save_image(im_path, image)
return
except TimeoutError:
logger.warning("Timeout, retry in 1 second.")
time.sleep(1)
tries += 1
except Exception as e:
raise ImageDownloadError(element.id, e)
def save_image(path: Path, image: ndarray):
cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url):
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url):
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = _retried_request(url)
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str:
......
# Get started
To use DAN in your own environment, install it using pip:
To use DAN in your own environment, you need to first clone with its submodules via:
```shell
git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
```
If you forgot the `--recurse-submodules`, you can initialize the submodule using:
```shell
git submodule update --init
```
Then you can install it via pip:
```shell
pip install -e .
......
......@@ -4,7 +4,7 @@ ignore = ["E501"]
select = ["E", "F", "T1", "W", "I"]
[tool.ruff.isort]
known-first-party = ["arkindex_export"]
known-first-party = ["arkindex_export", "line_image_extractor"]
known-third-party = [
"albumentations",
"cv2",
......
-e ./teklia_line_image_extractor
albumentations==1.3.1
arkindex-export==0.1.3
boto3==1.26.124
editdistance==0.6.2
imageio==2.26.1
numpy==1.24.3
opencv-python==4.7.0.72
PyYAML==6.0
scipy==1.10.1
tenacity==8.2.2
tensorboard==2.12.2
torch==2.0.0
torchvision==0.15.1
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from pathlib import Path
from typing import List
from setuptools import find_packages, setup
def parse_requirements(path):
assert os.path.exists(path), "Missing requirements {}".format(path)
with open(path) as f:
return list(map(str.strip, f.read().splitlines()))
def parse_requirements_line(line) -> str:
# Special case for git requirements
if line.startswith("git+http"):
assert "@" in line, "Branch should be specified with suffix (ex: @master)"
assert (
"#egg=" in line
), "Package name should be specified with suffix (ex: #egg=kraken)"
package_name: str = line.split("#egg=")[-1]
return f"{package_name} @ {line}"
# Special case for submodule requirements
elif line.startswith("-e"):
package_path: str = line.split(" ")[-1]
package = Path(package_path).resolve()
return f"{package.name} @ file://{package}"
else:
return line
def parse_requirements(filename: str) -> List[str]:
path = Path(__file__).parent.resolve() / filename
assert path.exists(), f"Missing requirements: {path}"
return list(
map(parse_requirements_line, map(str.strip, path.read_text().splitlines()))
)
setup(
......
Subproject commit 210c64939d62a8d915dcedbca7dcd529652e5a8b
......@@ -12,3 +12,6 @@ deps =
-rrequirements.txt
commands =
pytest {tty:--color=yes} {posargs}
[pytest]
testpaths= tests
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment