From ac10b18e909c2b9845cae57164caf08f2f92577f Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 31 Mar 2023 17:37:48 +0200
Subject: [PATCH] fix image caching

---
 worker_generic_training_dataset/db.py     | 14 ++++++++------
 worker_generic_training_dataset/utils.py  |  4 ++--
 worker_generic_training_dataset/worker.py |  8 ++++++--
 3 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py
index c4a38ea..05ccb67 100644
--- a/worker_generic_training_dataset/db.py
+++ b/worker_generic_training_dataset/db.py
@@ -81,9 +81,11 @@ def retrieve_entities(transcription: CachedTranscription):
         .join(Entity, on=TranscriptionEntity.entity)
         .join(EntityType, on=Entity.type)
     )
-    return zip(
-        *[
-            parse_entities(entity_data, transcription)
-            for entity_data in query.namedtuples()
-        ]
-    )
+    data = [
+        parse_entities(entity_data, transcription)
+        for entity_data in query.namedtuples()
+    ]
+    if not data:
+        return [], []
+
+    return zip(*data)
diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py
index 151f903..b6e8c3e 100644
--- a/worker_generic_training_dataset/utils.py
+++ b/worker_generic_training_dataset/utils.py
@@ -33,7 +33,7 @@ def build_image_url(element: Element):
 
 def download_image(element: Element, folder: Path):
     """
-    Download the image to `folder / {element.id}.jpg`
+    Download the image to `folder / {element.image.id}.jpg`
     """
     tries = 1
     # retry loop
@@ -43,7 +43,7 @@ def download_image(element: Element, folder: Path):
         try:
             image = iio.imread(build_image_url(element))
             cv2.imwrite(
-                str(folder / f"{element.id}.jpg"),
+                str(folder / f"{element.image.id}.jpg"),
                 cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
             )
             break
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 48420c3..cbd0555 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -49,6 +49,9 @@ class DatasetExtractor(ElementsWorker):
         # Initialize db that will be written
         self.initialize_database()
 
+        # Cached Images downloaded and created in DB
+        self.cached_images = dict()
+
     def initialize_database(self):
         # Create db at
         # - self.workdir / "db.sqlite" in Arkindex mode
@@ -101,13 +104,14 @@ class DatasetExtractor(ElementsWorker):
 
     def insert_element(self, element: Element, parent_id: str):
         logger.info(f"Processing element ({element.id})")
-        if element.image:
+        if element.image and element.image.id not in self.cached_images:
             # Download image
             logger.info("Downloading image")
             download_image(element, folder=IMAGE_FOLDER)
             # Insert image
             logger.info("Inserting image")
-            CachedImage.create(
+            # Store images in case some other elements use it as well
+            self.cached_images[element.image.id] = CachedImage.create(
                 id=element.image.id,
                 width=element.image.width,
                 height=element.image.height,
-- 
GitLab