From 08fc6895e39dedb4836b3cc576432e3e87d564eb Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 8 Aug 2023 09:59:40 +0000
Subject: [PATCH] Cache images and crop

---
 .gitlab-ci.yml                   |   4 ++
 .gitmodules                      |   3 +
 README.md                        |  16 ++++-
 dan/datasets/extract/__init__.py |   8 +++
 dan/datasets/extract/db.py       |  68 ++----------------
 dan/datasets/extract/extract.py  | 115 +++++++++++++++++++++++++------
 dan/datasets/extract/utils.py    |  79 ++++++++++++---------
 docs/get_started/index.md        |  14 +++-
 pyproject.toml                   |   2 +-
 requirements.txt                 |   3 +-
 setup.py                         |  31 +++++++--
 teklia_line_image_extractor      |   1 +
 tox.ini                          |   3 +
 13 files changed, 223 insertions(+), 124 deletions(-)
 create mode 100644 .gitmodules
 create mode 160000 teklia_line_image_extractor

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 8d5fabc7..6a4824e6 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -3,6 +3,10 @@ stages:
   - build
   - deploy
 
+variables:
+  # Submodule clone
+  GIT_SUBMODULE_STRATEGY: recursive
+
 lint:
   image: python:3.10
   stage: test
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000..33e01713
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "line_image_extractor"]
+	path = teklia_line_image_extractor
+	url = ../line_image_extractor.git
diff --git a/README.md b/README.md
index c9bb3d16..40c71bc6 100644
--- a/README.md
+++ b/README.md
@@ -4,9 +4,21 @@
 
 ## Documentation
 
-To use DAN in your own scripts, install it using pip:
+To use DAN in your own environment, you need to first clone with its submodules via:
 
-```console
+```shell
+git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
+```
+
+If you forgot the `--recurse-submodules`, you can initialize the submodule using:
+
+```shell
+git submodule update --init
+```
+
+Then you can install it via pip:
+
+```shell
 pip install -e .
 ```
 
diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index 49d7c7e8..99f255d3 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -138,4 +138,12 @@ def add_extract_parser(subcommands) -> None:
         help="Images larger than this height will be resized to this width.",
     )
 
+    parser.add_argument(
+        "--cache",
+        dest="cache_dir",
+        type=pathlib.Path,
+        help="Where the images should be cached.",
+        default=pathlib.Path(".cache"),
+    )
+
     parser.set_defaults(func=run)
diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py
index 52bf4d17..a8933d06 100644
--- a/dan/datasets/extract/db.py
+++ b/dan/datasets/extract/db.py
@@ -1,14 +1,10 @@
 # -*- coding: utf-8 -*-
 
-import ast
-from dataclasses import dataclass
-from itertools import starmap
-from typing import List, Optional, Union
-from urllib.parse import urljoin
+from typing import List, Union
 
 from arkindex_export import Image
-from arkindex_export.models import Element as ArkindexElement
 from arkindex_export.models import (
+    Element,
     Entity,
     EntityType,
     Transcription,
@@ -17,51 +13,10 @@ from arkindex_export.models import (
 from arkindex_export.queries import list_children
 
 
-def bounding_box(polygon: list):
-    """
-    Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
-    """
-    all_x, all_y = zip(*polygon)
-    x, y = min(all_x), min(all_y)
-    width, height = max(all_x) - x, max(all_y) - y
-    return int(x), int(y), int(width), int(height)
-
-
-@dataclass
-class Element:
-    id: str
-    type: str
-    polygon: str
-    url: str
-    width: int
-    height: int
-
-    max_width: Optional[int] = None
-    max_height: Optional[int] = None
-
-    def __post_init__(self):
-        self.max_height = self.max_height or self.height
-        self.max_width = self.max_width or self.width
-
-    @property
-    def bounding_box(self):
-        return bounding_box(ast.literal_eval(self.polygon))
-
-    @property
-    def image_url(self):
-        x, y, width, height = self.bounding_box
-        return urljoin(
-            self.url + "/",
-            f"{x},{y},{width},{height}/!{self.max_width},{self.max_height}/0/default.jpg",
-        )
-
-
 def get_elements(
     parent_id: str,
     element_type: List[str],
-    max_width: Optional[int] = None,
-    max_height: Optional[int] = None,
-) -> List[Element]:
+):
     """
     Retrieve elements from an SQLite export of an Arkindex corpus
     """
@@ -69,23 +24,10 @@ def get_elements(
     query = (
         list_children(parent_id=parent_id)
         .join(Image)
-        .where(ArkindexElement.type.in_(element_type))
-        .select(
-            ArkindexElement.id,
-            ArkindexElement.type,
-            ArkindexElement.polygon,
-            Image.url,
-            Image.width,
-            Image.height,
-        )
+        .where(Element.type.in_(element_type))
     )
 
-    return list(
-        starmap(
-            lambda *x: Element(*x, max_width=max_width, max_height=max_height),
-            query.tuples(),
-        )
-    )
+    return query
 
 
 def build_worker_version_filter(ArkindexModel, worker_version):
diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index 0e1879e0..b5e94a0a 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -1,11 +1,13 @@
 # -*- coding: utf-8 -*-
 
+import json
 import random
 from collections import defaultdict
 from pathlib import Path
 from typing import List, Optional, Union
 from uuid import UUID
 
+import numpy as np
 from tqdm import tqdm
 
 from arkindex_export import open_database
@@ -27,10 +29,13 @@ from dan.datasets.extract.utils import (
     insert_token,
     parse_tokens,
 )
+from line_image_extractor.extractor import extract, read_img, save_img
+from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
 
 IMAGES_DIR = "images"  # Subpath to the images directory.
 LABELS_DIR = "labels"  # Subpath to the labels directory.
 SPLIT_NAMES = ["train", "val", "test"]
+IIIF_URL_SUFFIX = "/full/full/0/default.jpg"
 
 
 class ArkindexExtractor:
@@ -51,6 +56,7 @@ class ArkindexExtractor:
         entity_worker_version: Optional[Union[str, bool]] = None,
         max_width: Optional[int] = None,
         max_height: Optional[int] = None,
+        cache_dir: Path = Path(".cache"),
     ) -> None:
         self.folders = folders
         self.element_type = element_type
@@ -64,6 +70,19 @@ class ArkindexExtractor:
         self.max_width = max_width
         self.max_height = max_height
 
+        self.cache_dir = cache_dir
+        # Create cache dir if non existent
+        self.cache_dir.mkdir(exist_ok=True, parents=True)
+
+    def find_image_in_cache(self, image_id: str) -> Path:
+        """Images are cached to avoid downloading them twice. They are stored under a specific name,
+        based on their Arkindex ID. Images are saved under the JPEG format.
+
+        :param image_id: ID of the image. The image is saved under this name.
+        :return: Where the image should be saved in the cache folder.
+        """
+        return self.cache_dir / f"{image_id}.jpg"
+
     def _keep_char(self, char: str) -> bool:
         # Keep all text by default if no separator was given
         return not self.entity_separators or char in self.entity_separators
@@ -148,6 +167,52 @@ class ArkindexExtractor:
         )
         return self.reconstruct_text(transcription.text, entities)
 
+    def retrieve_image(self, child: Element):
+        """Get or download image of the element. Checks in cache before downloading.
+
+        :param child: Processed element.
+        :return: The element's image.
+        """
+        cached_img_path = self.find_image_in_cache(child.image.id)
+        if not cached_img_path.exists():
+            # Save in cache
+            download_image(child.image.url + IIIF_URL_SUFFIX).save(
+                cached_img_path, format="jpeg"
+            )
+
+        return read_img(cached_img_path)
+
+    def get_image(self, child: Element, destination: Path) -> None:
+        """Save the element's image to the given path and applies any image operations needed.
+
+        :param child: Processed element.
+        :param destination: Where the image should be saved.
+        """
+        polygon = json.loads(str(child.polygon))
+
+        if self.max_height or self.max_width:
+            polygon = resize(
+                polygon,
+                self.max_width,
+                self.max_height,
+                scale_x=1.0,
+                scale_y_top=1.0,
+                scale_y_bottom=1.0,
+            )
+
+        # Extract the polygon in the image
+        image = extract(
+            img=self.retrieve_image(child),
+            polygon=np.array(polygon),
+            bbox=polygon_to_bbox(polygon),
+            # Hardcoded while we don't have a configuration file
+            extraction_mode=Extraction.deskew_min_area_rect,
+            max_deskew_angle=45,
+        )
+
+        # Save the image to disk
+        save_img(path=destination, img=image)
+
     def process_element(
         self,
         element: Element,
@@ -162,13 +227,14 @@ class ArkindexExtractor:
         base_path = Path(split, f"{element.type}_{element.id}")
         Path(self.output, LABELS_DIR, base_path).with_suffix(".txt").write_text(text)
 
-        download_image(
-            element, Path(self.output, LABELS_DIR, base_path).with_suffix(".jpg")
+        self.get_image(
+            element, Path(self.output, IMAGES_DIR, base_path).with_suffix(".jpg")
         )
         return element.id
 
     def process_parent(
         self,
+        pbar,
         parent: Element,
         split: str,
     ):
@@ -176,7 +242,10 @@ class ArkindexExtractor:
         Extract data from a parent element.
         """
         data = defaultdict(list)
-
+        base_description = (
+            f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
+        )
+        pbar.set_description(desc=base_description)
         if self.element_type == [parent.type]:
             try:
                 data[parent.type].append(self.process_element(parent, split))
@@ -184,12 +253,15 @@ class ArkindexExtractor:
                 logger.warning(f"Skipping {parent.id}: {str(e)}")
         # Extract children elements
         else:
-            for element in get_elements(
+            children = get_elements(
                 parent.id,
                 self.element_type,
-                max_width=self.max_width,
-                max_height=self.max_height,
-            ):
+            )
+
+            nb_children = children.count()
+            for idx, element in enumerate(children, start=1):
+                # Update description to update the children processing progress
+                pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
                 try:
                     data[element.type].append(self.process_element(element, split))
                 except ProcessingError as e:
@@ -198,23 +270,24 @@ class ArkindexExtractor:
 
     def run(self):
         # Iterate over the subsets to find the page images and labels.
-        for idx, (folder_id, split) in enumerate(
-            zip(self.folders, SPLIT_NAMES), start=1
-        ):
-            # Iterate over the pages to create splits at page level.
-            for parent in tqdm(
+        for folder_id, split in zip(self.folders, SPLIT_NAMES):
+            with tqdm(
                 get_elements(
                     folder_id,
                     [self.parent_element_type],
-                    max_width=self.max_width,
-                    max_height=self.max_height,
                 ),
-                desc=f"Processing {folder_id} {idx}/{len(self.subsets)}",
-            ):
-                self.process_parent(
-                    parent=parent,
-                    split=split,
-                )
+                desc=f"Extracting data from ({folder_id}) for split ({split})",
+            ) as pbar:
+                # Iterate over the pages to create splits at page level.
+                for parent in pbar:
+                    self.process_parent(
+                        pbar=pbar,
+                        parent=parent,
+                        split=split,
+                    )
+                    # Progress bar updates
+                    pbar.update()
+                    pbar.refresh()
 
 
 def run(
@@ -232,6 +305,7 @@ def run(
     entity_worker_version: Optional[Union[str, bool]],
     max_width: Optional[int],
     max_height: Optional[int],
+    cache_dir: Path,
 ):
     assert database.exists(), f"No file found @ {database}"
     open_database(path=database)
@@ -258,4 +332,5 @@ def run(
         entity_worker_version=entity_worker_version,
         max_width=max_width,
         max_height=max_height,
+        cache_dir=cache_dir,
     ).run()
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index 608aa4a0..c39c7dca 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -1,20 +1,30 @@
 # -*- coding: utf-8 -*-
 import logging
-import time
+from io import BytesIO
 from pathlib import Path
 from typing import NamedTuple
 
-import cv2
-import imageio.v2 as iio
+import requests
 import yaml
-from numpy import ndarray
-
-from dan.datasets.extract.db import Element
-from dan.datasets.extract.exceptions import ImageDownloadError
+from PIL import Image
+from tenacity import (
+    retry,
+    retry_if_exception_type,
+    stop_after_attempt,
+    wait_exponential,
+)
 
 logger = logging.getLogger(__name__)
 
-MAX_RETRIES = 5
+# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
+DOWNLOAD_TIMEOUT = (30, 60)
+
+
+def _retry_log(retry_state, *args, **kwargs):
+    logger.warning(
+        f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
+        f"retrying in {retry_state.idle_for} seconds"
+    )
 
 
 class EntityType(NamedTuple):
@@ -26,29 +36,36 @@ class EntityType(NamedTuple):
         return len(self.start) + len(self.end)
 
 
-def download_image(element: Element, im_path: Path):
-    if im_path.exists():
-        return im_path
-
-    tries = 1
-    # retry loop
-    while True:
-        if tries > MAX_RETRIES:
-            raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
-        try:
-            image = iio.imread(element.image_url)
-            save_image(im_path, image)
-            return
-        except TimeoutError:
-            logger.warning("Timeout, retry in 1 second.")
-            time.sleep(1)
-            tries += 1
-        except Exception as e:
-            raise ImageDownloadError(element.id, e)
-
-
-def save_image(path: Path, image: ndarray):
-    cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+@retry(
+    stop=stop_after_attempt(3),
+    wait=wait_exponential(multiplier=2),
+    retry=retry_if_exception_type(requests.RequestException),
+    before_sleep=_retry_log,
+    reraise=True,
+)
+def _retried_request(url):
+    resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
+    resp.raise_for_status()
+    return resp
+
+
+def download_image(url):
+    """
+    Download an image and open it with Pillow
+    """
+    assert url.startswith("http"), "Image URL must be HTTP(S)"
+    # Download the image
+    # Cannot use stream=True as urllib's responses do not support the seek(int) method,
+    # which is explicitly required by Image.open on file-like objects
+    resp = _retried_request(url)
+
+    # Preprocess the image and prepare it for classification
+    image = Image.open(BytesIO(resp.content))
+    logger.debug(
+        "Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
+    )
+
+    return image
 
 
 def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str:
diff --git a/docs/get_started/index.md b/docs/get_started/index.md
index 85c869c6..8698d476 100644
--- a/docs/get_started/index.md
+++ b/docs/get_started/index.md
@@ -1,6 +1,18 @@
 # Get started
 
-To use DAN in your own environment, install it using pip:
+To use DAN in your own environment, you need to first clone with its submodules via:
+
+```shell
+git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
+```
+
+If you forgot the `--recurse-submodules`, you can initialize the submodule using:
+
+```shell
+git submodule update --init
+```
+
+Then you can install it via pip:
 
 ```shell
 pip install -e .
diff --git a/pyproject.toml b/pyproject.toml
index 9e4a0537..71e69787 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ ignore = ["E501"]
 select = ["E", "F", "T1", "W", "I"]
 
 [tool.ruff.isort]
-known-first-party = ["arkindex_export"]
+known-first-party = ["arkindex_export", "line_image_extractor"]
 known-third-party = [
     "albumentations",
     "cv2",
diff --git a/requirements.txt b/requirements.txt
index 0a38d4fa..ed460a1e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,12 +1,13 @@
+-e ./teklia_line_image_extractor
 albumentations==1.3.1
 arkindex-export==0.1.3
 boto3==1.26.124
 editdistance==0.6.2
 imageio==2.26.1
 numpy==1.24.3
-opencv-python==4.7.0.72
 PyYAML==6.0
 scipy==1.10.1
+tenacity==8.2.2
 tensorboard==2.12.2
 torch==2.0.0
 torchvision==0.15.1
diff --git a/setup.py b/setup.py
index 5cb03b52..a4c14fc5 100755
--- a/setup.py
+++ b/setup.py
@@ -1,15 +1,36 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import os
+from pathlib import Path
+from typing import List
 
 from setuptools import find_packages, setup
 
 
-def parse_requirements(path):
-    assert os.path.exists(path), "Missing requirements {}".format(path)
-    with open(path) as f:
-        return list(map(str.strip, f.read().splitlines()))
+def parse_requirements_line(line) -> str:
+    # Special case for git requirements
+    if line.startswith("git+http"):
+        assert "@" in line, "Branch should be specified with suffix (ex: @master)"
+        assert (
+            "#egg=" in line
+        ), "Package name should be specified with suffix (ex: #egg=kraken)"
+        package_name: str = line.split("#egg=")[-1]
+        return f"{package_name} @ {line}"
+    # Special case for submodule requirements
+    elif line.startswith("-e"):
+        package_path: str = line.split(" ")[-1]
+        package = Path(package_path).resolve()
+        return f"{package.name} @ file://{package}"
+    else:
+        return line
+
+
+def parse_requirements(filename: str) -> List[str]:
+    path = Path(__file__).parent.resolve() / filename
+    assert path.exists(), f"Missing requirements: {path}"
+    return list(
+        map(parse_requirements_line, map(str.strip, path.read_text().splitlines()))
+    )
 
 
 setup(
diff --git a/teklia_line_image_extractor b/teklia_line_image_extractor
new file mode 160000
index 00000000..210c6493
--- /dev/null
+++ b/teklia_line_image_extractor
@@ -0,0 +1 @@
+Subproject commit 210c64939d62a8d915dcedbca7dcd529652e5a8b
diff --git a/tox.ini b/tox.ini
index 23074f42..ef7a6a03 100644
--- a/tox.ini
+++ b/tox.ini
@@ -12,3 +12,6 @@ deps =
     -rrequirements.txt
 commands =
     pytest {tty:--color=yes} {posargs}
+
+[pytest]
+testpaths= tests
-- 
GitLab