Skip to content
Snippets Groups Projects
db.py 2.85 KiB
# -*- coding: utf-8 -*-
from typing import NamedTuple

from arkindex_export import Classification, Image
from arkindex_export.models import (
    Element,
    Entity,
    EntityType,
    Transcription,
    TranscriptionEntity,
)
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
    CachedElement,
    CachedEntity,
    CachedTranscription,
    CachedTranscriptionEntity,
)

DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"


def retrieve_element(element_id: str):
    return Element.get_by_id(element_id)


def list_classifications(element_id: str):
    query = Classification.select().where(Classification.element_id == element_id)
    return query


def parse_transcription(transcription: NamedTuple, element: CachedElement):
    return CachedTranscription(
        id=transcription.id,
        element=element,
        text=transcription.text,
        # Dodge not-null constraint for now
        confidence=transcription.confidence or 1.0,
        orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
        worker_version_id=transcription.worker_version.id
        if transcription.worker_version
        else None,
    )


def list_transcriptions(element: CachedElement):
    query = Transcription.select().where(Transcription.element_id == element.id)
    return [parse_transcription(x, element) for x in query]


def parse_entities(data: NamedTuple, transcription: CachedTranscription):
    entity = CachedEntity(
        id=data.entity_id,
        type=data.type,
        name=data.name,
        validated=data.validated,
        metas=data.metas,
    )
    return entity, CachedTranscriptionEntity(
        id=data.transcription_entity_id,
        transcription=transcription,
        entity=entity,
        offset=data.offset,
        length=data.length,
        confidence=data.confidence,
    )


def retrieve_entities(transcription: CachedTranscription):
    query = (
        TranscriptionEntity.select(
            TranscriptionEntity.id.alias("transcription_entity_id"),
            TranscriptionEntity.length.alias("length"),
            TranscriptionEntity.offset.alias("offset"),
            TranscriptionEntity.confidence.alias("confidence"),
            Entity.id.alias("entity_id"),
            EntityType.name.alias("type"),
            Entity.name,
            Entity.validated,
            Entity.metas,
        )
        .where(TranscriptionEntity.transcription_id == transcription.id)
        .join(Entity, on=TranscriptionEntity.entity)
        .join(EntityType, on=Entity.type)
    )
    data = [
        parse_entities(entity_data, transcription)
        for entity_data in query.namedtuples()
    ]
    if not data:
        return [], []

    return zip(*data)


def get_children(parent_id, element_type=None):
    query = list_children(parent_id).join(Image)
    if element_type:
        query = query.where(Element.type == element_type)
    return query