From 9d2d6bf5712f0f5cbfb444bc1cd4844ef0e7646a Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 19 Jul 2023 08:25:54 +0000
Subject: [PATCH] Fix worker version filters

---
 dan/datasets/extract/__init__.py |  2 -
 dan/datasets/extract/db.py       | 90 +++++++++++++-------------------
 dan/datasets/extract/extract.py  | 11 ++--
 tests/test_db.py                 | 52 ++++++++----------
 tests/test_extract.py            |  7 ++-
 5 files changed, 68 insertions(+), 94 deletions(-)

diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index ea0a758c..54cca912 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -122,14 +122,12 @@ def add_extract_parser(subcommands) -> None:
         type=parse_worker_version,
         help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
         required=False,
-        default=False,
     )
     parser.add_argument(
         "--entity-worker-version",
         type=parse_worker_version,
         help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
         required=False,
-        default=False,
     )
 
     parser.add_argument(
diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py
index 86a6caac..79fd5d49 100644
--- a/dan/datasets/extract/db.py
+++ b/dan/datasets/extract/db.py
@@ -3,15 +3,17 @@
 import ast
 from dataclasses import dataclass
 from itertools import starmap
-from typing import List, NamedTuple, Optional, Union
+from typing import List, Optional, Union
 from urllib.parse import urljoin
 
 from arkindex_export import Image
 from arkindex_export.models import Element as ArkindexElement
-from arkindex_export.models import Entity as ArkindexEntity
-from arkindex_export.models import EntityType as ArkindexEntityType
-from arkindex_export.models import Transcription as ArkindexTranscription
-from arkindex_export.models import TranscriptionEntity as ArkindexTranscriptionEntity
+from arkindex_export.models import (
+    Entity,
+    EntityType,
+    Transcription,
+    TranscriptionEntity,
+)
 from arkindex_export.queries import list_children
 
 
@@ -25,23 +27,6 @@ def bounding_box(polygon: list):
     return int(x), int(y), int(width), int(height)
 
 
-# DB models
-Transcription = NamedTuple(
-    "Transcription",
-    id=str,
-    text=str,
-)
-
-
-Entity = NamedTuple(
-    "Entity",
-    type=str,
-    value=str,
-    offset=float,
-    length=float,
-)
-
-
 @dataclass
 class Element:
     id: str
@@ -94,6 +79,7 @@ def get_elements(
             Image.height,
         )
     )
