# -*- coding: utf-8 -*- import logging from operator import itemgetter from tempfile import _TemporaryFileWrapper from typing import List from uuid import UUID from apistar.exceptions import ErrorResponse from arkindex_export import ( Classification, DatasetElement, Element, Entity, EntityType, Image, Transcription, TranscriptionEntity, WorkerRun, WorkerVersion, open_database, ) from arkindex_export.queries import list_children from arkindex_worker.models import Element as ArkindexElement from arkindex_worker.models import Transcription as ArkindexTranscription from peewee import CharField from worker_generic_training_dataset import Extractor logger: logging.Logger = logging.getLogger(__name__) BULK_BATCH_SIZE = 50 DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None: return instance.id if instance else None class DatasetExtractorFromSQL(Extractor): def configure(self): super().configure() self.download_latest_export() def download_latest_export(self) -> None: """ Download the latest export of the current corpus. Export must be in `"done"` state. """ try: exports = list( self.api_client.paginate( "ListExports", id=self.corpus_id, ) ) except ErrorResponse as e: logger.error( f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}" ) raise e # Find the latest that is in "done" state exports: List[dict] = sorted( list(filter(lambda exp: exp["state"] == "done", exports)), key=itemgetter("updated"), reverse=True, ) assert ( len(exports) > 0 ), f"No available exports found for the corpus {self.corpus_id}." # Download latest export try: export_id: str = exports[0]["id"] logger.info(f"Downloading export ({export_id})...") self.export: _TemporaryFileWrapper = self.api_client.request( "DownloadExport", id=export_id, ) logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`") open_database(self.export.name) except ErrorResponse as e: logger.error( f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}" ) raise e def list_set_elements(self, dataset_id: UUID, set_name: str): return ( Element.select() .join(Image) .switch(Element) .join(DatasetElement, on=DatasetElement.element) .where( DatasetElement.dataset == dataset_id, DatasetElement.set_name == set_name, ) ) def list_classifications(self, element_id: UUID): return ( Classification.select() .where(Classification.element == element_id) .iterator() ) def list_transcriptions(self, element: ArkindexElement, **kwargs): return ( Transcription.select().where(Transcription.element == element.id).iterator() ) def list_transcription_entities( self, transcription: ArkindexTranscription, **kwargs ): # -> Any: return ( TranscriptionEntity.select() .where(TranscriptionEntity.transcription == transcription.id) .join(Entity, on=TranscriptionEntity.entity) .join(EntityType, on=Entity.type) ).iterator() def list_element_children(self, element: ArkindexElement, **kwargs): return list_children(element.id).iterator() def main(): DatasetExtractorFromSQL( description="Fill base-worker cache with information about dataset and extract images", ).run() if __name__ == "__main__": main()