From ec6c64fa18e90d8a203b7a991eafd1b30cf7a89d Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Fri, 31 Mar 2023 17:31:02 +0200 Subject: [PATCH] first working version --- requirements.txt | 3 + worker_generic_training_dataset/db.py | 89 +++++++++++ worker_generic_training_dataset/exceptions.py | 36 +++++ worker_generic_training_dataset/utils.py | 55 +++++++ worker_generic_training_dataset/worker.py | 140 ++++++++++++++++-- 5 files changed, 310 insertions(+), 13 deletions(-) create mode 100644 worker_generic_training_dataset/db.py create mode 100644 worker_generic_training_dataset/exceptions.py create mode 100644 worker_generic_training_dataset/utils.py diff --git a/requirements.txt b/requirements.txt index 5ff1be6..5611000 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,4 @@ arkindex-base-worker==0.3.2 +arkindex-export==0.1.2 +imageio==2.27.0 +opencv-python==4.7.0.72 diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py new file mode 100644 index 0000000..c4a38ea --- /dev/null +++ b/worker_generic_training_dataset/db.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +from typing import NamedTuple + +from arkindex_export import Classification +from arkindex_export.models import ( + Element, + Entity, + EntityType, + Transcription, + TranscriptionEntity, +) +from arkindex_worker.cache import ( + CachedElement, + CachedEntity, + CachedTranscription, + CachedTranscriptionEntity, +) + +DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" + + +def retrieve_element(element_id: str): + return Element.get_by_id(element_id) + + +def list_classifications(element: Element): + query = Classification.select().where(Classification.element == element) + return query + + +def parse_transcription(transcription: NamedTuple, element: CachedElement): + return CachedTranscription( + id=transcription.id, + element=element, + text=transcription.text, + confidence=transcription.confidence, + orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, + worker_version_id=transcription.worker_version.id + if transcription.worker_version + else None, + ) + + +def list_transcriptions(element: CachedElement): + query = Transcription.select().where(Transcription.element_id == element.id) + return [parse_transcription(x, element) for x in query] + + +def parse_entities(data: NamedTuple, transcription: CachedTranscription): + entity = CachedEntity( + id=data.entity_id, + type=data.type, + name=data.name, + validated=data.validated, + metas=data.metas, + ) + return entity, CachedTranscriptionEntity( + id=data.transcription_entity_id, + transcription=transcription, + entity=entity, + offset=data.offset, + length=data.length, + confidence=data.confidence, + ) + + +def retrieve_entities(transcription: CachedTranscription): + query = ( + TranscriptionEntity.select( + TranscriptionEntity.id.alias("transcription_entity_id"), + TranscriptionEntity.length.alias("length"), + TranscriptionEntity.offset.alias("offset"), + TranscriptionEntity.confidence.alias("confidence"), + Entity.id.alias("entity_id"), + EntityType.name.alias("type"), + Entity.name, + Entity.validated, + Entity.metas, + ) + .where(TranscriptionEntity.transcription_id == transcription.id) + .join(Entity, on=TranscriptionEntity.entity) + .join(EntityType, on=Entity.type) + ) + return zip( + *[ + parse_entities(entity_data, transcription) + for entity_data in query.namedtuples() + ] + ) diff --git a/worker_generic_training_dataset/exceptions.py b/worker_generic_training_dataset/exceptions.py new file mode 100644 index 0000000..062d580 --- /dev/null +++ b/worker_generic_training_dataset/exceptions.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + + +class ElementProcessingError(Exception): + """ + Raised when a problem is encountered while processing an element + """ + + element_id: str + """ + ID of the element being processed. + """ + + def __init__(self, element_id: str, *args: object) -> None: + super().__init__(*args) + self.element_id = element_id + + +class ImageDownloadError(ElementProcessingError): + """ + Raised when an element's image could not be downloaded + """ + + error: Exception + """ + Error encountered. + """ + + def __init__(self, element_id: str, error: Exception, *args: object) -> None: + super().__init__(element_id, *args) + self.error = error + + def __str__(self) -> str: + return ( + f"Couldn't retrieve image of element ({self.element_id}: {str(self.error)})" + ) diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py new file mode 100644 index 0000000..151f903 --- /dev/null +++ b/worker_generic_training_dataset/utils.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +import ast +import logging +import time +from pathlib import Path +from urllib.parse import urljoin + +import cv2 +import imageio.v2 as iio +from arkindex_export.models import Element +from worker_generic_training_dataset.exceptions import ImageDownloadError + +logger = logging.getLogger(__name__) +MAX_RETRIES = 5 + + +def bounding_box(polygon: list): + """ + Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points) + """ + all_x, all_y = zip(*polygon) + x, y = min(all_x), min(all_y) + width, height = max(all_x) - x, max(all_y) - y + return int(x), int(y), int(width), int(height) + + +def build_image_url(element: Element): + x, y, width, height = bounding_box(ast.literal_eval(element.polygon)) + return urljoin( + element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" + ) + + +def download_image(element: Element, folder: Path): + """ + Download the image to `folder / {element.id}.jpg` + """ + tries = 1 + # retry loop + while True: + if tries > MAX_RETRIES: + raise ImageDownloadError(element.id, Exception("Maximum retries reached.")) + try: + image = iio.imread(build_image_url(element)) + cv2.imwrite( + str(folder / f"{element.id}.jpg"), + cv2.cvtColor(image, cv2.COLOR_BGR2RGB), + ) + break + except TimeoutError: + logger.warning("Timeout, retry in 1 second.") + time.sleep(1) + tries += 1 + except Exception as e: + raise ImageDownloadError(element.id, e) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 5b18a7b..48420c3 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,13 +1,38 @@ # -*- coding: utf-8 -*- import logging import operator +from pathlib import Path from apistar.exceptions import ErrorResponse -from arkindex_worker.cache import create_tables, create_version_table, init_cache_db +from arkindex_export import open_database +from arkindex_export.models import Element +from arkindex_export.queries import list_children +from arkindex_worker.cache import ( + CachedClassification, + CachedElement, + CachedEntity, + CachedImage, + CachedTranscription, + CachedTranscriptionEntity, + create_tables, + create_version_table, +) +from arkindex_worker.cache import db as cache_database +from arkindex_worker.cache import init_cache_db from arkindex_worker.worker import ElementsWorker +from worker_generic_training_dataset.db import ( + list_classifications, + list_transcriptions, + retrieve_element, + retrieve_entities, +) +from worker_generic_training_dataset.utils import download_image logger = logging.getLogger(__name__) +IMAGE_FOLDER = Path("images") +BULK_BATCH_SIZE = 50 + class DatasetExtractor(ElementsWorker): def configure(self): @@ -26,12 +51,10 @@ class DatasetExtractor(ElementsWorker): def initialize_database(self): # Create db at - # - self.workdir.parent / self.task_id in Arkindex mode + # - self.workdir / "db.sqlite" in Arkindex mode # - self.args.database in dev mode database_path = ( - self.args.database - if self.is_read_only - else self.workdir.parent / self.task_id + self.args.database if self.is_read_only else self.workdir / "db.sqlite" ) init_cache_db(database_path) @@ -49,8 +72,9 @@ class DatasetExtractor(ElementsWorker): )["results"] except ErrorResponse as e: logger.error( - f"Could not list exports of corpus ({self.corpus_id}): {str(e)}" + f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}" ) + raise e # Find latest that is in "done" state exports = sorted( @@ -62,21 +86,111 @@ class DatasetExtractor(ElementsWorker): # Download latest it in a tmpfile try: export_id = exports[0]["id"] - download_url = self.api_client.request( + logger.info(f"Downloading export ({export_id})...") + self.export = self.api_client.request( "DownloadExport", id=export_id, - )["results"] + ) + logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`") + open_database(self.export.name) except ErrorResponse as e: logger.error( f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e)}" ) - print(download_url) + raise e + + def insert_element(self, element: Element, parent_id: str): + logger.info(f"Processing element ({element.id})") + if element.image: + # Download image + logger.info("Downloading image") + download_image(element, folder=IMAGE_FOLDER) + # Insert image + logger.info("Inserting image") + CachedImage.create( + id=element.image.id, + width=element.image.width, + height=element.image.height, + url=element.image.url, + ) - def process_element(self, element): - ... + # Insert element + logger.info("Inserting element") + cached_element = CachedElement.create( + id=element.id, + parent_id=parent_id, + type=element.type, + image=element.image.id if element.image else None, + polygon=element.polygon, + rotation_angle=element.rotation_angle, + mirrored=element.mirrored, + worker_version_id=element.worker_version.id + if element.worker_version + else None, + confidence=element.confidence, + ) - # List Transcriptions, Metas - # + # Insert classifications + logger.info("Listing classifications") + classifications = [ + CachedClassification( + id=classification.id, + element=cached_element, + class_name=classification.class_name, + confidence=classification.confidence, + state=classification.state, + ) + for classification in list_classifications(element) + ] + if classifications: + logger.info(f"Inserting {len(classifications)} classifications") + with cache_database.atomic(): + CachedClassification.bulk_create( + model_list=classifications, + batch_size=BULK_BATCH_SIZE, + ) + + # Insert transcriptions + logger.info("Listing transcriptions") + transcriptions = list_transcriptions(cached_element) + if transcriptions: + logger.info(f"Inserting {len(transcriptions)} transcriptions") + with cache_database.atomic(): + CachedTranscription.bulk_create( + model_list=transcriptions, + batch_size=BULK_BATCH_SIZE, + ) + + logger.info("Listing entities") + entities, transcription_entities = [], [] + for transcription in transcriptions: + ents, transc_ents = retrieve_entities(transcription) + entities.extend(ents) + transcription_entities.extend(transc_ents) + + if entities: + logger.info(f"Inserting {len(entities)} entities") + with cache_database.atomic(): + CachedEntity.bulk_create( + model_list=entities, + batch_size=BULK_BATCH_SIZE, + ) + # Insert transcription entities + logger.info( + f"Inserting {len(transcription_entities)} transcription entities" + ) + with cache_database.atomic(): + CachedTranscriptionEntity.bulk_create( + model_list=transcription_entities, + batch_size=BULK_BATCH_SIZE, + ) + + def process_element(self, element): + # Retrieve parent and create parent + parent = retrieve_element(element.id) + self.insert_element(parent, parent_id=None) + for child in list_children(parent_id=element.id): + self.insert_element(child, parent_id=element.id) def main(): -- GitLab