From 95f6f6145bdc2b4d545cedfb7fff4f9ecfc46310 Mon Sep 17 00:00:00 2001 From: Manon blanco <blanco@teklia.com> Date: Wed, 10 May 2023 09:24:25 +0000 Subject: [PATCH] Apply 5 suggestion(s) to 3 file(s) --- requirements.txt | 2 -- worker_generic_training_dataset/utils.py | 2 +- worker_generic_training_dataset/worker.py | 39 ++++++++++------------- 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/requirements.txt b/requirements.txt index 96225e1..918e5dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,2 @@ arkindex-base-worker==0.3.3-rc3 arkindex-export==0.1.2 -imageio==2.27.0 -opencv-python-headless==4.7.0.72 diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py index 6bfb519..76c2bca 100644 --- a/worker_generic_training_dataset/utils.py +++ b/worker_generic_training_dataset/utils.py @@ -17,7 +17,7 @@ def bounding_box(polygon: list): def build_image_url(element): - x, y, width, height = bounding_box(ast.literal_eval(element.polygon)) + x, y, width, height = bounding_box(json.loads(element.polygon)) return urljoin( element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" ) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 518f562..88a16a6 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -182,18 +182,16 @@ class DatasetExtractor(BaseWorker): ) for transcription in list_transcriptions(element) ] - if not transcriptions: - return [] - - logger.info(f"Inserting {len(transcriptions)} transcription(s)") - with cache_database.atomic(): - CachedTranscription.bulk_create( - model_list=transcriptions, - batch_size=BULK_BATCH_SIZE, - ) + if transcriptions: + logger.info(f"Inserting {len(transcriptions)} transcription(s)") + with cache_database.atomic(): + CachedTranscription.bulk_create( + model_list=transcriptions, + batch_size=BULK_BATCH_SIZE, + ) return transcriptions - def insert_entities(self, transcriptions: List[CachedTranscription]): + def insert_entities(self, transcriptions: List[CachedTranscription]) -> None: logger.info("Listing entities") extracted_entities = [] for transcription in transcriptions: @@ -218,19 +216,14 @@ class DatasetExtractor(BaseWorker): ), ) ) - if not extracted_entities: - # Early return if no entities found - return - - entities, transcription_entities = zip(*extracted_entities) - - # First insert entities since they are foreign keys on transcription entities - logger.info(f"Inserting {len(entities)} entities") - with cache_database.atomic(): - CachedEntity.bulk_create( - model_list=entities, - batch_size=BULK_BATCH_SIZE, - ) + if entities: + # First insert entities since they are foreign keys on transcription entities + logger.info(f"Inserting {len(entities)} entities") + with cache_database.atomic(): + CachedEntity.bulk_create( + model_list=entities, + batch_size=BULK_BATCH_SIZE, + ) if transcription_entities: # Insert transcription entities -- GitLab