+
     return list(
         starmap(
             lambda *x: Element(*x, max_width=max_width, max_height=max_height),
@@ -118,47 +104,43 @@ def get_transcriptions(
     """
     Retrieve transcriptions from an SQLite export of an Arkindex corpus
     """
-    query = ArkindexTranscription.select(
-        ArkindexTranscription.id, ArkindexTranscription.text
-    ).where(
-        (ArkindexTranscription.element == element_id)
-        & build_worker_version_filter(
-            ArkindexTranscription, worker_version=transcription_worker_version
-        )
-    )
-    return list(
-        starmap(
-            Transcription,
-            query.tuples(),
+    query = Transcription.select(
+        Transcription.id, Transcription.text, Transcription.worker_version
+    ).where((Transcription.element == element_id))
+
+    if transcription_worker_version is not None:
+        query = query.where(
+            build_worker_version_filter(
+                Transcription, worker_version=transcription_worker_version
+            )
         )
-    )
+    return query
 
 
 def get_transcription_entities(
     transcription_id: str, entity_worker_version: Union[str, bool]
-) -> List[Entity]:
+) -> List[TranscriptionEntity]:
     """
     Retrieve transcription entities from an SQLite export of an Arkindex corpus
     """
     query = (
-        ArkindexTranscriptionEntity.select(
-            ArkindexEntityType.name,
-            ArkindexEntity.name,
-            ArkindexTranscriptionEntity.offset,
-            ArkindexTranscriptionEntity.length,
-        )
-        .join(ArkindexEntity, on=ArkindexTranscriptionEntity.entity)
-        .join(ArkindexEntityType, on=ArkindexEntity.type)
-        .where(
-            (ArkindexTranscriptionEntity.transcription == transcription_id)
-            & build_worker_version_filter(
-                ArkindexTranscriptionEntity, worker_version=entity_worker_version
-            )
+        TranscriptionEntity.select(
+            EntityType.name.alias("type"),
+            Entity.name.alias("name"),
+            TranscriptionEntity.offset,
+            TranscriptionEntity.length,
+            TranscriptionEntity.worker_version,
         )
+        .join(Entity, on=TranscriptionEntity.entity)
+        .join(EntityType, on=Entity.type)
+        .where((TranscriptionEntity.transcription == transcription_id))
     )
-    return list(
-        starmap(
-            Entity,
-            query.order_by(ArkindexTranscriptionEntity.offset).tuples(),
+
+    if entity_worker_version is not None:
+        query = query.where(
+            build_worker_version_filter(
+                TranscriptionEntity, worker_version=entity_worker_version
+            )
         )
-    )
+
+    return query.order_by(TranscriptionEntity.offset).namedtuples()
diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index b4dd1da9..258500ba 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -12,7 +12,6 @@ from tqdm import tqdm
 from dan import logger
 from dan.datasets.extract.db import (
     Element,
-    Entity,
     get_elements,
     get_transcription_entities,
     get_transcriptions,
@@ -51,8 +50,8 @@ class ArkindexExtractor:
         load_entities: bool = None,
         tokens: Path = None,
         use_existing_split: bool = None,
-        transcription_worker_version: str = None,
-        entity_worker_version: str = None,
+        transcription_worker_version: Optional[Union[str, bool]] = None,
+        entity_worker_version: Optional[Union[str, bool]] = None,
         train_prob: float = None,
         val_prob: float = None,
         max_width: Optional[int] = None,
@@ -100,7 +99,7 @@ class ArkindexExtractor:
     def get_random_split(self):
         return next(self._assign_random_split())
 
-    def reconstruct_text(self, text: str, entities: List[Entity]):
+    def reconstruct_text(self, text: str, entities):
         """
         Insert tokens delimiting the start/end of each entity on the transcription.
         """
@@ -226,8 +225,8 @@ def run(
     train_folder: UUID,
     val_folder: UUID,
     test_folder: UUID,
-    transcription_worker_version: Union[str, bool],
-    entity_worker_version: Union[str, bool],
+    transcription_worker_version: Optional[Union[str, bool]],
+    entity_worker_version: Optional[Union[str, bool]],
     train_prob,
     val_prob,
     max_width: Optional[int],
diff --git a/tests/test_db.py b/tests/test_db.py
index 9230b808..6af15019 100644
--- a/tests/test_db.py
+++ b/tests/test_db.py
@@ -4,8 +4,6 @@ import pytest
 
 from dan.datasets.extract.db import (
     Element,
-    Entity,
-    Transcription,
     get_elements,
     get_transcription_entities,
     get_transcriptions,
@@ -35,7 +33,7 @@ def test_get_elements():
 
 
 @pytest.mark.parametrize(
-    "worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60")
+    "worker_version", (False, "0b2a429a-0da2-4b79-a6bb-330c6a07ac60", None)
 )
 def test_get_transcriptions(worker_version):
     """
@@ -48,22 +46,17 @@ def test_get_transcriptions(worker_version):
     )
 
     # Check number of results
-    assert len(transcriptions) == 1
-    transcription = transcriptions.pop()
-    assert isinstance(transcription, Transcription)
-
-    # Common keys
-    assert transcription.text == "[ T 8º SUP 26200"
-
-    # Differences
-    if worker_version:
-        assert transcription.id == "3bd248d6-998a-4579-a00c-d4639f3825aa"
-    else:
-        assert transcription.id == "c551960a-0f82-4779-b975-77a457bcf273"
+    assert len(transcriptions) == 1 + int(worker_version is None)
+    for transcription in transcriptions:
+        assert transcription.text == "[ T 8º SUP 26200"
+        if worker_version:
+            assert transcription.worker_version.id == worker_version
+        elif worker_version is False:
+            assert transcription.worker_version is None
 
 
 @pytest.mark.parametrize(
-    "worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965")
+    "worker_version", (False, "0e2a98f5-71ac-48f6-973b-cc10ed440965", None)
 )
 def test_get_transcription_entities(worker_version):
     transcription_id = "3bd248d6-998a-4579-a00c-d4639f3825aa"
@@ -73,18 +66,15 @@ def test_get_transcription_entities(worker_version):
     )
 
     # Check number of results
-    assert len(entities) == 1
-    transcription_entity = entities.pop()
-    assert isinstance(transcription_entity, Entity)
-
-    # Differences
-    if worker_version:
-        assert transcription_entity.type == "cote"
-        assert transcription_entity.value == "T 8 º SUP 26200"
-        assert transcription_entity.offset == 2
-        assert transcription_entity.length == 14
-    else:
-        assert transcription_entity.type == "Cote"
-        assert transcription_entity.value == "[ T 8º SUP 26200"
-        assert transcription_entity.offset == 0
-        assert transcription_entity.length == 16
+    assert len(entities) == 1 + (worker_version is None)
+    for transcription_entity in entities:
+        if worker_version:
+            assert transcription_entity.type == "cote"
+            assert transcription_entity.name == "T 8 º SUP 26200"
+            assert transcription_entity.offset == 2
+            assert transcription_entity.length == 14
+        elif worker_version is False:
+            assert transcription_entity.type == "Cote"
+            assert transcription_entity.name == "[ T 8º SUP 26200"
+            assert transcription_entity.offset == 0
+            assert transcription_entity.length == 16
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 11a6e7ee..4ddc6fa2 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -1,10 +1,15 @@
 # -*- coding: utf-8 -*-
 
+from typing import NamedTuple
+
 import pytest
 
-from dan.datasets.extract.extract import ArkindexExtractor, Entity
+from dan.datasets.extract.extract import ArkindexExtractor
 from dan.datasets.extract.utils import EntityType, insert_token
 
+# NamedTuple to mock actual database result
+Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
+
 
 @pytest.mark.parametrize(
     "text,count,offset,length,expected",
-- 
GitLab