Skip to content
Snippets Groups Projects
Commit ac36224e authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Apply 14 suggestion(s) to 2 file(s)

parent 91bfb128
No related branches found
No related tags found
1 merge request!2Implement worker
Pipeline #81806 failed
...@@ -28,7 +28,7 @@ def list_classifications(element_id: str): ...@@ -28,7 +28,7 @@ def list_classifications(element_id: str):
return Classification.select().where(Classification.element_id == element_id) return Classification.select().where(Classification.element_id == element_id)
def parse_transcription(transcription: NamedTuple, element: CachedElement): def parse_transcription(transcription: Transcription, element: CachedElement):
return CachedTranscription( return CachedTranscription(
id=transcription.id, id=transcription.id,
element=element, element=element,
...@@ -44,7 +44,7 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement): ...@@ -44,7 +44,7 @@ def parse_transcription(transcription: NamedTuple, element: CachedElement):
def list_transcriptions(element: CachedElement): def list_transcriptions(element: CachedElement):
query = Transcription.select().where(Transcription.element_id == element.id) query = Transcription.select().where(Transcription.element_id == element.id)
return [parse_transcription(x, element) for x in query] return [parse_transcription(transcription, element) for transcription in query]
def parse_entities(data: NamedTuple, transcription: CachedTranscription): def parse_entities(data: NamedTuple, transcription: CachedTranscription):
...@@ -92,7 +92,7 @@ def retrieve_entities(transcription: CachedTranscription): ...@@ -92,7 +92,7 @@ def retrieve_entities(transcription: CachedTranscription):
return zip(*data) return zip(*data)
def get_children(parent_id, element_type=None): def get_children(parent_id: UUID, element_type=None):
query = list_children(parent_id).join(Image) query = list_children(parent_id).join(Image)
if element_type: if element_type:
query = query.where(Element.type == element_type) query = query.where(Element.type == element_type)
......
...@@ -63,7 +63,7 @@ class DatasetExtractor(BaseWorker): ...@@ -63,7 +63,7 @@ class DatasetExtractor(BaseWorker):
# Initialize db that will be written # Initialize db that will be written
self.initialize_database() self.initialize_database()
# Cached Images downloaded and created in DB # CachedImage downloaded and created in DB
self.cached_images = dict() self.cached_images = dict()
def read_training_related_information(self): def read_training_related_information(self):
...@@ -72,7 +72,6 @@ class DatasetExtractor(BaseWorker): ...@@ -72,7 +72,6 @@ class DatasetExtractor(BaseWorker):
- train_folder_id - train_folder_id
- validation_folder_id - validation_folder_id
- test_folder_id (optional) - test_folder_id (optional)
- model_id
""" """
logger.info("Retrieving information from process_information") logger.info("Retrieving information from process_information")
...@@ -92,12 +91,12 @@ class DatasetExtractor(BaseWorker): ...@@ -92,12 +91,12 @@ class DatasetExtractor(BaseWorker):
# - self.workdir / "db.sqlite" in Arkindex mode # - self.workdir / "db.sqlite" in Arkindex mode
# - self.args.database in dev mode # - self.args.database in dev mode
database_path = ( database_path = (
Path(self.args.database) self.args.database
if self.is_read_only if self.is_read_only
else self.work_dir / "db.sqlite" else self.work_dir / "db.sqlite"
) )
if database_path.exists(): if database_path.exists():
database_path.unlink() database_path.unlink(missing_ok=True)
init_cache_db(database_path) init_cache_db(database_path)
...@@ -118,13 +117,13 @@ class DatasetExtractor(BaseWorker): ...@@ -118,13 +117,13 @@ class DatasetExtractor(BaseWorker):
) )
raise e raise e
# Find latest that is in "done" state # Find the latest that is in "done" state
exports = sorted( exports = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)), list(filter(lambda exp: exp["state"] == "done", exports)),
key=operator.itemgetter("updated"), key=operator.itemgetter("updated"),
reverse=True, reverse=True,
) )
assert len(exports) > 0, "No available exports found." assert len(exports) > 0, f"No available exports found for the corpus {self.corpus_id}."
# Download latest it in a tmpfile # Download latest it in a tmpfile
try: try:
...@@ -172,6 +171,8 @@ class DatasetExtractor(BaseWorker): ...@@ -172,6 +171,8 @@ class DatasetExtractor(BaseWorker):
mirrored=element.mirrored, mirrored=element.mirrored,
worker_version_id=element.worker_version worker_version_id=element.worker_version
if element.worker_version if element.worker_version
worker_version_id=element.worker_version.id
if element.worker_version
else None, else None,
confidence=element.confidence, confidence=element.confidence,
) )
...@@ -207,12 +208,9 @@ class DatasetExtractor(BaseWorker): ...@@ -207,12 +208,9 @@ class DatasetExtractor(BaseWorker):
batch_size=BULK_BATCH_SIZE, batch_size=BULK_BATCH_SIZE,
) )
# Insert entities
logger.info("Listing entities") logger.info("Listing entities")
entities, transcription_entities = [], [] entities, transcription_entities = zip(*[retrieve_entities(transcription) for transcription in transcriptions))
for transcription in transcriptions:
ents, transc_ents = retrieve_entities(transcription)
entities.extend(ents)
transcription_entities.extend(transc_ents)
if entities: if entities:
logger.info(f"Inserting {len(entities)} entities") logger.info(f"Inserting {len(entities)} entities")
...@@ -276,7 +274,6 @@ class DatasetExtractor(BaseWorker): ...@@ -276,7 +274,6 @@ class DatasetExtractor(BaseWorker):
# TAR + ZSTD Image folder and store as task artifact # TAR + ZSTD Image folder and store as task artifact
zstd_archive_path = self.work_dir / "arkindex_data.zstd" zstd_archive_path = self.work_dir / "arkindex_data.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}") logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zst_archive(source=image_folder, destination=zstd_archive_path) create_tar_zst_archive(source=image_folder, destination=zstd_archive_path)
# Cleanup image folder # Cleanup image folder
...@@ -285,7 +282,7 @@ class DatasetExtractor(BaseWorker): ...@@ -285,7 +282,7 @@ class DatasetExtractor(BaseWorker):
def main(): def main():
DatasetExtractor( DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images" description="Fill base-worker cache with information about dataset and extract images", support_cache=True
).run() ).run()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment