From 08fc6895e39dedb4836b3cc576432e3e87d564eb Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Tue, 8 Aug 2023 09:59:40 +0000 Subject: [PATCH] Cache images and crop --- .gitlab-ci.yml | 4 ++ .gitmodules | 3 + README.md | 16 ++++- dan/datasets/extract/__init__.py | 8 +++ dan/datasets/extract/db.py | 68 ++---------------- dan/datasets/extract/extract.py | 115 +++++++++++++++++++++++++------ dan/datasets/extract/utils.py | 79 ++++++++++++--------- docs/get_started/index.md | 14 +++- pyproject.toml | 2 +- requirements.txt | 3 +- setup.py | 31 +++++++-- teklia_line_image_extractor | 1 + tox.ini | 3 + 13 files changed, 223 insertions(+), 124 deletions(-) create mode 100644 .gitmodules create mode 160000 teklia_line_image_extractor diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8d5fabc7..6a4824e6 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -3,6 +3,10 @@ stages: - build - deploy +variables: + # Submodule clone + GIT_SUBMODULE_STRATEGY: recursive + lint: image: python:3.10 stage: test diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..33e01713 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "line_image_extractor"] + path = teklia_line_image_extractor + url = ../line_image_extractor.git diff --git a/README.md b/README.md index c9bb3d16..40c71bc6 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,21 @@ ## Documentation -To use DAN in your own scripts, install it using pip: +To use DAN in your own environment, you need to first clone with its submodules via: -```console +```shell +git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git +``` + +If you forgot the `--recurse-submodules`, you can initialize the submodule using: + +```shell +git submodule update --init +``` + +Then you can install it via pip: + +```shell pip install -e . ``` diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 49d7c7e8..99f255d3 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -138,4 +138,12 @@ def add_extract_parser(subcommands) -> None: help="Images larger than this height will be resized to this width.", ) + parser.add_argument( + "--cache", + dest="cache_dir", + type=pathlib.Path, + help="Where the images should be cached.", + default=pathlib.Path(".cache"), + ) + parser.set_defaults(func=run) diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index 52bf4d17..a8933d06 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -1,14 +1,10 @@ # -*- coding: utf-8 -*- -import ast -from dataclasses import dataclass -from itertools import starmap -from typing import List, Optional, Union -from urllib.parse import urljoin +from typing import List, Union from arkindex_export import Image -from arkindex_export.models import Element as ArkindexElement from arkindex_export.models import ( + Element, Entity, EntityType, Transcription, @@ -17,51 +13,10 @@ from arkindex_export.models import ( from arkindex_export.queries import list_children -def bounding_box(polygon: list): - """ - Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points) - """ - all_x, all_y = zip(*polygon) - x, y = min(all_x), min(all_y) - width, height = max(all_x) - x, max(all_y) - y - return int(x), int(y), int(width), int(height) - - -@dataclass -class Element: - id: str - type: str - polygon: str - url: str - width: int - height: int - - max_width: Optional[int] = None - max_height: Optional[int] = None - - def __post_init__(self): - self.max_height = self.max_height or self.height - self.max_width = self.max_width or self.width - - @property - def bounding_box(self): - return bounding_box(ast.literal_eval(self.polygon)) - - @property - def image_url(self): - x, y, width, height = self.bounding_box - return urljoin( - self.url + "/", - f"{x},{y},{width},{height}/!{self.max_width},{self.max_height}/0/default.jpg", - ) - - def get_elements( parent_id: str, element_type: List[str], - max_width: Optional[int] = None, - max_height: Optional[int] = None, -) -> List[Element]: +): """ Retrieve elements from an SQLite export of an Arkindex corpus """ @@ -69,23 +24,10 @@ def get_elements( query = ( list_children(parent_id=parent_id) .join(Image) - .where(ArkindexElement.type.in_(element_type)) - .select( - ArkindexElement.id, - ArkindexElement.type, - ArkindexElement.polygon, - Image.url, - Image.width, - Image.height, - ) + .where(Element.type.in_(element_type)) ) - return list( - starmap( - lambda *x: Element(*x, max_width=max_width, max_height=max_height), - query.tuples(), - ) - ) + return query def build_worker_version_filter(ArkindexModel, worker_version): diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 0e1879e0..b5e94a0a 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- +import json import random from collections import defaultdict from pathlib import Path from typing import List, Optional, Union from uuid import UUID +import numpy as np from tqdm import tqdm from arkindex_export import open_database @@ -27,10 +29,13 @@ from dan.datasets.extract.utils import ( insert_token, parse_tokens, ) +from line_image_extractor.extractor import extract, read_img, save_img +from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize IMAGES_DIR = "images" # Subpath to the images directory. LABELS_DIR = "labels" # Subpath to the labels directory. SPLIT_NAMES = ["train", "val", "test"] +IIIF_URL_SUFFIX = "/full/full/0/default.jpg" class ArkindexExtractor: @@ -51,6 +56,7 @@ class ArkindexExtractor: entity_worker_version: Optional[Union[str, bool]] = None, max_width: Optional[int] = None, max_height: Optional[int] = None, + cache_dir: Path = Path(".cache"), ) -> None: self.folders = folders self.element_type = element_type @@ -64,6 +70,19 @@ class ArkindexExtractor: self.max_width = max_width self.max_height = max_height + self.cache_dir = cache_dir + # Create cache dir if non existent + self.cache_dir.mkdir(exist_ok=True, parents=True) + + def find_image_in_cache(self, image_id: str) -> Path: + """Images are cached to avoid downloading them twice. They are stored under a specific name, + based on their Arkindex ID. Images are saved under the JPEG format. + + :param image_id: ID of the image. The image is saved under this name. + :return: Where the image should be saved in the cache folder. + """ + return self.cache_dir / f"{image_id}.jpg" + def _keep_char(self, char: str) -> bool: # Keep all text by default if no separator was given return not self.entity_separators or char in self.entity_separators @@ -148,6 +167,52 @@ class ArkindexExtractor: ) return self.reconstruct_text(transcription.text, entities) + def retrieve_image(self, child: Element): + """Get or download image of the element. Checks in cache before downloading. + + :param child: Processed element. + :return: The element's image. + """ + cached_img_path = self.find_image_in_cache(child.image.id) + if not cached_img_path.exists(): + # Save in cache + download_image(child.image.url + IIIF_URL_SUFFIX).save( + cached_img_path, format="jpeg" + ) + + return read_img(cached_img_path) + + def get_image(self, child: Element, destination: Path) -> None: + """Save the element's image to the given path and applies any image operations needed. + + :param child: Processed element. + :param destination: Where the image should be saved. + """ + polygon = json.loads(str(child.polygon)) + + if self.max_height or self.max_width: + polygon = resize( + polygon, + self.max_width, + self.max_height, + scale_x=1.0, + scale_y_top=1.0, + scale_y_bottom=1.0, + ) + + # Extract the polygon in the image + image = extract( + img=self.retrieve_image(child), + polygon=np.array(polygon), + bbox=polygon_to_bbox(polygon), + # Hardcoded while we don't have a configuration file + extraction_mode=Extraction.deskew_min_area_rect, + max_deskew_angle=45, + ) + + # Save the image to disk + save_img(path=destination, img=image) + def process_element( self, element: Element, @@ -162,13 +227,14 @@ class ArkindexExtractor: base_path = Path(split, f"{element.type}_{element.id}") Path(self.output, LABELS_DIR, base_path).with_suffix(".txt").write_text(text) - download_image( - element, Path(self.output, LABELS_DIR, base_path).with_suffix(".jpg") + self.get_image( + element, Path(self.output, IMAGES_DIR, base_path).with_suffix(".jpg") ) return element.id def process_parent( self, + pbar, parent: Element, split: str, ): @@ -176,7 +242,10 @@ class ArkindexExtractor: Extract data from a parent element. """ data = defaultdict(list) - + base_description = ( + f"Extracting data from {parent.type} ({parent.id}) for split ({split})" + ) + pbar.set_description(desc=base_description) if self.element_type == [parent.type]: try: data[parent.type].append(self.process_element(parent, split)) @@ -184,12 +253,15 @@ class ArkindexExtractor: logger.warning(f"Skipping {parent.id}: {str(e)}") # Extract children elements else: - for element in get_elements( + children = get_elements( parent.id, self.element_type, - max_width=self.max_width, - max_height=self.max_height, - ): + ) + + nb_children = children.count() + for idx, element in enumerate(children, start=1): + # Update description to update the children processing progress + pbar.set_description(desc=base_description + f" ({idx}/{nb_children})") try: data[element.type].append(self.process_element(element, split)) except ProcessingError as e: @@ -198,23 +270,24 @@ class ArkindexExtractor: def run(self): # Iterate over the subsets to find the page images and labels. - for idx, (folder_id, split) in enumerate( - zip(self.folders, SPLIT_NAMES), start=1 - ): - # Iterate over the pages to create splits at page level. - for parent in tqdm( + for folder_id, split in zip(self.folders, SPLIT_NAMES): + with tqdm( get_elements( folder_id, [self.parent_element_type], - max_width=self.max_width, - max_height=self.max_height, ), - desc=f"Processing {folder_id} {idx}/{len(self.subsets)}", - ): - self.process_parent( - parent=parent, - split=split, - ) + desc=f"Extracting data from ({folder_id}) for split ({split})", + ) as pbar: + # Iterate over the pages to create splits at page level. + for parent in pbar: + self.process_parent( + pbar=pbar, + parent=parent, + split=split, + ) + # Progress bar updates + pbar.update() + pbar.refresh() def run( @@ -232,6 +305,7 @@ def run( entity_worker_version: Optional[Union[str, bool]], max_width: Optional[int], max_height: Optional[int], + cache_dir: Path, ): assert database.exists(), f"No file found @ {database}" open_database(path=database) @@ -258,4 +332,5 @@ def run( entity_worker_version=entity_worker_version, max_width=max_width, max_height=max_height, + cache_dir=cache_dir, ).run() diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 608aa4a0..c39c7dca 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -1,20 +1,30 @@ # -*- coding: utf-8 -*- import logging -import time +from io import BytesIO from pathlib import Path from typing import NamedTuple -import cv2 -import imageio.v2 as iio +import requests import yaml -from numpy import ndarray - -from dan.datasets.extract.db import Element -from dan.datasets.extract.exceptions import ImageDownloadError +from PIL import Image +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) logger = logging.getLogger(__name__) -MAX_RETRIES = 5 +# See http://docs.python-requests.org/en/master/user/advanced/#timeouts +DOWNLOAD_TIMEOUT = (30, 60) + + +def _retry_log(retry_state, *args, **kwargs): + logger.warning( + f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), " + f"retrying in {retry_state.idle_for} seconds" + ) class EntityType(NamedTuple): @@ -26,29 +36,36 @@ class EntityType(NamedTuple): return len(self.start) + len(self.end) -def download_image(element: Element, im_path: Path): - if im_path.exists(): - return im_path - - tries = 1 - # retry loop - while True: - if tries > MAX_RETRIES: - raise ImageDownloadError(element.id, Exception("Maximum retries reached.")) - try: - image = iio.imread(element.image_url) - save_image(im_path, image) - return - except TimeoutError: - logger.warning("Timeout, retry in 1 second.") - time.sleep(1) - tries += 1 - except Exception as e: - raise ImageDownloadError(element.id, e) - - -def save_image(path: Path, image: ndarray): - cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2), + retry=retry_if_exception_type(requests.RequestException), + before_sleep=_retry_log, + reraise=True, +) +def _retried_request(url): + resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT) + resp.raise_for_status() + return resp + + +def download_image(url): + """ + Download an image and open it with Pillow + """ + assert url.startswith("http"), "Image URL must be HTTP(S)" + # Download the image + # Cannot use stream=True as urllib's responses do not support the seek(int) method, + # which is explicitly required by Image.open on file-like objects + resp = _retried_request(url) + + # Preprocess the image and prepare it for classification + image = Image.open(BytesIO(resp.content)) + logger.debug( + "Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1]) + ) + + return image def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str: diff --git a/docs/get_started/index.md b/docs/get_started/index.md index 85c869c6..8698d476 100644 --- a/docs/get_started/index.md +++ b/docs/get_started/index.md @@ -1,6 +1,18 @@ # Get started -To use DAN in your own environment, install it using pip: +To use DAN in your own environment, you need to first clone with its submodules via: + +```shell +git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git +``` + +If you forgot the `--recurse-submodules`, you can initialize the submodule using: + +```shell +git submodule update --init +``` + +Then you can install it via pip: ```shell pip install -e . diff --git a/pyproject.toml b/pyproject.toml index 9e4a0537..71e69787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ignore = ["E501"] select = ["E", "F", "T1", "W", "I"] [tool.ruff.isort] -known-first-party = ["arkindex_export"] +known-first-party = ["arkindex_export", "line_image_extractor"] known-third-party = [ "albumentations", "cv2", diff --git a/requirements.txt b/requirements.txt index 0a38d4fa..ed460a1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ +-e ./teklia_line_image_extractor albumentations==1.3.1 arkindex-export==0.1.3 boto3==1.26.124 editdistance==0.6.2 imageio==2.26.1 numpy==1.24.3 -opencv-python==4.7.0.72 PyYAML==6.0 scipy==1.10.1 +tenacity==8.2.2 tensorboard==2.12.2 torch==2.0.0 torchvision==0.15.1 diff --git a/setup.py b/setup.py index 5cb03b52..a4c14fc5 100755 --- a/setup.py +++ b/setup.py @@ -1,15 +1,36 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import os +from pathlib import Path +from typing import List from setuptools import find_packages, setup -def parse_requirements(path): - assert os.path.exists(path), "Missing requirements {}".format(path) - with open(path) as f: - return list(map(str.strip, f.read().splitlines())) +def parse_requirements_line(line) -> str: + # Special case for git requirements + if line.startswith("git+http"): + assert "@" in line, "Branch should be specified with suffix (ex: @master)" + assert ( + "#egg=" in line + ), "Package name should be specified with suffix (ex: #egg=kraken)" + package_name: str = line.split("#egg=")[-1] + return f"{package_name} @ {line}" + # Special case for submodule requirements + elif line.startswith("-e"): + package_path: str = line.split(" ")[-1] + package = Path(package_path).resolve() + return f"{package.name} @ file://{package}" + else: + return line + + +def parse_requirements(filename: str) -> List[str]: + path = Path(__file__).parent.resolve() / filename + assert path.exists(), f"Missing requirements: {path}" + return list( + map(parse_requirements_line, map(str.strip, path.read_text().splitlines())) + ) setup( diff --git a/teklia_line_image_extractor b/teklia_line_image_extractor new file mode 160000 index 00000000..210c6493 --- /dev/null +++ b/teklia_line_image_extractor @@ -0,0 +1 @@ +Subproject commit 210c64939d62a8d915dcedbca7dcd529652e5a8b diff --git a/tox.ini b/tox.ini index 23074f42..ef7a6a03 100644 --- a/tox.ini +++ b/tox.ini @@ -12,3 +12,6 @@ deps = -rrequirements.txt commands = pytest {tty:--color=yes} {posargs} + +[pytest] +testpaths= tests -- GitLab