Skip to content
Snippets Groups Projects
from_sql.py 3.99 KiB
Newer Older
# -*- coding: utf-8 -*-
import logging
from operator import itemgetter
from tempfile import _TemporaryFileWrapper
from typing import List
from uuid import UUID

from apistar.exceptions import ErrorResponse
from arkindex_export import (
    Classification,
    DatasetElement,
    Element,
    Entity,
    EntityType,
    Image,
    Transcription,
    TranscriptionEntity,
    WorkerRun,
    WorkerVersion,
    open_database,
)
from arkindex_export.queries import list_children
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Transcription as ArkindexTranscription
from peewee import CharField
from worker_generic_training_dataset import Extractor

logger: logging.Logger = logging.getLogger(__name__)

BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"


def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None:
    return instance.id if instance else None


class DatasetExtractorFromSQL(Extractor):
    def configure(self):
        super().configure()
        self.download_latest_export()

    def download_latest_export(self) -> None:
        """
        Download the latest export of the current corpus.
        Export must be in `"done"` state.
        """
        try:
            exports = list(
                self.api_client.paginate(
                    "ListExports",
                    id=self.corpus_id,
                )
            )
        except ErrorResponse as e:
            logger.error(
                f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
            )
            raise e

        # Find the latest that is in "done" state
        exports: List[dict] = sorted(
            list(filter(lambda exp: exp["state"] == "done", exports)),
            key=itemgetter("updated"),
            reverse=True,
        )
        assert (
            len(exports) > 0
        ), f"No available exports found for the corpus {self.corpus_id}."

        # Download latest export
        try:
            export_id: str = exports[0]["id"]
            logger.info(f"Downloading export ({export_id})...")
            self.export: _TemporaryFileWrapper = self.api_client.request(
                "DownloadExport",
                id=export_id,
            )
            logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
            open_database(self.export.name)
        except ErrorResponse as e:
            logger.error(
                f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
            )
            raise e

    def list_set_elements(self, dataset_id: UUID, set_name: str):
        return (
            Element.select()
            .join(Image)
            .switch(Element)
            .join(DatasetElement, on=DatasetElement.element)
            .where(
                DatasetElement.dataset == dataset_id,
                DatasetElement.set_name == set_name,
            )
        )

    def list_classifications(self, element_id: UUID):
        return (
            Classification.select()
            .where(Classification.element == element_id)
            .iterator()
        )

    def list_transcriptions(self, element: ArkindexElement, **kwargs):
        return (
            Transcription.select().where(Transcription.element == element.id).iterator()
        )

    def list_transcription_entities(
        self, transcription: ArkindexTranscription, **kwargs
    ):  # -> Any:
        return (
            TranscriptionEntity.select()
            .where(TranscriptionEntity.transcription == transcription.id)
            .join(Entity, on=TranscriptionEntity.entity)
            .join(EntityType, on=Entity.type)
        ).iterator()

    def list_element_children(self, element: ArkindexElement, **kwargs):
        return list_children(element.id).iterator()


def main():
    DatasetExtractorFromSQL(
        description="Fill base-worker cache with information about dataset and extract images",
    ).run()


if __name__ == "__main__":
    main()