From 95f6f6145bdc2b4d545cedfb7fff4f9ecfc46310 Mon Sep 17 00:00:00 2001
From: Manon blanco <blanco@teklia.com>
Date: Wed, 10 May 2023 09:24:25 +0000
Subject: [PATCH] Apply 5 suggestion(s) to 3 file(s)

---
 requirements.txt                          |  2 --
 worker_generic_training_dataset/utils.py  |  2 +-
 worker_generic_training_dataset/worker.py | 39 ++++++++++-------------
 3 files changed, 17 insertions(+), 26 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 96225e1..918e5dc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,2 @@
 arkindex-base-worker==0.3.3-rc3
 arkindex-export==0.1.2
-imageio==2.27.0
-opencv-python-headless==4.7.0.72
diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py
index 6bfb519..76c2bca 100644
--- a/worker_generic_training_dataset/utils.py
+++ b/worker_generic_training_dataset/utils.py
@@ -17,7 +17,7 @@ def bounding_box(polygon: list):
 
 
 def build_image_url(element):
-    x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
+    x, y, width, height = bounding_box(json.loads(element.polygon))
     return urljoin(
         element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
     )
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 518f562..88a16a6 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -182,18 +182,16 @@ class DatasetExtractor(BaseWorker):
             )
             for transcription in list_transcriptions(element)
         ]
-        if not transcriptions:
-            return []
-
-        logger.info(f"Inserting {len(transcriptions)} transcription(s)")
-        with cache_database.atomic():
-            CachedTranscription.bulk_create(
-                model_list=transcriptions,
-                batch_size=BULK_BATCH_SIZE,
-            )
+        if transcriptions:
+            logger.info(f"Inserting {len(transcriptions)} transcription(s)")
+            with cache_database.atomic():
+                CachedTranscription.bulk_create(
+                    model_list=transcriptions,
+                    batch_size=BULK_BATCH_SIZE,
+                )
         return transcriptions
 
-    def insert_entities(self, transcriptions: List[CachedTranscription]):
+    def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
         logger.info("Listing entities")
         extracted_entities = []
         for transcription in transcriptions:
@@ -218,19 +216,14 @@ class DatasetExtractor(BaseWorker):
                         ),
                     )
                 )
-        if not extracted_entities:
-            # Early return if no entities found
-            return
-
-        entities, transcription_entities = zip(*extracted_entities)
-
-        # First insert entities since they are foreign keys on transcription entities
-        logger.info(f"Inserting {len(entities)} entities")
-        with cache_database.atomic():
-            CachedEntity.bulk_create(
-                model_list=entities,
-                batch_size=BULK_BATCH_SIZE,
-            )
+        if entities:
+            # First insert entities since they are foreign keys on transcription entities
+            logger.info(f"Inserting {len(entities)} entities")
+            with cache_database.atomic():
+                CachedEntity.bulk_create(
+                    model_list=entities,
+                    batch_size=BULK_BATCH_SIZE,
+                )
 
         if transcription_entities:
             # Insert transcription entities
-- 
GitLab