diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index ea0a758c8352f0bbb9a8afb000c7e9fcc4aa17ba..54cca912fe7edd60aa1d354ab3bcf14e5ab11a7a 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -122,14 +122,12 @@ def add_extract_parser(subcommands) -> None: type=parse_worker_version, help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.", required=False, - default=False, ) parser.add_argument( "--entity-worker-version", type=parse_worker_version, help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.", required=False, - default=False, ) parser.add_argument( diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index 86a6caac4bef9014218d50c0936c87c7b5712462..79fd5d498ab6fb0659265a965417d31ae2189860 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -3,15 +3,17 @@ import ast from dataclasses import dataclass from itertools import starmap -from typing import List, NamedTuple, Optional, Union +from typing import List, Optional, Union from urllib.parse import urljoin from arkindex_export import Image from arkindex_export.models import Element as ArkindexElement -from arkindex_export.models import Entity as ArkindexEntity -from arkindex_export.models import EntityType as ArkindexEntityType -from arkindex_export.models import Transcription as ArkindexTranscription -from arkindex_export.models import TranscriptionEntity as ArkindexTranscriptionEntity +from arkindex_export.models import ( + Entity, + EntityType, + Transcription, + TranscriptionEntity, +) from arkindex_export.queries import list_children @@ -25,23 +27,6 @@ def bounding_box(polygon: list): return int(x), int(y), int(width), int(height) -# DB models -Transcription = NamedTuple( - "Transcription", - id=str, - text=str, -) - - -Entity = NamedTuple( - "Entity", - type=str, - value=str, - offset=float, - length=float, -) - - @dataclass class Element: id: str @@ -94,6 +79,7 @@ def get_elements( Image.height, ) ) + return list( starmap( lambda *x: Element(*x, max_width=max_width, max_height=max_height), @@ -118,47 +104,43 @@ def get_transcriptions( """ Retrieve transcriptions from an SQLite export of an Arkindex corpus """ - query = ArkindexTranscription.select( - ArkindexTranscription.id, ArkindexTranscription.text - ).where( - (ArkindexTranscription.element == element_id) - & build_worker_version_filter( - ArkindexTranscription, worker_version=transcription_worker_version - ) - ) - return list( - starmap( - Transcription, - query.tuples(), + query = Transcription.select( + Transcription.id, Transcription.text, Transcription.worker_version + ).where((Transcription.element == element_id)) + + if transcription_worker_version is not None: + query = query.where( + build_worker_version_filter( + Transcription, worker_version=transcription_worker_version + ) ) - ) + return query def get_transcription_entities( transcription_id: str, entity_worker_version: Union[str, bool] -) -> List[Entity]: +) -> List[TranscriptionEntity]: """ Retrieve transcription entities from an SQLite export of an Arkindex corpus """ query = ( - ArkindexTranscriptionEntity.select( - ArkindexEntityType.name, - ArkindexEntity.name, - ArkindexTranscriptionEntity.offset, - ArkindexTranscriptionEntity.length, - ) - .join(ArkindexEntity, on=ArkindexTranscriptionEntity.entity) - .join(ArkindexEntityType, on=ArkindexEntity.type) - .where( - (ArkindexTranscriptionEntity.transcription == transcription_id) - & build_worker_version_filter( - ArkindexTranscriptionEntity, worker_version=entity_worker_version - ) + TranscriptionEntity.select( + EntityType.name.alias("type"), + Entity.name.alias("name"), + TranscriptionEntity.offset, + TranscriptionEntity.length, + TranscriptionEntity.worker_version, ) + .join(Entity, on=TranscriptionEntity.entity) + .join(EntityType, on=Entity.type) + .where((TranscriptionEntity.transcription == transcription_id)) ) - return list( - starmap( - Entity, - query.order_by(ArkindexTranscriptionEntity.offset).tuples(), + + if entity_worker_version is not None: + query = query.where( + build_worker_version_filter( + TranscriptionEntity, worker_version=entity_worker_version + ) ) - ) + + return query.order_by(TranscriptionEntity.offset).namedtuples() diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index b4dd1da9908ff04040919843a4559ead633d6813..258500bac154c1c5b93127212504694e6292b5aa 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -12,7 +12,6 @@ from tqdm import tqdm from dan import logger from dan.datasets.extract.db import ( Element, - Entity, get_elements, get_transcription_entities, get_transcriptions, @@ -51,8 +50,8 @@ class ArkindexExtractor: load_entities: bool = None, tokens: Path = None, use_existing_split: bool = None, - transcription_worker_version: str = None, - entity_worker_version: str = None, + transcription_worker_version: Optional[Union[str, bool]] = None, + entity_worker_version: Optional[Union[str, bool]] = None, train_prob: float = None, val_prob: float = None, max_width: Optional[int] = None, @@ -100,7 +99,7 @@ class ArkindexExtractor: def get_random_split(self): return next(self._assign_random_split()) - def reconstruct_text(self, text: str, entities: List[Entity]): + def reconstruct_text(self, text: str, entities): """ Insert tokens delimiting the start/end of each entity on the transcription. """ @@ -226,8 +225,8 @@ def run( train_folder: UUID, val_folder: UUID, test_folder: UUID, - transcription_worker_version: Union[str, bool], - entity_worker_version: Union[str, bool], + transcription_worker_version: Optional[Union[str, bool]], + entity_worker_version: Optional[Union[str, bool]], train_prob, val_prob, max_width: Optional[int], diff --git a/tests/test_db.py b/tests/test_db.py index 9230b8084f6eb1e10341348982eda8ad70c0fffd..6af150196cf9147793a3e6fa734bfef517c8bf59 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,8 +4,6 @@ import pytest from dan.datasets.extract.db import ( Element, - Entity, - Transcription, get_elements, get_transcription_entities, get_transcriptions, @@ -35,7 +33,7 @@ def test_get_elements(): @pytest.mark.parametrize( - "worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60") + "worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60", None) ) def test_get_transcriptions(worker_version): """ @@ -48,22 +46,17 @@ def test_get_transcriptions(worker_version): ) # Check number of results - assert len(transcriptions) == 1 - transcription = transcriptions.pop() - assert isinstance(transcription, Transcription) - - # Common keys - assert transcription.text == "[ T 8º SUP 26200" - - # Differences - if worker_version: - assert transcription.id == "3bd248d6-998a-4579-a00c-d4639f3825aa" - else: - assert transcription.id == "c551960a-0f82-4779-b975-77a457bcf273" + assert len(transcriptions) == 1 + int(worker_version is None) + for transcription in transcriptions: + assert transcription.text == "[ T 8º SUP 26200" + if worker_version: + assert transcription.worker_version.id == worker_version + elif worker_version is False: + assert transcription.worker_version is None @pytest.mark.parametrize( - "worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965") + "worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965", None) ) def test_get_transcription_entities(worker_version): transcription_id = "3bd248d6-998a-4579-a00c-d4639f3825aa" @@ -73,18 +66,15 @@ def test_get_transcription_entities(worker_version): ) # Check number of results - assert len(entities) == 1 - transcription_entity = entities.pop() - assert isinstance(transcription_entity, Entity) - - # Differences - if worker_version: - assert transcription_entity.type == "cote" - assert transcription_entity.value == "T 8 º SUP 26200" - assert transcription_entity.offset == 2 - assert transcription_entity.length == 14 - else: - assert transcription_entity.type == "Cote" - assert transcription_entity.value == "[ T 8º SUP 26200" - assert transcription_entity.offset == 0 - assert transcription_entity.length == 16 + assert len(entities) == 1 + (worker_version is None) + for transcription_entity in entities: + if worker_version: + assert transcription_entity.type == "cote" + assert transcription_entity.name == "T 8 º SUP 26200" + assert transcription_entity.offset == 2 + assert transcription_entity.length == 14 + elif worker_version is False: + assert transcription_entity.type == "Cote" + assert transcription_entity.name == "[ T 8º SUP 26200" + assert transcription_entity.offset == 0 + assert transcription_entity.length == 16 diff --git a/tests/test_extract.py b/tests/test_extract.py index 11a6e7eee01759fbf941df24b9a6852b01c751da..4ddc6fa258a4e4ac6d0e413833fa03db7376d42d 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -1,10 +1,15 @@ # -*- coding: utf-8 -*- +from typing import NamedTuple + import pytest -from dan.datasets.extract.extract import ArkindexExtractor, Entity +from dan.datasets.extract.extract import ArkindexExtractor from dan.datasets.extract.utils import EntityType, insert_token +# NamedTuple to mock actual database result +Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str) + @pytest.mark.parametrize( "text,count,offset,length,expected",