From ac10b18e909c2b9845cae57164caf08f2f92577f Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Fri, 31 Mar 2023 17:37:48 +0200 Subject: [PATCH] fix image caching --- worker_generic_training_dataset/db.py | 14 ++++++++------ worker_generic_training_dataset/utils.py | 4 ++-- worker_generic_training_dataset/worker.py | 8 ++++++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index c4a38ea..05ccb67 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -81,9 +81,11 @@ def retrieve_entities(transcription: CachedTranscription): .join(Entity, on=TranscriptionEntity.entity) .join(EntityType, on=Entity.type) ) - return zip( - *[ - parse_entities(entity_data, transcription) - for entity_data in query.namedtuples() - ] - ) + data = [ + parse_entities(entity_data, transcription) + for entity_data in query.namedtuples() + ] + if not data: + return [], [] + + return zip(*data) diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py index 151f903..b6e8c3e 100644 --- a/worker_generic_training_dataset/utils.py +++ b/worker_generic_training_dataset/utils.py @@ -33,7 +33,7 @@ def build_image_url(element: Element): def download_image(element: Element, folder: Path): """ - Download the image to `folder / {element.id}.jpg` + Download the image to `folder / {element.image.id}.jpg` """ tries = 1 # retry loop @@ -43,7 +43,7 @@ def download_image(element: Element, folder: Path): try: image = iio.imread(build_image_url(element)) cv2.imwrite( - str(folder / f"{element.id}.jpg"), + str(folder / f"{element.image.id}.jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB), ) break diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 48420c3..cbd0555 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -49,6 +49,9 @@ class DatasetExtractor(ElementsWorker): # Initialize db that will be written self.initialize_database() + # Cached Images downloaded and created in DB + self.cached_images = dict() + def initialize_database(self): # Create db at # - self.workdir / "db.sqlite" in Arkindex mode @@ -101,13 +104,14 @@ class DatasetExtractor(ElementsWorker): def insert_element(self, element: Element, parent_id: str): logger.info(f"Processing element ({element.id})") - if element.image: + if element.image and element.image.id not in self.cached_images: # Download image logger.info("Downloading image") download_image(element, folder=IMAGE_FOLDER) # Insert image logger.info("Inserting image") - CachedImage.create( + # Store images in case some other elements use it as well + self.cached_images[element.image.id] = CachedImage.create( id=element.image.id, width=element.image.width, height=element.image.height, -- GitLab