Skip to content
Snippets Groups Projects
Verified Commit 9d2d6bf5 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Fix worker version filters

parent cddebe2f
No related branches found
No related tags found
No related merge requests found
......@@ -122,14 +122,12 @@ def add_extract_parser(subcommands) -> None:
type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=False,
)
parser.add_argument(
"--entity-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=False,
)
parser.add_argument(
......
......@@ -3,15 +3,17 @@
import ast
from dataclasses import dataclass
from itertools import starmap
from typing import List, NamedTuple, Optional, Union
from typing import List, Optional, Union
from urllib.parse import urljoin
from arkindex_export import Image
from arkindex_export.models import Element as ArkindexElement
from arkindex_export.models import Entity as ArkindexEntity
from arkindex_export.models import EntityType as ArkindexEntityType
from arkindex_export.models import Transcription as ArkindexTranscription
from arkindex_export.models import TranscriptionEntity as ArkindexTranscriptionEntity
from arkindex_export.models import (
Entity,
EntityType,
Transcription,
TranscriptionEntity,
)
from arkindex_export.queries import list_children
......@@ -25,23 +27,6 @@ def bounding_box(polygon: list):
return int(x), int(y), int(width), int(height)
# DB models
Transcription = NamedTuple(
"Transcription",
id=str,
text=str,
)
Entity = NamedTuple(
"Entity",
type=str,
value=str,
offset=float,
length=float,
)
@dataclass
class Element:
id: str
......@@ -94,6 +79,7 @@ def get_elements(
Image.height,
)
)
return list(
starmap(
lambda *x: Element(*x, max_width=max_width, max_height=max_height),
......@@ -118,47 +104,43 @@ def get_transcriptions(
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
"""
query = ArkindexTranscription.select(
ArkindexTranscription.id, ArkindexTranscription.text
).where(
(ArkindexTranscription.element == element_id)
& build_worker_version_filter(
ArkindexTranscription, worker_version=transcription_worker_version
)
)
return list(
starmap(
Transcription,
query.tuples(),
query = Transcription.select(
Transcription.id, Transcription.text, Transcription.worker_version
).where((Transcription.element == element_id))
if transcription_worker_version is not None:
query = query.where(
build_worker_version_filter(
Transcription, worker_version=transcription_worker_version
)
)
)
return query
def get_transcription_entities(
transcription_id: str, entity_worker_version: Union[str, bool]
) -> List[Entity]:
) -> List[TranscriptionEntity]:
"""
Retrieve transcription entities from an SQLite export of an Arkindex corpus
"""
query = (
ArkindexTranscriptionEntity.select(
ArkindexEntityType.name,
ArkindexEntity.name,
ArkindexTranscriptionEntity.offset,
ArkindexTranscriptionEntity.length,
)
.join(ArkindexEntity, on=ArkindexTranscriptionEntity.entity)
.join(ArkindexEntityType, on=ArkindexEntity.type)
.where(
(ArkindexTranscriptionEntity.transcription == transcription_id)
& build_worker_version_filter(
ArkindexTranscriptionEntity, worker_version=entity_worker_version
)
TranscriptionEntity.select(
EntityType.name.alias("type"),
Entity.name.alias("name"),
TranscriptionEntity.offset,
TranscriptionEntity.length,
TranscriptionEntity.worker_version,
)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
.where((TranscriptionEntity.transcription == transcription_id))
)
return list(
starmap(
Entity,
query.order_by(ArkindexTranscriptionEntity.offset).tuples(),
if entity_worker_version is not None:
query = query.where(
build_worker_version_filter(
TranscriptionEntity, worker_version=entity_worker_version
)
)
)
return query.order_by(TranscriptionEntity.offset).namedtuples()
......@@ -12,7 +12,6 @@ from tqdm import tqdm
from dan import logger
from dan.datasets.extract.db import (
Element,
Entity,
get_elements,
get_transcription_entities,
get_transcriptions,
......@@ -51,8 +50,8 @@ class ArkindexExtractor:
load_entities: bool = None,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
train_prob: float = None,
val_prob: float = None,
max_width: Optional[int] = None,
......@@ -100,7 +99,7 @@ class ArkindexExtractor:
def get_random_split(self):
return next(self._assign_random_split())
def reconstruct_text(self, text: str, entities: List[Entity]):
def reconstruct_text(self, text: str, entities):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
......@@ -226,8 +225,8 @@ def run(
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Union[str, bool],
entity_worker_version: Union[str, bool],
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
train_prob,
val_prob,
max_width: Optional[int],
......
......@@ -4,8 +4,6 @@ import pytest
from dan.datasets.extract.db import (
Element,
Entity,
Transcription,
get_elements,
get_transcription_entities,
get_transcriptions,
......@@ -35,7 +33,7 @@ def test_get_elements():
@pytest.mark.parametrize(
"worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60")
"worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60", None)
)
def test_get_transcriptions(worker_version):
"""
......@@ -48,22 +46,17 @@ def test_get_transcriptions(worker_version):
)
# Check number of results
assert len(transcriptions) == 1
transcription = transcriptions.pop()
assert isinstance(transcription, Transcription)
# Common keys
assert transcription.text == "[ T 8º SUP 26200"
# Differences
if worker_version:
assert transcription.id == "3bd248d6-998a-4579-a00c-d4639f3825aa"
else:
assert transcription.id == "c551960a-0f82-4779-b975-77a457bcf273"
assert len(transcriptions) == 1 + int(worker_version is None)
for transcription in transcriptions:
assert transcription.text == "[ T 8º SUP 26200"
if worker_version:
assert transcription.worker_version.id == worker_version
elif worker_version is False:
assert transcription.worker_version is None
@pytest.mark.parametrize(
"worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965")
"worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965", None)
)
def test_get_transcription_entities(worker_version):
transcription_id = "3bd248d6-998a-4579-a00c-d4639f3825aa"
......@@ -73,18 +66,15 @@ def test_get_transcription_entities(worker_version):
)
# Check number of results
assert len(entities) == 1
transcription_entity = entities.pop()
assert isinstance(transcription_entity, Entity)
# Differences
if worker_version:
assert transcription_entity.type == "cote"
assert transcription_entity.value == "T 8 º SUP 26200"
assert transcription_entity.offset == 2
assert transcription_entity.length == 14
else:
assert transcription_entity.type == "Cote"
assert transcription_entity.value == "[ T 8º SUP 26200"
assert transcription_entity.offset == 0
assert transcription_entity.length == 16
assert len(entities) == 1 + (worker_version is None)
for transcription_entity in entities:
if worker_version:
assert transcription_entity.type == "cote"
assert transcription_entity.name == "T 8 º SUP 26200"
assert transcription_entity.offset == 2
assert transcription_entity.length == 14
elif worker_version is False:
assert transcription_entity.type == "Cote"
assert transcription_entity.name == "[ T 8º SUP 26200"
assert transcription_entity.offset == 0
assert transcription_entity.length == 16
# -*- coding: utf-8 -*-
from typing import NamedTuple
import pytest
from dan.datasets.extract.extract import ArkindexExtractor, Entity
from dan.datasets.extract.extract import ArkindexExtractor
from dan.datasets.extract.utils import EntityType, insert_token
# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
@pytest.mark.parametrize(
"text,count,offset,length,expected",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment