diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index c4a38ea9e9acc61b91cb19b3a010bf1f51aae689..05ccb6790f422a779220c51ea62188bc6f1c9c1e 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -81,9 +81,11 @@ def retrieve_entities(transcription: CachedTranscription): .join(Entity, on=TranscriptionEntity.entity) .join(EntityType, on=Entity.type) ) - return zip( - *[ - parse_entities(entity_data, transcription) - for entity_data in query.namedtuples() - ] - ) + data = [ + parse_entities(entity_data, transcription) + for entity_data in query.namedtuples() + ] + if not data: + return [], [] + + return zip(*data) diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py index 151f903fd505c92a8cb079a767fcc0531e1134bf..b6e8c3edd2d87b46f56af14b4e32b1c4a3e54829 100644 --- a/worker_generic_training_dataset/utils.py +++ b/worker_generic_training_dataset/utils.py @@ -33,7 +33,7 @@ def build_image_url(element: Element): def download_image(element: Element, folder: Path): """ - Download the image to `folder / {element.id}.jpg` + Download the image to `folder / {element.image.id}.jpg` """ tries = 1 # retry loop @@ -43,7 +43,7 @@ def download_image(element: Element, folder: Path): try: image = iio.imread(build_image_url(element)) cv2.imwrite( - str(folder / f"{element.id}.jpg"), + str(folder / f"{element.image.id}.jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB), ) break diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 48420c3a6c452b123ab3aad90cab348dc557faf4..cbd0555a79d922741e97b9137ddea25f6f4539f5 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -49,6 +49,9 @@ class DatasetExtractor(ElementsWorker): # Initialize db that will be written self.initialize_database() + # Cached Images downloaded and created in DB + self.cached_images = dict() + def initialize_database(self): # Create db at # - self.workdir / "db.sqlite" in Arkindex mode @@ -101,13 +104,14 @@ class DatasetExtractor(ElementsWorker): def insert_element(self, element: Element, parent_id: str): logger.info(f"Processing element ({element.id})") - if element.image: + if element.image and element.image.id not in self.cached_images: # Download image logger.info("Downloading image") download_image(element, folder=IMAGE_FOLDER) # Insert image logger.info("Inserting image") - CachedImage.create( + # Store images in case some other elements use it as well + self.cached_images[element.image.id] = CachedImage.create( id=element.image.id, width=element.image.width, height=element.image.height,