From db4bacb7f679ee85d1d83cbbb7f23adf5e9a34e0 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Fri, 1 Dec 2023 14:14:24 +0100 Subject: [PATCH] Use helper to retrieve ID --- worker_generic_training_dataset/worker.py | 33 +++++++++-------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 6384ba4..7830657 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -11,7 +11,7 @@ from typing import List, Optional from uuid import UUID from apistar.exceptions import ErrorResponse -from arkindex_export import Element, open_database +from arkindex_export import Element, WorkerRun, WorkerVersion, open_database from arkindex_export.queries import list_children from arkindex_worker.cache import ( CachedClassification, @@ -32,6 +32,7 @@ from arkindex_worker.models import Dataset from arkindex_worker.models import Element as WorkerElement from arkindex_worker.utils import create_tar_zst_archive from arkindex_worker.worker import DatasetWorker +from peewee import CharField from worker_generic_training_dataset.db import ( list_classifications, list_transcription_entities, @@ -50,6 +51,10 @@ def _format_element(element: WorkerElement) -> Element: return retrieve_element(element.id) +def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None: + return instance.id if instance else None + + class DatasetExtractor(DatasetWorker): def configure(self) -> None: self.args: Namespace = self.parser.parse_args() @@ -147,9 +152,7 @@ class DatasetExtractor(DatasetWorker): class_name=classification.class_name, confidence=classification.confidence, state=classification.state, - worker_run_id=classification.worker_run.id - if classification.worker_run - else None, + worker_run_id=get_object_id(classification.worker_run), ) for classification in list_classifications(element.id) ] @@ -172,12 +175,8 @@ class DatasetExtractor(DatasetWorker): text=transcription.text, confidence=transcription.confidence, orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, - worker_version_id=transcription.worker_version.id - if transcription.worker_version - else None, - worker_run_id=transcription.worker_run.id - if transcription.worker_run - else None, + worker_version_id=get_object_id(transcription.worker_version), + worker_run_id=get_object_id(transcription.worker_run), ) for transcription in list_transcriptions(element.id) ] @@ -202,9 +201,7 @@ class DatasetExtractor(DatasetWorker): name=transcription_entity.entity.name, validated=transcription_entity.entity.validated, metas=transcription_entity.entity.metas, - worker_run_id=transcription_entity.entity.worker_run.id - if transcription_entity.entity.worker_run - else None, + worker_run_id=get_object_id(transcription_entity.entity.worker_run), ) entities.append(entity) transcription_entities.append( @@ -215,9 +212,7 @@ class DatasetExtractor(DatasetWorker): offset=transcription_entity.offset, length=transcription_entity.length, confidence=transcription_entity.confidence, - worker_run_id=transcription_entity.worker_run.id - if transcription_entity.worker_run - else None, + worker_run_id=get_object_id(transcription_entity.worker_run), ) ) if entities: @@ -289,10 +284,8 @@ class DatasetExtractor(DatasetWorker): polygon=element.polygon, rotation_angle=element.rotation_angle, mirrored=element.mirrored, - worker_version_id=element.worker_version.id - if element.worker_version - else None, - worker_run_id=element.worker_run.id if element.worker_run else None, + worker_version_id=get_object_id(element.worker_version), + worker_run_id=get_object_id(element.worker_run), confidence=element.confidence, ) -- GitLab