diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c169292e4318e617c929f25237a2e8af61813b16..c77c88be7d0b52106765f82120a19bead5a3258b 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -18,14 +18,10 @@ test: variables: PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" - ARKINDEX_API_SCHEMA_URL: schema.yml before_script: - pip install tox - # Download OpenAPI schema from last backend build - - curl https://assets.teklia.com/arkindex/openapi.yml > schema.yml - except: - schedules diff --git a/README.md b/README.md index fc26487e578676cccc6fe9189b15e7a5fc44e2be..88097a8e6f3c4350969016c53b654d4da0b34bc4 100644 --- a/README.md +++ b/README.md @@ -5,20 +5,3 @@ and converts data to ATR format. It also generates reproducible train, val and test splits. A documentation is available at https://teklia.gitlab.io/atr/data-generator/. - -## Environment variables - -`ARKINDEX_API_TOKEN` and `ARKINDEX_API_URL` environment variables must be defined. - -You can create an alias by adding this line to your `~/.bashrc`: - -```sh -alias set_demo='export ARKINDEX_API_URL=https://demo.arkindex.org/;export ARKINDEX_API_TOKEN=my_api_token' -``` - -Then run: - -```sh -source ~/.bashrc -set_demo -``` diff --git a/atr_data_generator/arguments.py b/atr_data_generator/arguments.py index e88f72f867896f22285af4688a2d1bcbef228a2e..a7e4ee8891cac31e2d8a217d3ee794bab687dfe1 100644 --- a/atr_data_generator/arguments.py +++ b/atr_data_generator/arguments.py @@ -1,20 +1,12 @@ # -*- coding: utf-8 -*- -import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from pathlib import Path @dataclass class BaseArgs: - def __post_init__(self): - self._validate() - - def _validate(self): - """Override this method to add argument validation.""" - pass - - def dict(self): - return json.loads(json.dumps(asdict(self), default=str)) + def json(self): + return vars(self).copy() @dataclass @@ -35,6 +27,15 @@ class CommonArgs(BaseArgs): log_parameters: bool = True def __post_init__(self): - super().__post_init__() self.output_dir.mkdir(exist_ok=True, parents=True) self.cache_dir.mkdir(exist_ok=True, parents=True) + + def json(self): + data = super().json() + data.update( + { + "output_dir": str(self.output_dir), + "cache_dir": str(self.cache_dir), + } + ) + return data diff --git a/atr_data_generator/cli.py b/atr_data_generator/cli.py index 52971af96e3ac694fe1479755a98106bcf7deb69..ee265bf545b69a173e2ce941dbe8fdc6640a07c1 100644 --- a/atr_data_generator/cli.py +++ b/atr_data_generator/cli.py @@ -18,7 +18,9 @@ def main(): args = vars(parser.parse_args()) if "func" in args: - kwargs = args.pop("config_parser")(args.pop("config")) - args.pop("func")(**kwargs) + conf_args = args.pop("config_parser")(args.pop("config")) + _func = args.pop("func") + kwargs = {**args, **conf_args} + _func(**kwargs) else: parser.print_help() diff --git a/atr_data_generator/extract/__init__.py b/atr_data_generator/extract/__init__.py index 66acc121d366e78c131ef6fb3b43b69d2ab7523e..56b44c3b66adebfdb4e38ada25bf51e93067fec4 100644 --- a/atr_data_generator/extract/__init__.py +++ b/atr_data_generator/extract/__init__.py @@ -9,14 +9,13 @@ from teklia_toolbox.config import ConfigParser from atr_data_generator.arguments import CommonArgs from atr_data_generator.extract.arguments import ( + DEFAULT_RESCALE, ExtractionMode, FilterArgs, ImageArgs, SelectArgs, - Style, - TranscriptionType, ) -from atr_data_generator.extract.main import main +from atr_data_generator.extract.base import main def _float(value): @@ -25,12 +24,6 @@ def _float(value): return float(value) -def _style(value): - if value is None: - return None - return Style(value) - - def get_parser(): parser = ConfigParser() @@ -48,32 +41,26 @@ def get_parser(): type=ExtractionMode, default=ExtractionMode.deskew_min_area_rect, ) + image.add_option("fixed_height", type=int, default=None) image.add_option("max_deskew_angle", type=int, default=45) image.add_option("skew_angle", type=int, default=0) image.add_option("should_rotate", type=bool, default=False) image.add_option("grayscale", type=bool, default=True) scale = image.add_subparser("scale", default={}) - scale.add_option("x", type=_float, default=None) - scale.add_option("y_top", type=_float, default=None) - scale.add_option("y_bottom", type=_float, default=None) + scale.add_option("x", type=_float, default=DEFAULT_RESCALE) + scale.add_option("y_top", type=_float, default=DEFAULT_RESCALE) + scale.add_option("y_bottom", type=_float, default=DEFAULT_RESCALE) # Filters - filters = parser.add_subparser("filter") - filters.add_option("transcription_type", type=TranscriptionType) - filters.add_option("ignored_classes", type=str, many=True, default=[]) - filters.add_option("accepted_classes", type=str, many=True, default=[]) + filters = parser.add_subparser("filter", default={}) filters.add_option( "accepted_worker_version_ids", type=uuid.UUID, many=True, default=[] ) - filters.add_option("style", type=_style, default=None) filters.add_option("skip_vertical_lines", type=bool, default=False) - filters.add_option("accepted_metadatas", type=dict, default={}) - filters.add_option("filter_parent_metadatas", type=bool, default=False) # Select select = parser.add_subparser("select", default={}) - select.add_option("corpora", type=uuid.UUID, many=True, default=[]) select.add_option("folders", type=uuid.UUID, many=True, default=[]) select.add_option("parent_type", type=str, default=None) select.add_option("element_type", type=str, default=None) @@ -105,4 +92,5 @@ def add_extract_subparser(subcommands): help=__doc__, ) parser.add_argument("--config", type=Path, help="Configuration file") + parser.add_argument("--database-path", type=Path, help="Export path") parser.set_defaults(func=main, config_parser=config_parser) diff --git a/atr_data_generator/extract/arguments.py b/atr_data_generator/extract/arguments.py index e4df060c5a7be47f7c6a41ac1ff536af4965bc07..5743be6c97bbfe730e81ac1ec052fdd8e52ef184 100644 --- a/atr_data_generator/extract/arguments.py +++ b/atr_data_generator/extract/arguments.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import argparse -import getpass from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Union @@ -9,7 +7,6 @@ from line_image_extractor.extractor import Extraction from atr_data_generator.arguments import BaseArgs -USER = getpass.getuser() MANUAL = "manual" DEFAULT_RESCALE = 1.0 @@ -23,20 +20,6 @@ class ListedEnum(str, Enum): return self.value -class Style(ListedEnum): - """Font style of the text. - - Attributes: - handwritten: Handwritten font style. - typewritten: Printed font style. - other: Other font style. - """ - - handwritten: str = "handwritten" - typewritten: str = "typewritten" - other: str = "other" - - class ExtractionMode(ListedEnum): """Extraction mode of the image. @@ -63,47 +46,29 @@ class ExtractionMode(ListedEnum): return getattr(Extraction, self.value) -class TranscriptionType(ListedEnum): - """Arkindex type of the element to extract. - - Attributes: - word: - text_line: - half_subject_line: - text_zone: - paragraph: - act: - page: - text: - """ - - word: str = "word" - text_line: str = "text_line" - half_subject_line: str = "half_subject_line" - text_zone: str = "text_zone" - paragraph: str = "paragraph" - act: str = "act" - page: str = "page" - text: str = "text" - - @dataclass class SelectArgs(BaseArgs): """ Arguments to select elements from Arkindex Args: - corpora (list): List of corpus ids to be used. folders (list): List of folder ids to be used. Elements of `volume_type` will be searched recursively in these folders element_type (str): Filter elements to process by type parent_type (str): Filter elements parents to process by type """ - corpora: Optional[List[str]] = field(default_factory=list) - folders: Optional[List[str]] = field(default_factory=list) + folders: List[str] = field(default_factory=list) element_type: Optional[str] = None parent_type: Optional[str] = None + def __post_init__(self): + assert len(self.folders) > 0, "Please provide at least one folder." + + def json(self): + data = super().json() + data.update({"folders": list(map(str, self.folders))}) + return data + @dataclass class ScaleArgs(BaseArgs): @@ -115,19 +80,22 @@ class ScaleArgs(BaseArgs): y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling) """ - x: Optional[float] = None - y_top: Optional[float] = None - y_bottom: Optional[float] = None + x: float = DEFAULT_RESCALE + y_top: float = DEFAULT_RESCALE + y_bottom: float = DEFAULT_RESCALE def __post_init__(self): - super().__post_init__() if self.should_resize_polygons: # use 1.0 as default - no resize, if not specified self.x = self.y_top = self.y_bottom = DEFAULT_RESCALE @property def should_resize_polygons(self): - return self.x or self.y_top or self.y_bottom + return bool( + self.x != DEFAULT_RESCALE + or self.y_top != DEFAULT_RESCALE + or self.y_bottom != DEFAULT_RESCALE + ) @dataclass @@ -136,10 +104,10 @@ class ImageArgs(BaseArgs): Arguments related to image transformation. Args: - extraction_mode: Mode for extracting the line images: {[e.name for e in Extraction]}, - max_deskew_angle: Maximum angle by which deskewing is allowed to rotate the line image. + extraction_mode (str): Mode for extracting the line images, see `ExtractionMode` class for options, + max_deskew_angle (int): Maximum angle by which deskewing is allowed to rotate the line image. the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated. - skew_angle: Angle by which the line image will be rotated. Useful for data augmentation" + skew_angle -int): Angle by which the line image will be rotated. Useful for data augmentation" creating skewed text lines for a more robust model. Only used with skew_* extraction modes. should_rotate (bool): Use text line rotation class to rotate lines if possible grayscale (bool): Convert images to grayscale (By default grayscale) @@ -151,10 +119,20 @@ class ImageArgs(BaseArgs): skew_angle: int = 0 should_rotate: bool = False grayscale: bool = True + fixed_height: Optional[int] = None def __post_init__(self): self.scale = ScaleArgs(**self.scale) - super().__post_init__() + + def json(self): + data = super().json() + data.update( + { + "extraction_mode": self.extraction_mode.value, + "scale": self.scale.json(), + } + ) + return data @dataclass @@ -163,65 +141,21 @@ class FilterArgs(BaseArgs): Arguments related to element filtering. Args: - transcription_type: Which type of elements' transcriptions to use? (page, paragraph, text_line, etc) - ignored_classes: List of ignored ml_class names. Filter lines by class - accepted_classes: List of accepted ml_class names. Filter lines by class - accepted_worker_version_ids:List of accepted worker version ids. Filter transcriptions by worker version ids. - The order is important - only up to one transcription will be chosen per element (text_line) - and the worker version order defines the precedence. If there exists a transcription for - the first worker version then it will be chosen, otherwise will continue on to the next - worker version. - Use `--accepted_worker_version_ids manual` to get only manual transcriptions - style: Filter line images by style class. - accepted_metadatas: Key-value dictionary where each entry is a mandatory Arkindex metadata name/value. Filter lines by metadata. + accepted_worker_version_ids (List of str): List of accepted worker version ids. Filter transcriptions by worker version ids. + Use `manual` to get only manual transcriptions + skip_vertical_lines (bool): Skip vertical lines. """ - transcription_type: TranscriptionType = TranscriptionType.text_line - ignored_classes: List[str] = field(default_factory=list) - accepted_classes: List[str] = field(default_factory=list) accepted_worker_version_ids: List[str] = field(default_factory=list) skip_vertical_lines: bool = False - style: Optional[Style] = None - accepted_metadatas: dict = field(default_factory=dict) - filter_parent_metadatas: bool = False - - def _validate(self): - # Check overlap of accepted and ignored classes - accepted_classes = self.accepted_classes - ignored_classes = self.accepted_classes - if accepted_classes and ignored_classes: - if set(accepted_classes) & set(ignored_classes): - raise argparse.ArgumentTypeError( - f"--filter.accepted_classes and --filter.ignored_classes values must not overlap ({accepted_classes} - {ignored_classes})" - ) - if self.style and (accepted_classes or ignored_classes): - if set(Style.list()) & (set(accepted_classes) | set(ignored_classes)): - raise argparse.ArgumentTypeError( - f"--style class values ({Style.list()}) shouldn't be in the accepted_classes list " - f"(or ignored_classes list) " - "if --filter.style is used with either --filter.accepted_classes or --filter.ignored_classes." + def json(self): + data = super().json() + data.update( + { + "accepted_worker_version_ids": list( + map(str, self.accepted_worker_version_ids) ) - - # Support manual worker_version_ids - if MANUAL in self.accepted_worker_version_ids: - # Replace with None - self.accepted_worker_version_ids[ - self.accepted_worker_version_ids.index(MANUAL) - ] = None - - @property - def should_filter_by_class(self): - return bool(self.accepted_classes) or bool(self.ignored_classes) - - @property - def should_filter_by_worker(self): - return bool(self.accepted_worker_version_ids) - - @property - def should_filter_by_style(self): - return bool(self.style) - - @property - def should_filter_by_metadatas(self): - return bool(self.accepted_metadatas) + } + ) + return data diff --git a/atr_data_generator/extract/arkindex.py b/atr_data_generator/extract/arkindex.py deleted file mode 100644 index 1232e14f0565201c4236f609b3919608324bb6dc..0000000000000000000000000000000000000000 --- a/atr_data_generator/extract/arkindex.py +++ /dev/null @@ -1,80 +0,0 @@ -# -*- coding: utf-8 -*- -from functools import cached_property -from typing import Optional - - -class MagicDict(dict): - """ - A dict whose items can be accessed like attributes. - """ - - def _magify(self, item): - """ - Automagically convert lists and dicts to MagicDicts and lists of MagicDicts - Allows for nested access: foo.bar.baz - """ - if isinstance(item, list): - return list(map(self._magify, item)) - if isinstance(item, dict): - return MagicDict(item) - return item - - def __getitem__(self, item): - item = super().__getitem__(item) - return self._magify(item) - - def __getattr__(self, name): - try: - return self[name] - except KeyError: - raise AttributeError( - "{} object has no attribute '{}'".format(self.__class__.__name__, name) - ) - - def __setattr__(self, name, value): - return super().__setitem__(name, value) - - def __delattr__(self, name): - try: - return super().__delattr__(name) - except AttributeError: - try: - return super().__delitem__(name) - except KeyError: - pass - raise - - def __dir__(self): - return super().__dir__() + list(self.keys()) - - -class Element(MagicDict): - """ - Describes an Arkindex element. - """ - - @cached_property - def image_url(self) -> Optional[str]: - """ - Build an URL to access the image. - When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers. - :param size: Subresolution of the image, following the syntax of the IIIF resize parameter. - :returns: An URL to the image, or None if the element does not have an image. - """ - if not self.get("zone"): - return - url = self.zone.image.get("s3_url") - if url: - return url - url = self.zone.image.url - if not url.endswith("/"): - url += "/" - return "{}full/full/0/default.jpg".format(url) - - @cached_property - def width(self): - return self.zone.image.width - - @cached_property - def height(self): - return self.zone.image.height diff --git a/atr_data_generator/extract/base.py b/atr_data_generator/extract/base.py new file mode 100644 index 0000000000000000000000000000000000000000..25dc0efacf2706410a8cf7444c826dd410bd752d --- /dev/null +++ b/atr_data_generator/extract/base.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +import json +import logging +from ast import Dict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np +from arkindex_export import Element, open_database +from line_image_extractor.extractor import extract, read_img, save_img +from line_image_extractor.image_utils import polygon_to_bbox, resize +from PIL import Image +from tqdm import tqdm + +from atr_data_generator.arguments import CommonArgs +from atr_data_generator.extract.arguments import FilterArgs, ImageArgs, SelectArgs +from atr_data_generator.extract.db import get_children, get_children_info +from atr_data_generator.extract.utils import _is_vertical, resize_image_height +from atr_data_generator.utils import download_image, export_parameters + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +IIIF_URL_SUFFIX = "/full/full/0/default.jpg" +EXPORT_PATH = "labels.json" + + +@dataclass +class DataGenerator: + common: CommonArgs + image: ImageArgs + filter: FilterArgs + select: SelectArgs + + data: Dict = field(default_factory=dict) + + def __post_init__(self): + # Create cache dir if non existent + self.common.cache_dir.mkdir(exist_ok=True, parents=True) + + def connect_db(self, db_path: Path) -> None: + open_database(str(db_path)) + + def find_image_in_cache(self, image_id: str) -> Path: + """Images are cached under a certain name. If it exists, the image is not downloaded. + + :param image_id: ID of the image. The image is saved under this name. + :return: The path of the image if it exists in the cache. + """ + return self.common.cache_dir / f"{image_id}.jpg" + + def retrieve_image(self, child: Element) -> np.ndarray[Any, Any]: + """Get or download image of the element. Checks in cache before downloading. + + :param child: Processed element. + :return: The element's image. + """ + cached_img_path = self.find_image_in_cache(child.image.id) + if cached_img_path.exists(): + return read_img(cached_img_path) + + child_full_image = download_image(child.image.url + IIIF_URL_SUFFIX) + + # Convert if needed + if self.image.grayscale: + child_full_image = child_full_image.convert("L") + + # Save in cache + child_full_image.save(cached_img_path, format="jpeg") + + return np.asarray(child_full_image) + + def get_image(self, child: Element, destination: Path) -> None: + """Save the element's image to the given path and applies any image operations needed. + + :param child: Processed element. + :param destination: Where the image should be saved. + """ + polygon = json.loads(str(child.polygon)) + + # Rescale if needed + if self.image.scale.should_resize_polygons: + polygon = resize( + polygon, + child.image.width, + child.image.height, + self.image.scale.x, + self.image.scale.y_top, + self.image.scale.y_bottom, + ) + bbox = polygon_to_bbox(polygon) + + # Skip if line is vertical and these are forbidden + if self.filter.skip_vertical_lines and _is_vertical(bbox): + return + + # Extract the polygon in the image + image = extract( + img=self.retrieve_image(child), + polygon=np.array(polygon), + bbox=bbox, + extraction_mode=self.image.extraction_mode.mode, + max_deskew_angle=self.image.max_deskew_angle, + skew_angle=self.image.skew_angle, + grayscale=self.image.grayscale, + ) + + # Resize to required height if needed + if self.image.fixed_height and bbox.height != self.image.fixed_height: + image = np.asarray( + resize_image_height( + img=Image.fromarray(image), + fixed_height=self.image.fixed_height, + ) + ) + + # Save the image to disk + save_img(path=destination, img=image) + + def process_parent(self, parent: Element): + """ + Process every children under this parent element. + """ + with tqdm( + get_children_info( + str(parent.id), + type=self.select.element_type, + sources=self.filter.accepted_worker_version_ids, + ), + desc=f"Extracting data from {parent.type} {parent.name}", + ) as pbar: + for child in pbar: + image_path = ( + self.common.output_dir + / "images" + / f"{parent.id}_{child.element_id}.jpg" + ) + # Store transcription + self.data[str(image_path)] = child.text + + # Extract the image + self.get_image(child.element, image_path) + + # Progress bar updates + pbar.update() + pbar.refresh() + + def export(self): + """ + Export data to disk. + """ + (self.common.output_dir / EXPORT_PATH).write_text( + json.dumps(self.data, indent=2) + ) + + def run(self, db_path: Path): + """ + Extract data from folders of elements with selected type + """ + self.connect_db(db_path) + # Iterate over folders + for folder_id in self.select.folders: + # Find the parent elements + for parent_element in get_children(folder_id, type=self.select.parent_type): + self.process_parent(parent_element) + + self.export() + + +def main( + database_path: Path, + common: CommonArgs, + image: ImageArgs, + filters: FilterArgs, + select: SelectArgs, +): + data_generator = DataGenerator( + common=common, image=image, filter=filters, select=select + ) + data_generator.run(db_path=database_path) + + export_parameters( + common=common, + image=image, + select=select, + filter=filters, + ) diff --git a/atr_data_generator/extract/db.py b/atr_data_generator/extract/db.py new file mode 100644 index 0000000000000000000000000000000000000000..9309e03d1ab3cd2ccfeb22e068f1a870a1c731b6 --- /dev/null +++ b/atr_data_generator/extract/db.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +from typing import List, Optional + +from arkindex_export import Element, Transcription +from arkindex_export.queries import list_children + +from atr_data_generator.extract.arguments import MANUAL + + +def get_children(parent_id: str, type: Optional[str]): + """Recursively list children elements. + + :param parent_id: ID of the parent element. + :param type: Optionally filter by element type. + :return: The filtered list of children. + """ + query = list_children(parent_id) + if type: + query = query.where(Element.type == type) + return query + + +def parse_sources(sources: List[str]): + """List of transcriptions sources. Manual source has a different treatment. + + :param sources: List of str or MANUAL. + :return: A peewee filter by Transcription.worker_version + """ + query_filter = None + + if MANUAL in sources: + # Manual filtering + query_filter = Transcription.worker_version.is_null() + sources.remove(MANUAL) + + # Filter by worker_versions + if sources: + if query_filter: + query_filter |= Transcription.worker_version.in_(sources) + else: + query_filter = Transcription.worker_version.in_(sources) + return query_filter + + +def get_children_info( + parent_id: str, + type: Optional[str], + sources: Optional[List[str]], +): + """Get the information about the children elements and their transcriptions. Apply all needed filters. + + :param parent_id: ID of the parent element. + :param type: Transcriptions of elements. + """ + + elements = list_children(parent_id) + + # Filter by type + if type: + elements = elements.where(Element.type == type) + + # Get transcriptions + transcriptions = Transcription.select().join( + elements, on=(Transcription.element == elements.c.id) + ) + + # Filter by transcription source + if sources: + transcriptions.where(parse_sources(sources.copy())) + + return transcriptions diff --git a/atr_data_generator/extract/main.py b/atr_data_generator/extract/main.py index 1aafd2a247521570b4ae99a442c8653112c477a0..d46aeb2c676b111703ce20f7a36663bb364e21ea 100644 --- a/atr_data_generator/extract/main.py +++ b/atr_data_generator/extract/main.py @@ -1,529 +1,529 @@ -# -*- coding: utf-8 -*- - -import logging -from collections import Counter, defaultdict -from dataclasses import dataclass -from operator import attrgetter -from typing import List, Optional - -import numpy as np -from apistar.exceptions import ErrorResponse -from arkindex import ArkindexClient, options_from_env -from line_image_extractor.extractor import extract, read_img, save_img -from line_image_extractor.image_utils import WHITE, rotate_and_trim -from tqdm import tqdm - -from atr_data_generator.arguments import CommonArgs -from atr_data_generator.extract.arguments import ( - FilterArgs, - ImageArgs, - SelectArgs, - Style, -) -from atr_data_generator.extract.arkindex import Element -from atr_data_generator.extract.utils import ( - TranscriptionData, - resize_transcription_data, -) -from atr_data_generator.utils import download_image, export_parameters, write_file - -logger = logging.getLogger(__name__) - -ROTATION_CLASSES_TO_ANGLES = { - "rotate_0": 0, - "rotate_left_90": 90, - "rotate_180": 180, - "rotate_right_90": -90, -} -TEXT_LINE = "text_line" -DEFAULT_RESCALE = 1.0 - - -@dataclass -class ATRDataGenerator: - common: CommonArgs - image: ImageArgs - filter: FilterArgs - api_client: ArkindexClient - - def __post_init__( - self, - ): - self.skipped_pages_count = 0 - self.skipped_vertical_lines_count = 0 - self.accepted_lines_count = 0 - - # Create output folders - self.out_line_text_dir, self.out_line_img_dir = ( - self.common.output_dir / subfolder / self.common.dataset_name - for subfolder in ("Transcriptions", "Lines") - ) - self.out_line_text_dir.mkdir(exist_ok=True, parents=True) - self.out_line_img_dir.mkdir(exist_ok=True, parents=True) - - # Setup cache if needed - self.setup_cache() - - @property - def should_load_metadatas(self): - return ( - self.should_filter_by_metadatas and not self.filter.filter_parent_metadatas - ) - - @property - def should_filter_by_class(self): - return self.filter.should_filter_by_class - - @property - def should_filter_by_style(self): - return self.filter.should_filter_by_style - - @property - def should_filter_by_metadatas(self): - return self.filter.should_filter_by_metadatas - - @property - def should_filter_by_worker(self): - return self.filter.should_filter_by_worker - - @property - def scale(self): - return self.image.scale - - @property - def should_resize_polygons(self): - return self.scale.should_resize_polygons - - def setup_cache(self): - logger.info(f"Setting up cache to {self.common.cache_dir}") - self.img_cache_dir.mkdir(exist_ok=True, parents=True) - self._color = "grayscale" if self.image.grayscale else "rgb" - - @property - def img_cache_dir(self): - return self.common.cache_dir / "images" - - @property - def _cache_is_empty(self): - return not any(self.img_cache_dir.iterdir()) - - @property - def filter_zones(self): - return ( - self.should_filter_by_class - or self.should_filter_by_style - or self.should_filter_by_metadatas - ) - - def find_image_in_cache(self, image_url): - # id is last part before full/full/0/default.jpg - image_id = image_url.split("/")[-5].replace("%2F", "/") - return self.img_cache_dir / self._color / image_id - - def get_image(self, page: Element) -> np.ndarray: - cached_img_path = self.find_image_in_cache(page.image_url) - if not self._cache_is_empty and cached_img_path.exists(): - logger.info(f"Cached image exists: {cached_img_path} - {page.id}") - return read_img(cached_img_path, self.image.grayscale) - else: - pil_img = download_image(page.image_url) - if self.image.grayscale: - pil_img = pil_img.convert("L") - - logger.info(f"Image not in cache: {cached_img_path} - {page.id}") - cached_img_path.parent.mkdir(exist_ok=True, parents=True) - pil_img.save(cached_img_path, format="jpeg") - - return np.array(pil_img) - - def metadata_filtering(self, elt): - if self.filter.filter_parent_metadatas: - metadatas = [] - parents = self.api_client.paginate( - "ListElementParents", id=elt["id"], with_metadata=True - ) - for parent in parents: - metadatas.extend(parent["metadata"]) - else: - metadatas = elt["metadata"] - metadatas_dict = {metadata["name"]: metadata["value"] for metadata in metadatas} - for meta in self.filter.accepted_metadatas: - if not ( - meta in metadatas_dict - and metadatas_dict[meta] == self.filter.accepted_metadatas[meta] - ): - return False - return True - - def get_accepted_zones(self, page: Element, element_type: Optional[str]): - if not self.filter_zones: - return [] - - try: - accepted_zones = [] - for elt in self.api_client.paginate( - "ListElementChildren", - id=page.id, - type=element_type, - with_classes=self.should_filter_by_class, - with_metadata=self.should_load_metadatas, - recursive=True, - ): - - should_accept = True - if self.should_filter_by_class: - # at first filter to only have elements with accepted classes - # if accepted classes list is empty then should accept all - # except for ignored classes - elem_classes = [ - c for c in elt["classes"] if c["state"] != "rejected" - ] - - should_accept = len(self.filter.accepted_classes) == 0 - for classification in elem_classes: - class_name = classification["ml_class"]["name"] - if class_name in self.filter.accepted_classes: - should_accept = True - break - elif class_name in self.filter.ignored_classes: - should_accept = False - break - - if not should_accept: - continue - - if self.should_filter_by_style: - elem_classes = [ - c for c in elt["classes"] if c["state"] != "rejected" - ] - style_counts = Counter() - for classification in elem_classes: - class_name = classification["ml_class"]["name"] - if class_name in Style.list(): - style_counts[class_name] += 1 - - if len(style_counts) == 0: - # no handwritten or typewritten found, so other - found_class = Style.other - elif len(style_counts) == 1: - found_class = list(style_counts.keys())[0] - found_class = Style(found_class) - else: - raise ValueError( - f"Multiple style classes on the same element! {elt['id']} - {elem_classes}" - ) - - if not should_accept: - continue - - if self.should_filter_by_metadatas: - if self.metadata_filtering(elt): - accepted_zones.append(elt["zone"]["id"]) - else: - accepted_zones.append(elt["zone"]["id"]) - logger.info( - "Number of accepted zone for page {} : {}".format( - page.id, len(accepted_zones) - ) - ) - return accepted_zones - except ErrorResponse as e: - logger.info( - f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page.id}" - ) - raise e - - def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]): - # Maybe not keep - if not lines: - return - - line_elem_counter = Counter([trans.element_id for trans in lines]) - most_common = line_elem_counter.most_common(10) - if most_common[0][-1] > 1: - logger.error("Line elements have multiple transcriptions! Showing top 10:") - logger.error(f"{most_common}") - raise ValueError(f"Multiple transcriptions: {most_common[0]}") - - worker_version_counter = Counter([trans.worker_version_id for trans in lines]) - if len(worker_version_counter) > 1: - logger.warning( - f"There are transcriptions from multiple worker versions on this page: {page_id}:" - ) - logger.warning( - f"Top 10 worker versions: {worker_version_counter.most_common(10)}" - ) - - def _choose_best_transcriptions( - self, lines: List[TranscriptionData] - ) -> List[TranscriptionData]: - # Keep inspiration from https://gitlab.com/teklia/callico/-/blob/master/callico/process/imports.py#L189 - """ - Get the best transcription based on the order of accepted worker version ids. - :param lines: - :return: - """ - if not lines: - return [] - - trans_by_element = defaultdict(list) - for line in lines: - trans_by_element[line.element_id].append(line) - - best_transcriptions = [] - for elem, trans_list in trans_by_element.items(): - tmp_dict = {t.worker_version_id: t for t in trans_list} - - for wv in self.filter.accepted_worker_version_ids: - if wv in tmp_dict: - best_transcriptions.append(tmp_dict[wv]) - break - else: - logger.info(f"No suitable trans found for {elem}") - return best_transcriptions - - def get_transcriptions(self, page_id: str, element_type: Optional[str]): - lines = [] - accepted_zones = self.get_accepted_zones(page_id, element_type) - try: - for res in self.api_client.paginate( - "ListTranscriptions", id=page_id, recursive=True - ): - if ( - self.should_filter_by_worker - and res["worker_version_id"] - not in self.filter.accepted_worker_version_ids - ): - continue - if ( - self.should_filter_by_class - or self.should_filter_by_style - or self.should_filter_by_metadatas - ) and (res["element"]["zone"]["id"] not in accepted_zones): - continue - if res["element"]["type"] != self.filter.transcription_type: - continue - - text = res["text"] - if not text or not text.strip(): - continue - - if ( - "\n" in text.strip() - and not self.filter.transcription_type == "text" - ): - elem_id = res["element"]["id"] - raise ValueError( - f"Newlines are not allowed in line transcriptions - {page_id} - {elem_id} - {text}" - ) - - if "zone" in res: - polygon = res["zone"]["polygon"] - elif "element" in res: - polygon = res["element"]["zone"]["polygon"] - else: - raise ValueError(f"Data problem with polygon :: {res}") - - trans_data = TranscriptionData( - element_id=res["element"]["id"], - element_name=res["element"]["name"], - polygon=polygon, - text=text, - transcription_id=res["id"], - worker_version_id=res["worker_version_id"], - ) - - lines.append(trans_data) - - if self.should_filter_by_worker: - # if accepted worker versions have been defined then use them - lines = self._choose_best_transcriptions(lines) - else: - # if no accepted worker versions have been defined - # then check that there aren't multiple transcriptions - # on the same text line - self._validate_transcriptions(page_id, lines) - - if self.image.should_rotate: - classes_by_elem = self.get_children_classes(page_id) - - for trans in lines: - rotation_classes = [ - c - for c in classes_by_elem[trans.element_id] - if c in ROTATION_CLASSES_TO_ANGLES - ] - if len(rotation_classes) > 0: - if len(rotation_classes) > 1: - logger.warning( - f"Several rotation classes = {len(rotation_classes)} - {trans.element_id}" - ) - trans.rotation_class = rotation_classes[0] - else: - logger.warning(f"No rotation classes on {trans.element_id}") - - count_skipped = 0 - if self.filter.skip_vertical_lines: - filtered_lines = [] - for line in lines: - if line.is_vertical: - count_skipped += 1 - continue - filtered_lines.append(line) - - lines = filtered_lines - - count = len(lines) - - return lines, count, count_skipped - - except ErrorResponse as e: - logger.info( - f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}" - ) - raise e - - def get_children_classes(self, page_id): - return { - elem["id"]: [ - best_class["ml_class"]["name"] - for best_class in elem["classes"] - if best_class["state"] != "rejected" - ] - for elem in self.api_client.paginate( - "ListElementChildren", - id=page_id, - recursive=True, - type=TEXT_LINE, - with_classes=True, - ) - } - - def _save_line_image(self, page_id, line_img, trans: TranscriptionData = None): - # Get line id - line_id = trans.element_id - - # Get line number from its name - line_number = trans.element_name.split("_")[-1] - - if self.image.should_rotate and trans.rotation_class: - rotate_angle = ROTATION_CLASSES_TO_ANGLES[trans.rotation_class] - line_img = rotate_and_trim(line_img, rotate_angle, WHITE) - - save_img( - f"{self.out_line_img_dir}/{page_id}_{line_number:0>3}_{line_id}.jpg", - line_img, - ) - - def extract_lines(self, page: Element, element_type: Optional[str]): - lines, count, count_skipped = self.get_transcriptions(page.id, element_type) - if count == 0: - self.skipped_pages_count += 1 - logger.info( - f"{page.type.capitalize()} {page.id} skipped, because it has no {element_type}s" - ) - return - - logger.debug(f"Total num of lines {count + count_skipped}") - logger.debug(f"Num of accepted lines {count}") - logger.debug(f"Num of skipped lines {count_skipped}") - - self.skipped_vertical_lines_count += count_skipped - self.accepted_lines_count += count - - img = self.get_image(page) - - # sort vertically then horizontally - sorted_lines = sorted(lines, key=attrgetter("rect.y", "rect.x")) - - if self.should_resize_polygons: - sorted_lines = [ - resize_transcription_data( - line, - page, - self.scale, - ) - for line in sorted_lines - ] - - for trans in sorted_lines: - extracted_img = extract( - img=img, - polygon=np.array(trans.polygon), - bbox=trans.rect, - extraction_mode=self.image.extraction_mode.mode, - max_deskew_angle=self.image.max_deskew_angle, - skew_angle=self.image.skew_angle, - grayscale=self.image.grayscale, - ) - - # don't enumerate, read the line number from the elements's name (e.g. line_xx) so that it matches with Arkindex - self._save_line_image(page.id, extracted_img, trans) - - for trans in sorted_lines: - line_number = trans.element_name.split("_")[-1] - line_id = trans.element_id - write_file( - f"{self.out_line_text_dir}/{page.id}_{line_number:0>3}_{line_id}.txt", - trans.text, - ) - - def run_folders(self, folder_ids: list, parent_type: str, element_type: str): - for folder_id in tqdm(folder_ids, desc="Processing folders"): - logger.info(f"Processing folder {folder_id}") - # Look for parents - for parent in self.api_client.paginate( - "ListElementChildren", id=folder_id, type=parent_type, recursive=True - ): - self.extract_lines(Element(parent), element_type) - - def run_corpora(self, corpus_ids: list, parent_type: str, element_type: str): - for corpus_id in tqdm(corpus_ids): - logger.info(f"Processing corpus {corpus_id}") - # Look for parents - for parent in self.api_client.paginate( - "ListElements", corpus=corpus_id, type=parent_type, recursive=True - ): - self.extract_lines(Element(parent), element_type) - - -def main(common: CommonArgs, image: ImageArgs, filters: FilterArgs, select: SelectArgs): - api_client = ArkindexClient(**options_from_env()) - - data_generator = ATRDataGenerator( - common=common, - image=image, - filter=filters, - api_client=api_client, - ) - - element_type = filters.transcription_type - if element_type: - element_type = str(element_type) - logger.info(f"Will look for transcriptions of `{element_type}s`") - - # extract all the lines and transcriptions - if select.folders: - data_generator.run_folders(select.folders, select.parent_type, element_type) - elif select.corpora: - data_generator.run_corpora(select.corpora, select.parent_type, element_type) - else: - raise Exception("Please specify either a folder or a corpus.") - - if data_generator.skipped_vertical_lines_count > 0: - logger.info(f"Number of skipped pages: {data_generator.skipped_pages_count}") - _skipped_vertical_count = data_generator.skipped_vertical_lines_count - _total_count = _skipped_vertical_count + data_generator.accepted_lines_count - skipped_ratio = _skipped_vertical_count / _total_count * 100 - - logger.info( - f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)" - ) - - export_parameters( - common=common, - image=image, - select=select, - filter=filters, - arkindex_api_url=api_client.document.url, - ) +# # -*- coding: utf-8 -*- + +# import logging +# from collections import Counter, defaultdict +# from dataclasses import dataclass +# from operator import attrgetter +# from typing import List, Optional + +# import numpy as np +# from apistar.exceptions import ErrorResponse +# from arkindex import ArkindexClient, options_from_env +# from line_image_extractor.extractor import extract, read_img, save_img +# from line_image_extractor.image_utils import WHITE, rotate_and_trim +# from tqdm import tqdm + +# from atr_data_generator.arguments import CommonArgs +# from atr_data_generator.extract.arguments import ( +# FilterArgs, +# ImageArgs, +# SelectArgs, +# Style, +# ) +# from atr_data_generator.extract.arkindex import Element +# from atr_data_generator.extract.utils import ( +# TranscriptionData, +# resize_transcription_data, +# ) +# from atr_data_generator.utils import download_image, export_parameters, write_file + +# logger = logging.getLogger(__name__) + +# ROTATION_CLASSES_TO_ANGLES = { +# "rotate_0": 0, +# "rotate_left_90": 90, +# "rotate_180": 180, +# "rotate_right_90": -90, +# } +# TEXT_LINE = "text_line" +# DEFAULT_RESCALE = 1.0 + + +# @dataclass +# class ATRDataGenerator: +# common: CommonArgs +# image: ImageArgs +# filter: FilterArgs +# api_client: ArkindexClient + +# def __post_init__( +# self, +# ): +# self.skipped_pages_count = 0 +# self.skipped_vertical_lines_count = 0 +# self.accepted_lines_count = 0 + +# # Create output folders +# self.out_line_text_dir, self.out_line_img_dir = ( +# self.common.output_dir / subfolder / self.common.dataset_name +# for subfolder in ("Transcriptions", "Lines") +# ) +# self.out_line_text_dir.mkdir(exist_ok=True, parents=True) +# self.out_line_img_dir.mkdir(exist_ok=True, parents=True) + +# # Setup cache if needed +# self.setup_cache() + +# @property +# def should_load_metadatas(self): +# return ( +# self.should_filter_by_metadatas and not self.filter.filter_parent_metadatas +# ) + +# @property +# def should_filter_by_class(self): +# return self.filter.should_filter_by_class + +# @property +# def should_filter_by_style(self): +# return self.filter.should_filter_by_style + +# @property +# def should_filter_by_metadatas(self): +# return self.filter.should_filter_by_metadatas + +# @property +# def should_filter_by_worker(self): +# return self.filter.should_filter_by_worker + +# @property +# def scale(self): +# return self.image.scale + +# @property +# def should_resize_polygons(self): +# return self.scale.should_resize_polygons + +# def setup_cache(self): +# logger.info(f"Setting up cache to {self.common.cache_dir}") +# self.img_cache_dir.mkdir(exist_ok=True, parents=True) +# self._color = "grayscale" if self.image.grayscale else "rgb" + +# @property +# def img_cache_dir(self): +# return self.common.cache_dir / "images" + +# @property +# def _cache_is_empty(self): +# return not any(self.img_cache_dir.iterdir()) + +# @property +# def filter_zones(self): +# return ( +# self.should_filter_by_class +# or self.should_filter_by_style +# or self.should_filter_by_metadatas +# ) + +# def find_image_in_cache(self, image_url): +# # id is last part before full/full/0/default.jpg +# image_id = image_url.split("/")[-5].replace("%2F", "/") +# return self.img_cache_dir / self._color / image_id + +# def get_image(self, page: Element) -> np.ndarray: +# cached_img_path = self.find_image_in_cache(page.image_url) +# if not self._cache_is_empty and cached_img_path.exists(): +# logger.info(f"Cached image exists: {cached_img_path} - {page.id}") +# return read_img(cached_img_path, self.image.grayscale) +# else: +# pil_img = download_image(page.image_url) +# if self.image.grayscale: +# pil_img = pil_img.convert("L") + +# logger.info(f"Image not in cache: {cached_img_path} - {page.id}") +# cached_img_path.parent.mkdir(exist_ok=True, parents=True) +# pil_img.save(cached_img_path, format="jpeg") + +# return np.array(pil_img) + +# def metadata_filtering(self, elt): +# if self.filter.filter_parent_metadatas: +# metadatas = [] +# parents = self.api_client.paginate( +# "ListElementParents", id=elt["id"], with_metadata=True +# ) +# for parent in parents: +# metadatas.extend(parent["metadata"]) +# else: +# metadatas = elt["metadata"] +# metadatas_dict = {metadata["name"]: metadata["value"] for metadata in metadatas} +# for meta in self.filter.accepted_metadatas: +# if not ( +# meta in metadatas_dict +# and metadatas_dict[meta] == self.filter.accepted_metadatas[meta] +# ): +# return False +# return True + +# def get_accepted_zones(self, page: Element, element_type: Optional[str]): +# if not self.filter_zones: +# return [] + +# try: +# accepted_zones = [] +# for elt in self.api_client.paginate( +# "ListElementChildren", +# id=page.id, +# type=element_type, +# with_classes=self.should_filter_by_class, +# with_metadata=self.should_load_metadatas, +# recursive=True, +# ): + +# should_accept = True +# if self.should_filter_by_class: +# # at first filter to only have elements with accepted classes +# # if accepted classes list is empty then should accept all +# # except for ignored classes +# elem_classes = [ +# c for c in elt["classes"] if c["state"] != "rejected" +# ] + +# should_accept = len(self.filter.accepted_classes) == 0 +# for classification in elem_classes: +# class_name = classification["ml_class"]["name"] +# if class_name in self.filter.accepted_classes: +# should_accept = True +# break +# elif class_name in self.filter.ignored_classes: +# should_accept = False +# break + +# if not should_accept: +# continue + +# if self.should_filter_by_style: +# elem_classes = [ +# c for c in elt["classes"] if c["state"] != "rejected" +# ] +# style_counts = Counter() +# for classification in elem_classes: +# class_name = classification["ml_class"]["name"] +# if class_name in Style.list(): +# style_counts[class_name] += 1 + +# if len(style_counts) == 0: +# # no handwritten or typewritten found, so other +# found_class = Style.other +# elif len(style_counts) == 1: +# found_class = list(style_counts.keys())[0] +# found_class = Style(found_class) +# else: +# raise ValueError( +# f"Multiple style classes on the same element! {elt['id']} - {elem_classes}" +# ) + +# if not should_accept: +# continue + +# if self.should_filter_by_metadatas: +# if self.metadata_filtering(elt): +# accepted_zones.append(elt["zone"]["id"]) +# else: +# accepted_zones.append(elt["zone"]["id"]) +# logger.info( +# "Number of accepted zone for page {} : {}".format( +# page.id, len(accepted_zones) +# ) +# ) +# return accepted_zones +# except ErrorResponse as e: +# logger.info( +# f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page.id}" +# ) +# raise e + +# def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]): +# # Maybe not keep +# if not lines: +# return + +# line_elem_counter = Counter([trans.element_id for trans in lines]) +# most_common = line_elem_counter.most_common(10) +# if most_common[0][-1] > 1: +# logger.error("Line elements have multiple transcriptions! Showing top 10:") +# logger.error(f"{most_common}") +# raise ValueError(f"Multiple transcriptions: {most_common[0]}") + +# worker_version_counter = Counter([trans.worker_version_id for trans in lines]) +# if len(worker_version_counter) > 1: +# logger.warning( +# f"There are transcriptions from multiple worker versions on this page: {page_id}:" +# ) +# logger.warning( +# f"Top 10 worker versions: {worker_version_counter.most_common(10)}" +# ) + +# def _choose_best_transcriptions( +# self, lines: List[TranscriptionData] +# ) -> List[TranscriptionData]: +# # Keep inspiration from https://gitlab.com/teklia/callico/-/blob/master/callico/process/imports.py#L189 +# """ +# Get the best transcription based on the order of accepted worker version ids. +# :param lines: +# :return: +# """ +# if not lines: +# return [] + +# trans_by_element = defaultdict(list) +# for line in lines: +# trans_by_element[line.element_id].append(line) + +# best_transcriptions = [] +# for elem, trans_list in trans_by_element.items(): +# tmp_dict = {t.worker_version_id: t for t in trans_list} + +# for wv in self.filter.accepted_worker_version_ids: +# if wv in tmp_dict: +# best_transcriptions.append(tmp_dict[wv]) +# break +# else: +# logger.info(f"No suitable trans found for {elem}") +# return best_transcriptions + +# def get_transcriptions(self, page_id: str, element_type: Optional[str]): +# lines = [] +# accepted_zones = self.get_accepted_zones(page_id, element_type) +# try: +# for res in self.api_client.paginate( +# "ListTranscriptions", id=page_id, recursive=True +# ): +# if ( +# self.should_filter_by_worker +# and res["worker_version_id"] +# not in self.filter.accepted_worker_version_ids +# ): +# continue +# if ( +# self.should_filter_by_class +# or self.should_filter_by_style +# or self.should_filter_by_metadatas +# ) and (res["element"]["zone"]["id"] not in accepted_zones): +# continue +# if res["element"]["type"] != self.filter.transcription_type: +# continue + +# text = res["text"] +# if not text or not text.strip(): +# continue + +# if ( +# "\n" in text.strip() +# and not self.filter.transcription_type == "text" +# ): +# elem_id = res["element"]["id"] +# raise ValueError( +# f"Newlines are not allowed in line transcriptions - {page_id} - {elem_id} - {text}" +# ) + +# if "zone" in res: +# polygon = res["zone"]["polygon"] +# elif "element" in res: +# polygon = res["element"]["zone"]["polygon"] +# else: +# raise ValueError(f"Data problem with polygon :: {res}") + +# trans_data = TranscriptionData( +# element_id=res["element"]["id"], +# element_name=res["element"]["name"], +# polygon=polygon, +# text=text, +# transcription_id=res["id"], +# worker_version_id=res["worker_version_id"], +# ) + +# lines.append(trans_data) + +# if self.should_filter_by_worker: +# # if accepted worker versions have been defined then use them +# lines = self._choose_best_transcriptions(lines) +# else: +# # if no accepted worker versions have been defined +# # then check that there aren't multiple transcriptions +# # on the same text line +# self._validate_transcriptions(page_id, lines) + +# if self.image.should_rotate: +# classes_by_elem = self.get_children_classes(page_id) + +# for trans in lines: +# rotation_classes = [ +# c +# for c in classes_by_elem[trans.element_id] +# if c in ROTATION_CLASSES_TO_ANGLES +# ] +# if len(rotation_classes) > 0: +# if len(rotation_classes) > 1: +# logger.warning( +# f"Several rotation classes = {len(rotation_classes)} - {trans.element_id}" +# ) +# trans.rotation_class = rotation_classes[0] +# else: +# logger.warning(f"No rotation classes on {trans.element_id}") + +# count_skipped = 0 +# if self.filter.skip_vertical_lines: +# filtered_lines = [] +# for line in lines: +# if line.is_vertical: +# count_skipped += 1 +# continue +# filtered_lines.append(line) + +# lines = filtered_lines + +# count = len(lines) + +# return lines, count, count_skipped + +# except ErrorResponse as e: +# logger.info( +# f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}" +# ) +# raise e + +# def get_children_classes(self, page_id): +# return { +# elem["id"]: [ +# best_class["ml_class"]["name"] +# for best_class in elem["classes"] +# if best_class["state"] != "rejected" +# ] +# for elem in self.api_client.paginate( +# "ListElementChildren", +# id=page_id, +# recursive=True, +# type=TEXT_LINE, +# with_classes=True, +# ) +# } + +# def _save_line_image(self, page_id, line_img, trans: TranscriptionData = None): +# # Get line id +# line_id = trans.element_id + +# # Get line number from its name +# line_number = trans.element_name.split("_")[-1] + +# if self.image.should_rotate and trans.rotation_class: +# rotate_angle = ROTATION_CLASSES_TO_ANGLES[trans.rotation_class] +# line_img = rotate_and_trim(line_img, rotate_angle, WHITE) + +# save_img( +# f"{self.out_line_img_dir}/{page_id}_{line_number:0>3}_{line_id}.jpg", +# line_img, +# ) + +# def extract_lines(self, page: Element, element_type: Optional[str]): +# lines, count, count_skipped = self.get_transcriptions(page.id, element_type) +# if count == 0: +# self.skipped_pages_count += 1 +# logger.info( +# f"{page.type.capitalize()} {page.id} skipped, because it has no {element_type}s" +# ) +# return + +# logger.debug(f"Total num of lines {count + count_skipped}") +# logger.debug(f"Num of accepted lines {count}") +# logger.debug(f"Num of skipped lines {count_skipped}") + +# self.skipped_vertical_lines_count += count_skipped +# self.accepted_lines_count += count + +# img = self.get_image(page) + +# # sort vertically then horizontally +# sorted_lines = sorted(lines, key=attrgetter("rect.y", "rect.x")) + +# if self.should_resize_polygons: +# sorted_lines = [ +# resize_transcription_data( +# line, +# page, +# self.scale, +# ) +# for line in sorted_lines +# ] + +# for trans in sorted_lines: +# extracted_img = extract( +# img=img, +# polygon=np.array(trans.polygon), +# bbox=trans.rect, +# extraction_mode=self.image.extraction_mode.mode, +# max_deskew_angle=self.image.max_deskew_angle, +# skew_angle=self.image.skew_angle, +# grayscale=self.image.grayscale, +# ) + +# # don't enumerate, read the line number from the elements's name (e.g. line_xx) so that it matches with Arkindex +# self._save_line_image(page.id, extracted_img, trans) + +# for trans in sorted_lines: +# line_number = trans.element_name.split("_")[-1] +# line_id = trans.element_id +# write_file( +# f"{self.out_line_text_dir}/{page.id}_{line_number:0>3}_{line_id}.txt", +# trans.text, +# ) + +# def run_folders(self, folder_ids: list, parent_type: str, element_type: str): +# for folder_id in tqdm(folder_ids, desc="Processing folders"): +# logger.info(f"Processing folder {folder_id}") +# # Look for parents +# for parent in self.api_client.paginate( +# "ListElementChildren", id=folder_id, type=parent_type, recursive=True +# ): +# self.extract_lines(Element(parent), element_type) + +# def run_corpora(self, corpus_ids: list, parent_type: str, element_type: str): +# for corpus_id in tqdm(corpus_ids): +# logger.info(f"Processing corpus {corpus_id}") +# # Look for parents +# for parent in self.api_client.paginate( +# "ListElements", corpus=corpus_id, type=parent_type, recursive=True +# ): +# self.extract_lines(Element(parent), element_type) + + +# def main(common: CommonArgs, image: ImageArgs, filters: FilterArgs, select: SelectArgs): +# api_client = ArkindexClient(**options_from_env()) + +# data_generator = ATRDataGenerator( +# common=common, +# image=image, +# filter=filters, +# api_client=api_client, +# ) + +# element_type = filters.transcription_type +# if element_type: +# element_type = str(element_type) +# logger.info(f"Will look for transcriptions of `{element_type}s`") + +# # extract all the lines and transcriptions +# if select.folders: +# data_generator.run_folders(select.folders, select.parent_type, element_type) +# elif select.corpora: +# data_generator.run_corpora(select.corpora, select.parent_type, element_type) +# else: +# raise Exception("Please specify either a folder or a corpus.") + +# if data_generator.skipped_vertical_lines_count > 0: +# logger.info(f"Number of skipped pages: {data_generator.skipped_pages_count}") +# _skipped_vertical_count = data_generator.skipped_vertical_lines_count +# _total_count = _skipped_vertical_count + data_generator.accepted_lines_count +# skipped_ratio = _skipped_vertical_count / _total_count * 100 + +# logger.info( +# f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)" +# ) + +# export_parameters( +# common=common, +# image=image, +# select=select, +# filter=filters, +# arkindex_api_url=api_client.document.url, +# ) diff --git a/atr_data_generator/extract/utils.py b/atr_data_generator/extract/utils.py index f2f21b6c67475cf6939240e48c7026eea400831e..f8e13c66f5ef0e3c5c0f5b2eadbe6fa1619e80e8 100644 --- a/atr_data_generator/extract/utils.py +++ b/atr_data_generator/extract/utils.py @@ -1,78 +1,15 @@ # -*- coding: utf-8 -*- -import uuid +from line_image_extractor.image_utils import BoundingBox +from PIL import Image -from document_processing.transcription import Transcription -from document_processing.utils import TextOrientation -from line_image_extractor.image_utils import resize -from atr_data_generator.extract.arguments import ScaleArgs -from atr_data_generator.extract.arkindex import Element +def _is_vertical(bbox: BoundingBox): + return bbox.height > bbox.width -class TranscriptionData(Transcription): - def __init__( - self, - element_id, - polygon, - text, - confidence=None, - orientation=TextOrientation.HorizontalLeftToRight, - rotation_class=None, - rotation_class_confidence=None, - element_name: str = None, - transcription_id: uuid.UUID = None, - worker_version_id: uuid.UUID = None, - ): - super().__init__( - element_id, - polygon, - text, - confidence, - orientation, - rotation_class, - rotation_class_confidence, - ) - self.element_name = element_name - self.transcription_id = transcription_id - self.worker_version_id = worker_version_id - - @property - def is_vertical(self) -> bool: - """ - Used to filter out vertical lines. Will be ignored when rotation class is given. - """ - if self.rotation_class is None: - return self.rect.height > self.rect.width - return False - - def __repr__(self): - return str(vars(self)) - - @classmethod - def copy_replace_polygon(cls, trans: "TranscriptionData", new_polygon): - """ - Class method to keep the change logic inside the class - less likely to forget to update. - """ - return TranscriptionData( - element_id=trans.element_id, - element_name=trans.element_name, - polygon=new_polygon, - text=trans.text, - transcription_id=trans.transcription_id, - worker_version_id=trans.worker_version_id, - rotation_class=trans.rotation_class, - ) - - -def resize_transcription_data( - trans: TranscriptionData, - page: Element, - scale: ScaleArgs, -) -> TranscriptionData: - orig_polygon = trans.polygon - resized_polygon = resize( - orig_polygon, page.width, page.height, scale.x, scale.y_top, scale.y_bottom - ) - - return TranscriptionData.copy_replace_polygon(trans, resized_polygon) +def resize_image_height(img: Image.Image, fixed_height: int): + width, height = img.size + height_ratio = fixed_height / height + new_width = int(width * height_ratio) + return img.resize((new_width, fixed_height), Image.NEAREST) diff --git a/atr_data_generator/split/arguments.py b/atr_data_generator/split/arguments.py index e5317d24cfea8986880891595fe44e4bfb4a6fe9..f6493ef001e64699dba794f1773c5b660ec1290a 100644 --- a/atr_data_generator/split/arguments.py +++ b/atr_data_generator/split/arguments.py @@ -23,7 +23,7 @@ class SplitArgs(BaseArgs): val_ratio: float = 1 - train_ratio - test_ratio use_existing_split: bool = False - def _validate(self): + def __post_init__(self): if self.train_ratio + self.val_ratio + self.test_ratio != 1.0: raise argparse.ArgumentTypeError( f"Invalid ratios for (train, val, test) ({self.train_ratio}, {self.val_ratio}, {self.test_ratio})" diff --git a/atr_data_generator/split/main.py b/atr_data_generator/split/main.py index 30a441eea6573149813d3a713bce6a1d8c06773d..6d8ded3373713195c461696404c2f138027dd240 100644 --- a/atr_data_generator/split/main.py +++ b/atr_data_generator/split/main.py @@ -3,12 +3,13 @@ import logging import random from enum import Enum +from pathlib import Path import numpy as np from atr_data_generator.arguments import CommonArgs from atr_data_generator.split.arguments import SplitArgs -from atr_data_generator.utils import export_parameters, write_file +from atr_data_generator.utils import export_parameters logger = logging.getLogger(__name__) @@ -105,7 +106,7 @@ class PartitionSplitter: logger.info(f"Partition {split} is empty! Skipping...") continue file_name = f"{partitions_dir}/{Split(split).name}Lines.lst" - write_file(file_name, "\n".join(split_line_ids) + "\n") + Path(file_name).write_text("\n".join(split_line_ids) + "\n") return datasets diff --git a/atr_data_generator/utils.py b/atr_data_generator/utils.py index 592b4046c5d377679d1981b95675a06cca203886..d8cd1ba0e05bfded9d672370717bfd5c00ad0005 100644 --- a/atr_data_generator/utils.py +++ b/atr_data_generator/utils.py @@ -12,28 +12,43 @@ from typing import TYPE_CHECKING, Optional import requests import yaml from PIL import Image +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) if TYPE_CHECKING: from atr_data_generator.arguments import CommonArgs from atr_data_generator.extract.arguments import FilterArgs, ImageArgs, SelectArgs from atr_data_generator.split.arguments import SplitArgs -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) -def write_file(file_name, content): - with open(file_name, "w") as f: - f.write(content) +# See http://docs.python-requests.org/en/master/user/advanced/#timeouts +DOWNLOAD_TIMEOUT = (30, 60) -def write_json(d, filename): - with open(filename, "w") as f: - f.write(json.dumps(d, indent=4)) +def _retry_log(retry_state, *args, **kwargs): + logger.warning( + f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), " + f"retrying in {retry_state.idle_for} seconds" + ) -def write_yaml(d, filename): - with open(filename, "w") as f: - yaml.dump(d, f) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=2), + retry=retry_if_exception_type(requests.RequestException), + before_sleep=_retry_log, + reraise=True, +) +def _retried_request(url): + resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT) + resp.raise_for_status() + return resp def download_image(url): @@ -44,8 +59,7 @@ def download_image(url): # Download the image # Cannot use stream=True as urllib's responses do not support the seek(int) method, # which is explicitly required by Image.open on file-like objects - resp = requests.get(url) - resp.raise_for_status() + resp = _retried_request(url) # Preprocess the image and prepare it for classification image = Image.open(BytesIO(resp.content)) @@ -62,22 +76,19 @@ def export_parameters( split: Optional[SplitArgs] = None, select: Optional[SelectArgs] = None, filter: Optional[FilterArgs] = None, - datasets=None, - arkindex_api_url: Optional[str] = None, ): """ Dump a JSON log file to keep track of parameters for this dataset """ - # Get config dict - config = {"common": common.dict()} + config = {"common": common.json()} if image: - config["image"] = image.dict() + config["image"] = image.json() if split: - config["split"] = split.dict() + config["split"] = split.json() if select: - config["select"] = select.dict() + config["select"] = select.json() if filter: - config["filter"] = filter.dict() + config["filter"] = filter.json() if common.log_parameters: # Get additional info on dataset and user @@ -85,26 +96,21 @@ def export_parameters( parameters = {} parameters["config"] = config - if datasets: - parameters["dataset"] = ( - datasets - if isinstance(datasets, dict) - else {"train": datasets[0], "valid": datasets[1], "test": datasets[2]} - ) parameters["info"] = { "user": getpass.getuser(), "date": current_datetime, "device": socket.gethostname(), - "arkindex_api_url": arkindex_api_url, } # Export parameter_file = ( common.output_dir / f"param-{common.dataset_name}-{current_datetime}.json" ) - write_json(parameters, parameter_file) + parameter_file.write_text(json.dumps(parameters, indent=2, default=str)) + logger.info(f"Parameters exported in file {parameter_file}") config_file = common.output_dir / "config.yaml" - write_yaml(config, config_file) + config_file.write_text(yaml.dump(config)) + logger.info(f"Config exported in file {config_file}") diff --git a/docs/extract/index.md b/docs/extract/index.md index 23bd7e03836aa540ceca7fd4798be967a7b3a004..fce719d44627797e928ba88489310fe85e10d9bd 100644 --- a/docs/extract/index.md +++ b/docs/extract/index.md @@ -1,8 +1,21 @@ ## Dataset extraction -The `extract` subcommand is used to extract data from Arkindex. Two folders will be created: +The `extract` subcommand is used to extract data from Arkindex. This will create: -- `Lines`, with the images that need to be transcribed, -- `Transcription`, with the groundtruth `.txt` transcriptions of each image. +- `images/`, a folder with the images that need to be transcribed, +- `labels.json`, a JSON file where each image is linked to its transcription. + +The full command is: + +```sh +atr-data-generator extract \ + --config path/to/configuration.yaml \ + --database-path path/to/db.sqlite +``` + +Both arguments are required: + +- `--config`, the path to the configuration file, +- `--database-path`, the path to the Arkindex SQLite export of the corpus. More details about the configuration file needed in the [Dataset extraction](./configuration.md) section. diff --git a/examples/extraction.yml b/examples/extraction.yml index 9027793a192d59228328803f2a5918dd52d751e6..ce017ac34e0890e82b3b97d42218b5a4097c0c7b 100644 --- a/examples/extraction.yml +++ b/examples/extraction.yml @@ -4,26 +4,20 @@ common: log_parameters: true output_dir: # Fill me filter: - accepted_classes: [] - accepted_metadatas: {} accepted_worker_version_ids: [] - filter_parent_metadatas: false - ignored_classes: [] skip_vertical_lines: false - style: null - transcription_type: # Fill me image: extraction_mode: deskew_min_area_rect + fixed_height: null grayscale: true max_deskew_angle: 45 scale: - x: null - y_bottom: null - y_top: null + x: 1.0 + y_bottom: 1.0 + y_top: 1.0 should_rotate: false skew_angle: 0 select: - corpora: [] # Fill me or folders - element_type: None - folders: [] # Fill me or corpora - parent_type: None + element_type: null + folders: # Fill me + parent_type: null diff --git a/requirements.txt b/requirements.txt index 34983d7dd3cc6513ac389b7d014a67cb71770d2c..6132e52d0d81979ecb3e3b8646ac618a17c232cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -arkindex-client==1.0.13 +arkindex-export==0.1.4 teklia-document-processing==0.2.0 teklia-line-image-extractor==0.2.8-rc3 teklia-toolbox==0.1.3 +tenacity==8.2.2 tqdm==4.64.1 diff --git a/tests/conftest.py b/tests/conftest.py index 7d90c973307ce8cfacc7cb2e4484bdbdd1c9f1c4..18bab8083edd75b340fdfde2b98721c62e0a6e36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,90 +1,38 @@ # -*- coding: utf-8 -*- -import json -import os from pathlib import Path import pytest -from arkindex.mock import MockApiClient -from line_image_extractor import extractor - -TEST_WORKER_VERSION_ID = "1234-kaldi" - -TEST_VOLUME_ID = "6ebebd79-2a28-464f-b60a-aa47a864a586" +from PIL import Image FIXTURES = Path(__file__).resolve().parent / "data" @pytest.fixture(autouse=True) def setup_environment(responses): - """Setup needed environment variables""" - - # Allow accessing remote API schemas - # defaulting to the prod environment - schema_url = os.environ.get( - "ARKINDEX_API_SCHEMA_URL", - "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json", - ) - responses.add_passthru(schema_url) + arkindex_iiif_server = "https://europe-gamma.iiif.teklia.com/iiif/2" - # Setup a fake worker version ID - os.environ["WORKER_VERSION_ID"] = TEST_WORKER_VERSION_ID + # Allow downloading missing images + responses.add_passthru(arkindex_iiif_server) @pytest.fixture -def fake_volume_id(): - return TEST_VOLUME_ID +def database(): + return FIXTURES / "test_db.sqlite" @pytest.fixture -def fake_image(): - img_path = FIXTURES / "pinned_insects/d336445e-a3ea-4278-a973-a14daefab229.jpg" - return extractor.read_img(img_path) +def image_large_height(): + return Image.open(FIXTURES / "images" / "big_height.jpg") @pytest.fixture -def fake_run_filter_metadata(): - api_client = MockApiClient() - with open(FIXTURES / "Maurdor/ListElementChildren/fake_page.json") as f: - pages_json = json.load(f) - api_client.add_response( - "ListElementChildren", - response=pages_json, - id="fake_page", - type=None, - with_classes=False, - with_metadata=True, - recursive=True, - ) - return api_client +def image_small_height(): + return Image.open(FIXTURES / "images" / "small_height.jpg") @pytest.fixture -def fake_run_volume_api_client(fake_volume_id): - api_client = MockApiClient() - with open( - FIXTURES / "pinned_insects/ListElementChildren" / f"{fake_volume_id}.json" - ) as f: - pages_json = json.load(f) - api_client.add_response( - "ListElementChildren", - response=pages_json, - id=fake_volume_id, - recursive=True, - type="page", - ) - - trans_dir = FIXTURES / "pinned_insects/ListTranscriptions" - for trans_file in trans_dir.glob("*.json"): - trans_json = json.loads(trans_file.read_text()) - - api_client.add_response( - "ListTranscriptions", - response=trans_json, - id=trans_file.stem, - recursive=True, - ) - - return api_client +def image_cache(): + return FIXTURES / "images" @pytest.fixture diff --git a/tests/data/Maurdor/ListElementChildren/fake_page.json b/tests/data/Maurdor/ListElementChildren/fake_page.json deleted file mode 100644 index 723730b1017dd2f9e3099fb146da3db3cffdfbfd..0000000000000000000000000000000000000000 --- a/tests/data/Maurdor/ListElementChildren/fake_page.json +++ /dev/null @@ -1,249 +0,0 @@ -[ - { - "id": "0b719a5a-40c7-47ec-96c8-6c3064df5485", - "type": "text", - "name": "29", - "corpus": { - "id": "809222f2-b4a4-444c-a1a8-37667ccbff6b", - "name": "Maurdor", - "public": false - }, - "thumbnail_url": null, - "zone": { - "id": "0b719a5a-40c7-47ec-96c8-6c3064df5485", - "polygon": [ - [190, 1187], - [190, 1242], - [389, 1242], - [389, 1187], - [190, 1187] - ], - "image": { - "id": "65955cb6-7aeb-4f7c-8531-5d797c135e41", - "path": "public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "width": 1700, - "height": 2339, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "s3_url": null, - "status": "checked", - "server": { - "display_name": "https://europe-gamma.iiif.teklia.com/iiif/2", - "url": "https://europe-gamma.iiif.teklia.com/iiif/2", - "max_width": null, - "max_height": null - } - }, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png/190,1187,199,55/full/0/default.jpg" - }, - "rotation_angle": 0, - "mirrored": false, - "created": "2022-09-05T22:15:11.027308Z", - "classes": [], - "metadata": [ - { - "id": "8b5100cf-98c6-4d0f-8bbe-6df07cbdd02e", - "type": "text", - "name": "Language", - "value": "arabic", - "dates": [] - }, - { - "id": "bf7de461-1ca2-4160-b1c7-be63e28bd06e", - "type": "text", - "name": "Script", - "value": "typed", - "dates": [] - } - ], - "has_children": null, - "worker_version_id": "329f5b6e-78c9-4240-a2cf-78a746c6f897", - "confidence": null - }, - { - "id": "fbf0d8e5-4729-4bd1-988a-49e178f7d0e6", - "type": "text", - "name": "27", - "corpus": { - "id": "809222f2-b4a4-444c-a1a8-37667ccbff6b", - "name": "Maurdor", - "public": false - }, - "thumbnail_url": null, - "zone": { - "id": "fbf0d8e5-4729-4bd1-988a-49e178f7d0e6", - "polygon": [ - [676, 2124], - [676, 2195], - [1070, 2195], - [1070, 2124], - [676, 2124] - ], - "image": { - "id": "65955cb6-7aeb-4f7c-8531-5d797c135e41", - "path": "public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "width": 1700, - "height": 2339, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "s3_url": null, - "status": "checked", - "server": { - "display_name": "https://europe-gamma.iiif.teklia.com/iiif/2", - "url": "https://europe-gamma.iiif.teklia.com/iiif/2", - "max_width": null, - "max_height": null - } - }, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png/676,2124,394,71/full/0/default.jpg" - }, - "rotation_angle": 0, - "mirrored": false, - "created": "2022-09-05T22:15:11.027161Z", - "classes": [], - "metadata": [ - { - "id": "75f36675-92a1-43df-90d0-bccd97b43594", - "type": "text", - "name": "Function", - "value": "reference", - "dates": [] - }, - { - "id": "e90372ce-db83-47ca-9362-b2e5def497e8", - "type": "text", - "name": "Language", - "value": "english", - "dates": [] - }, - { - "id": "bd1ad52f-dbb4-4dc8-a865-d6c42abd690e", - "type": "text", - "name": "Script", - "value": "typed", - "dates": [] - } - ], - "has_children": null, - "worker_version_id": "329f5b6e-78c9-4240-a2cf-78a746c6f897", - "confidence": null - }, - { - "id": "18725942-24f5-4f81-a16f-62c1323c1041", - "type": "text", - "name": "28", - "corpus": { - "id": "809222f2-b4a4-444c-a1a8-37667ccbff6b", - "name": "Maurdor", - "public": false - }, - "thumbnail_url": null, - "zone": { - "id": "18725942-24f5-4f81-a16f-62c1323c1041", - "polygon": [ - [1154, 1244], - [1154, 1327], - [1456, 1327], - [1456, 1244], - [1154, 1244] - ], - "image": { - "id": "65955cb6-7aeb-4f7c-8531-5d797c135e41", - "path": "public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "width": 1700, - "height": 2339, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "s3_url": null, - "status": "checked", - "server": { - "display_name": "https://europe-gamma.iiif.teklia.com/iiif/2", - "url": "https://europe-gamma.iiif.teklia.com/iiif/2", - "max_width": null, - "max_height": null - } - }, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png/1154,1244,302,83/full/0/default.jpg" - }, - "rotation_angle": 0, - "mirrored": false, - "created": "2022-09-05T22:15:11.027234Z", - "classes": [], - "metadata": [ - { - "id": "1b78325e-24fa-452f-b11c-5db2ac52df59", - "type": "text", - "name": "Language", - "value": "arabic", - "dates": [] - }, - { - "id": "85ce21e0-0521-4e59-bfd0-a2060888be56", - "type": "text", - "name": "Script", - "value": "typed", - "dates": [] - } - ], - "has_children": null, - "worker_version_id": "329f5b6e-78c9-4240-a2cf-78a746c6f897", - "confidence": null - }, - { - "id": "3ebe7e48-fe2f-4533-92df-9895db05c3f5", - "type": "text", - "name": "23", - "corpus": { - "id": "809222f2-b4a4-444c-a1a8-37667ccbff6b", - "name": "Maurdor", - "public": false - }, - "thumbnail_url": null, - "zone": { - "id": "3ebe7e48-fe2f-4533-92df-9895db05c3f5", - "polygon": [ - [1403, 1763], - [1403, 1830], - [1470, 1830], - [1470, 1763], - [1403, 1763] - ], - "image": { - "id": "65955cb6-7aeb-4f7c-8531-5d797c135e41", - "path": "public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "width": 1700, - "height": 2339, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png", - "s3_url": null, - "status": "checked", - "server": { - "display_name": "https://europe-gamma.iiif.teklia.com/iiif/2", - "url": "https://europe-gamma.iiif.teklia.com/iiif/2", - "max_width": null, - "max_height": null - } - }, - "url": "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fmaurdor%2Fdev_png%2FYXLOKV-00.png/1403,1763,67,67/full/0/default.jpg" - }, - "rotation_angle": 0, - "mirrored": false, - "created": "2022-09-05T22:15:11.026869Z", - "classes": [], - "metadata": [ - { - "id": "42034ac4-1ed3-4893-a275-296484b417c5", - "type": "text", - "name": "Language", - "value": "arabic", - "dates": [] - }, - { - "id": "d20a52bb-28e7-407e-998e-b51f10af330a", - "type": "text", - "name": "Script", - "value": "typed", - "dates": [] - } - ], - "has_children": null, - "worker_version_id": "329f5b6e-78c9-4240-a2cf-78a746c6f897", - "confidence": null - } -] diff --git a/tests/data/images/39595d67-defb-43ab-b4af-26ba414bf4d9.jpg b/tests/data/images/39595d67-defb-43ab-b4af-26ba414bf4d9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d9ad7bddda5c09d300ea9986d2c1ac271fa80b96 Binary files /dev/null and b/tests/data/images/39595d67-defb-43ab-b4af-26ba414bf4d9.jpg differ diff --git a/tests/data/images/6d73a325-3795-4571-933b-5783177544b7.jpg b/tests/data/images/6d73a325-3795-4571-933b-5783177544b7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fea1cb6573f711401de024b59b61d3de6ecaa927 Binary files /dev/null and b/tests/data/images/6d73a325-3795-4571-933b-5783177544b7.jpg differ diff --git a/tests/data/images/80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg b/tests/data/images/80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e51321cc169cb848103690f50071a4e364b80c16 Binary files /dev/null and b/tests/data/images/80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg differ diff --git a/tests/data/images/big_height.jpg b/tests/data/images/big_height.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ad80684928d90fc3bd001501c5e7750f4b884d75 Binary files /dev/null and b/tests/data/images/big_height.jpg differ diff --git a/tests/data/images/e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg b/tests/data/images/e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3645cf5ab261fbd9e8d19182867d0b22b57cd4dc Binary files /dev/null and b/tests/data/images/e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg differ diff --git a/tests/data/images/small_height.jpg b/tests/data/images/small_height.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6776a5fff2be0d69781e2b1a06b8339edd194e5d Binary files /dev/null and b/tests/data/images/small_height.jpg differ diff --git a/tests/data/test_db.sqlite b/tests/data/test_db.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..112afe1aeddbd2af2e32a9e0ce6ff2c201cf4c86 Binary files /dev/null and b/tests/data/test_db.sqlite differ diff --git a/tests/test_extract.py b/tests/test_extract.py index 7c9380f26d3fb83acb003775112c6801d8fb311b..0bab200ecca9bfcd0a30bfd727b3ba83ce5ef2bc 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -1,70 +1,65 @@ # -*- coding: utf-8 -*- + +import json + import pytest from atr_data_generator.arguments import CommonArgs -from atr_data_generator.extract.arguments import MANUAL, FilterArgs, ImageArgs -from atr_data_generator.extract.arkindex import Element -from atr_data_generator.extract.main import ATRDataGenerator +from atr_data_generator.extract.arguments import ( + MANUAL, + FilterArgs, + ImageArgs, + SelectArgs, +) +from atr_data_generator.extract.base import EXPORT_PATH, DataGenerator @pytest.mark.parametrize( - "worker_version_ids, expected_trans_lines", + "worker_version_ids", + ( + ([]), + ([MANUAL]), # only manual transcriptions + ), +) +@pytest.mark.parametrize( + "folders, expected_trans_lines", ( - ([], 55), - ([MANUAL], 55), # only manual transcriptions in this example - (["test_1234"], 0), # no transcriptions with this worker version - (["test_1234", MANUAL], 55), # searching by worker version in this order + (["a0c4522d-2d80-4766-a01c-b9d686f41f6a"], 17), + ( + [ + "a0c4522d-2d80-4766-a01c-b9d686f41f6a", + "39b9ac5c-89ab-4258-8116-965bf0ca0419", + ], + 38, + ), ), ) def test_run_volumes_with_worker_version( - fake_run_volume_api_client, - fake_volume_id, - fake_image, + database, + folders, + image_cache, worker_version_ids, expected_trans_lines, tmp_path, - mocker, ): - atr_data_gen = ATRDataGenerator( - common=CommonArgs(dataset_name="test", output_dir=tmp_path), + atr_data_gen = DataGenerator( + common=CommonArgs( + dataset_name="test", output_dir=tmp_path, cache_dir=image_cache + ), image=ImageArgs(), + select=SelectArgs(folders=folders), filter=FilterArgs(accepted_worker_version_ids=worker_version_ids), - api_client=fake_run_volume_api_client, ) - atr_data_gen.get_image = mocker.MagicMock() - # return same fake image for all the pages - atr_data_gen.get_image.return_value = fake_image - atr_data_gen.run_folders( - folder_ids=[fake_volume_id], parent_type="page", element_type="text_line" - ) + atr_data_gen.run(database) - trans_files = list(atr_data_gen.out_line_text_dir.glob("*.txt")) + # Read json transcription file + data = json.loads((atr_data_gen.common.output_dir / EXPORT_PATH).read_text()) # assert files aren't empty - assert all(len(trans_file.read_text().strip()) > 0 for trans_file in trans_files) - assert len(trans_files) == expected_trans_lines + assert all(map(lambda x: len(x) > 0, data.values())) + assert len(data) == expected_trans_lines # each image file should have one transcription file - img_files = list(atr_data_gen.out_line_img_dir.glob("*.jpg")) - assert len(img_files) == len(trans_files) - - assert len(atr_data_gen.api_client.history) == 9 - assert atr_data_gen.api_client.responses == [] - - -def test_get_accepted_zones_filter_metadata(tmp_path, fake_run_filter_metadata): - atr_data_gen = ATRDataGenerator( - common=CommonArgs(dataset_name="test", output_dir=tmp_path), - image=ImageArgs(scale={}), - api_client=fake_run_filter_metadata, - filter=FilterArgs(accepted_metadatas={"Language": "arabic"}), - ) - - assert atr_data_gen.get_accepted_zones( - page=Element(id="fake_page"), element_type=None - ) == [ - "0b719a5a-40c7-47ec-96c8-6c3064df5485", - "18725942-24f5-4f81-a16f-62c1323c1041", - "3ebe7e48-fe2f-4533-92df-9895db05c3f5", - ] + img_files = list(atr_data_gen.common.output_dir.rglob("*.jpg")) + assert len(img_files) == len(data) diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index cec9842aad114923611e20f8a2776bbdca3fa347..9dd89bff9d6f3c3e6bbda41a6d7b5c4b5825ae30 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -1,48 +1,30 @@ # -*- coding: utf-8 -*- -import cv2 -import numpy as np import pytest -from line_image_extractor.image_utils import determine_rotate_angle +import pytest_lazyfixture +from document_processing.utils import BoundingBox + +from atr_data_generator.extract.utils import _is_vertical, resize_image_height @pytest.mark.parametrize( - "angle, expected_rotate_angle", + "bbox, vertical", ( - (-1, -1), - (0, 0), - (10, 10), - (44.9, 45), - (45.1, -45), - (45, 0), - (46, -44), - (50, -40), - (89, -1), - (90, 0), - (91, 1), - (134, 44), - (135, 0), - (136, -44), - (179, -1), - (180, 0), - (-180, 0), - (-179, 1), - (-91, -1), - (-90, 0), - (-46, 44), - (-45, 0), - (-44, -44), + (BoundingBox(x=0, y=0, height=100, width=1000), False), + (BoundingBox(x=0, y=0, height=1000, width=100), True), ), ) -def test_determine_rotate_angle(angle, expected_rotate_angle): - top_left = [300, 300] - shape = [400, 100] - # create polygon with expected angle - box = cv2.boxPoints((top_left, shape, angle)) - box = np.intp(box) - _, _, calc_angle = cv2.minAreaRect(box) - rotate_angle = determine_rotate_angle(box) +def test_is_vertical(bbox, vertical): + assert _is_vertical(bbox) is vertical + - assert ( - round(rotate_angle) == expected_rotate_angle - ), f"C, A, R: {calc_angle} === {angle} === {rotate_angle}" +@pytest.mark.parametrize("height", (128, 200)) +@pytest.mark.parametrize( + "image", + ( + pytest_lazyfixture.lazy_fixture("image_large_height"), + pytest_lazyfixture.lazy_fixture("image_small_height"), + ), +) +def test_resized_image_fixed_height(height, image): + assert resize_image_height(image, height).height == height diff --git a/tox.ini b/tox.ini index 5961a065c6eb8326bfb833b7c9980ace36c34d65..37745361d14cecfbed890568b61611ec4a8cf584 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,22 @@ [tox] -envlist = atr-data-generator +env_list = + atr-data-generator +minversion = 4.4.7 [testenv] -passenv = ARKINDEX_API_SCHEMA_URL -commands = - pytest {posargs} - +description = run the tests with pytest +package = wheel +wheel_build_env = .pkg deps = pytest + pytest-lazy-fixture pytest-responses pytest-mock ./document-processing -rrequirements.txt + +commands = + pytest {tty:--color=yes} {posargs} + +[pytest] +testpaths = tests