Skip to content
Snippets Groups Projects
db.py 2.47 KiB
Newer Older
Yoann Schneider's avatar
Yoann Schneider committed
# -*- coding: utf-8 -*-
from operator import attrgetter
Yoann Schneider's avatar
Yoann Schneider committed
from typing import List, Optional

from arkindex_export import Dataset, DatasetElement, Element, Transcription
Yoann Schneider's avatar
Yoann Schneider committed
from arkindex_export.queries import list_children

from atr_data_generator.extract.arguments import MANUAL


def get_dataset_elements(dataset: Dataset, split: str):
    """
    Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus
    :param dataset: Dataset object from which the elements come.
    :param split: Set name of the dataset to use.
    :return: The filtered list of dataset elements.
Yoann Schneider's avatar
Yoann Schneider committed
    """
    return (
        DatasetElement.select(DatasetElement.element)
        .join(Element)
        .where(
            DatasetElement.dataset == dataset,
            DatasetElement.set_name == split,
        )
    )
Yoann Schneider's avatar
Yoann Schneider committed


def parse_sources(sources: List[str]):
    """List of transcriptions sources. Manual source has a different treatment.

    :param sources: List of str or MANUAL.
    :return: A peewee filter by Transcription.worker_version
    """
    query_filter = None

    if MANUAL in sources:
        # Manual filtering
        query_filter = Transcription.worker_version.is_null()
        sources.remove(MANUAL)

    # Filter by worker_versions
    if sources:
        if query_filter:
            query_filter |= Transcription.worker_version.in_(sources)
        else:
            query_filter = Transcription.worker_version.in_(sources)
    return query_filter


def get_children_info(
    parent_id: str,
    type: Optional[str],
    sources: Optional[List[str]],
):
    """Get the information about the children elements and their transcriptions. Apply all needed filters.

    :param parent_id: ID of the parent element.
    :param type: Transcriptions of elements.
    """

    elements = list_children(parent_id)

    # Insert parent in the query to allow to process it
    elements = Element.select().where(
        Element.id.in_(list(map(attrgetter("id"), elements)) + [parent_id])
    )

Yoann Schneider's avatar
Yoann Schneider committed
    # Filter by type
    if type:
        elements = elements.where(Element.type == type)

    # Get transcriptions
    transcriptions = Transcription.select().join(
        elements, on=(Transcription.element == elements.c.id)
    )

    # Filter by transcription source
    if sources:
        transcriptions = transcriptions.where(parse_sources(sources.copy()))
Yoann Schneider's avatar
Yoann Schneider committed
    # Additional ordering in case there are identical names
    return transcriptions.order_by(Transcription.element.name, Transcription.element_id)