diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index 6d413747f83a374b571653e950e175be8bb317fd..e81e652051f7e93f62dca597857506b7196d1b01 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -28,7 +28,7 @@ def list_classifications(element_id: str): return Classification.select().where(Classification.element_id == element_id) -def parse_transcription(transcription: NamedTuple, element: CachedElement): +def parse_transcription(transcription: Transcription, element: CachedElement): return CachedTranscription( id=transcription.id, element=element, @@ -44,7 +44,7 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement): def list_transcriptions(element: CachedElement): query = Transcription.select().where(Transcription.element_id == element.id) - return [parse_transcription(x, element) for x in query] + return [parse_transcription(transcription, element) for transcription in query] def parse_entities(data: NamedTuple, transcription: CachedTranscription): @@ -92,7 +92,7 @@ def retrieve_entities(transcription: CachedTranscription): return zip(*data) -def get_children(parent_id, element_type=None): +def get_children(parent_id: UUID, element_type=None): query = list_children(parent_id).join(Image) if element_type: query = query.where(Element.type == element_type) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 30120a142d8c518169d5c14acecafb46f266446f..81743500786c39f1571c2935874b6dac823c8d03 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -63,7 +63,7 @@ class DatasetExtractor(BaseWorker): # Initialize db that will be written self.initialize_database() - # Cached Images downloaded and created in DB + # CachedImage downloaded and created in DB self.cached_images = dict() def read_training_related_information(self): @@ -72,7 +72,6 @@ class DatasetExtractor(BaseWorker): - train_folder_id - validation_folder_id - test_folder_id (optional) - - model_id """ logger.info("Retrieving information from process_information") @@ -92,12 +91,12 @@ class DatasetExtractor(BaseWorker): # - self.workdir / "db.sqlite" in Arkindex mode # - self.args.database in dev mode database_path = ( - Path(self.args.database) + self.args.database if self.is_read_only else self.work_dir / "db.sqlite" ) if database_path.exists(): - database_path.unlink() + database_path.unlink(missing_ok=True) init_cache_db(database_path) @@ -118,13 +117,13 @@ class DatasetExtractor(BaseWorker): ) raise e - # Find latest that is in "done" state + # Find the latest that is in "done" state exports = sorted( list(filter(lambda exp: exp["state"] == "done", exports)), key=operator.itemgetter("updated"), reverse=True, ) - assert len(exports) > 0, "No available exports found." + assert len(exports) > 0, f"No available exports found for the corpus {self.corpus_id}." # Download latest it in a tmpfile try: @@ -172,6 +171,8 @@ class DatasetExtractor(BaseWorker): mirrored=element.mirrored, worker_version_id=element.worker_version if element.worker_version + worker_version_id=element.worker_version.id + if element.worker_version else None, confidence=element.confidence, ) @@ -207,12 +208,9 @@ class DatasetExtractor(BaseWorker): batch_size=BULK_BATCH_SIZE, ) + # Insert entities logger.info("Listing entities") - entities, transcription_entities = [], [] - for transcription in transcriptions: - ents, transc_ents = retrieve_entities(transcription) - entities.extend(ents) - transcription_entities.extend(transc_ents) + entities, transcription_entities = zip(*[retrieve_entities(transcription) for transcription in transcriptions)) if entities: logger.info(f"Inserting {len(entities)} entities") @@ -276,7 +274,6 @@ class DatasetExtractor(BaseWorker): # TAR + ZSTD Image folder and store as task artifact zstd_archive_path = self.work_dir / "arkindex_data.zstd" logger.info(f"Compressing the images to {zstd_archive_path}") - create_tar_zst_archive(source=image_folder, destination=zstd_archive_path) # Cleanup image folder @@ -285,7 +282,7 @@ class DatasetExtractor(BaseWorker): def main(): DatasetExtractor( - description="Fill base-worker cache with information about dataset and extract images" + description="Fill base-worker cache with information about dataset and extract images", support_cache=True ).run()