From 9b9ead1a6b88bb07053d2cc726cfe1a73530cf73 Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Tue, 14 May 2024 15:32:14 +0200
Subject: [PATCH] Integration of Worker Runs in extraction

---
 dan/datasets/extract/__init__.py | 28 +++++++++---
 dan/datasets/extract/arkindex.py | 13 +++++-
 dan/datasets/extract/db.py       | 38 +++++++++++++++-
 dan/datasets/extract/utils.py    |  2 +
 docs/usage/datasets/extract.md   |  2 +
 tests/conftest.py                | 75 ++++++++++++++++++++------------
 tests/test_db.py                 | 38 ++++++++++++----
 tests/test_extract.py            | 16 ++++---
 8 files changed, 159 insertions(+), 53 deletions(-)

diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index 534ab1d9..87522fae 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -15,16 +15,16 @@ from dan.datasets.extract.arkindex import run
 MANUAL_SOURCE = "manual"
 
 
-def parse_worker_version(worker_version_id) -> str | bool:
-    if worker_version_id == MANUAL_SOURCE:
+def parse_source(source) -> str | bool:
+    if source == MANUAL_SOURCE:
         return False
 
     try:
-        UUID(worker_version_id)
+        UUID(source)
     except ValueError:
-        raise argparse.ArgumentTypeError(f"`{worker_version_id}` is not a valid UUID.")
+        raise argparse.ArgumentTypeError(f"`{source}` is not a valid UUID.")
 
-    return worker_version_id
+    return source
 
 
 def validate_char(char):
@@ -97,18 +97,32 @@ def add_extract_parser(subcommands) -> None:
 
     parser.add_argument(
         "--transcription-worker-versions",
-        type=parse_worker_version,
+        type=parse_source,
         nargs="+",
         help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
         default=[],
     )
     parser.add_argument(
         "--entity-worker-versions",
-        type=parse_worker_version,
+        type=parse_source,
         nargs="+",
         help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
         default=[],
     )
+    parser.add_argument(
+        "--transcription-worker-runs",
+        type=parse_source,
+        nargs="+",
+        help=f"Filter transcriptions by worker_run. Use {MANUAL_SOURCE} for manual filtering.",
+        default=[],
+    )
+    parser.add_argument(
+        "--entity-worker-runs",
+        type=parse_source,
+        nargs="+",
+        help=f"Filter transcriptions entities by worker_run. Use {MANUAL_SOURCE} for manual filtering.",
+        default=[],
+    )
 
     parser.add_argument(
         "--subword-vocab-size",
diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py
index a7429932..dce5b244 100644
--- a/dan/datasets/extract/arkindex.py
+++ b/dan/datasets/extract/arkindex.py
@@ -61,6 +61,8 @@ class ArkindexExtractor:
         tokens: Path | None = None,
         transcription_worker_versions: List[str | bool] = [],
         entity_worker_versions: List[str | bool] = [],
+        transcription_worker_runs: List[str | bool] = [],
+        entity_worker_runs: List[str | bool] = [],
         keep_spaces: bool = False,
         allow_empty: bool = False,
         subword_vocab_size: int = 1000,
@@ -73,6 +75,8 @@ class ArkindexExtractor:
         self.tokens = parse_tokens(tokens) if tokens else {}
         self.transcription_worker_versions = transcription_worker_versions
         self.entity_worker_versions = entity_worker_versions
+        self.transcription_worker_runs = transcription_worker_runs
+        self.entity_worker_runs = entity_worker_runs
         self.allow_empty = allow_empty
         self.mapping = LMTokenMapping()
         self.keep_spaces = keep_spaces
@@ -101,7 +105,9 @@ class ArkindexExtractor:
         If the entities are needed, they are added to the transcription using tokens.
         """
         transcriptions = get_transcriptions(
-            element.id, self.transcription_worker_versions
+            element.id,
+            self.transcription_worker_versions,
+            self.transcription_worker_runs,
         )
         if len(transcriptions) == 0:
             if self.allow_empty:
@@ -117,6 +123,7 @@ class ArkindexExtractor:
         entities = get_transcription_entities(
             transcription.id,
             self.entity_worker_versions,
+            self.entity_worker_runs,
             supported_types=list(self.tokens),
         )
 
@@ -328,6 +335,8 @@ def run(
     tokens: Path,
     transcription_worker_versions: List[str | bool],
     entity_worker_versions: List[str | bool],
+    transcription_worker_runs: List[str | bool],
+    entity_worker_runs: List[str | bool],
     keep_spaces: bool,
     allow_empty: bool,
     subword_vocab_size: int,
@@ -347,6 +356,8 @@ def run(
         tokens=tokens,
         transcription_worker_versions=transcription_worker_versions,
         entity_worker_versions=entity_worker_versions,
+        transcription_worker_runs=transcription_worker_runs,
+        entity_worker_runs=entity_worker_runs,
         keep_spaces=keep_spaces,
         allow_empty=allow_empty,
         subword_vocab_size=subword_vocab_size,
diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py
index 13408f5a..dfb29191 100644
--- a/dan/datasets/extract/db.py
+++ b/dan/datasets/extract/db.py
@@ -68,14 +68,33 @@ def build_worker_version_filter(ArkindexModel, worker_versions: List[str | bool]
     return condition
 
 
+def build_worker_run_filter(ArkindexModel, worker_runs: List[str | bool]):
+    """
+    `False` worker run means `manual` worker_run -> null field.
+    """
+    condition = None
+    for worker_run in worker_runs:
+        condition |= (
+            ArkindexModel.worker_run == worker_run
+            if worker_run
+            else ArkindexModel.worker_run.is_null()
+        )
+    return condition
+
+
 def get_transcriptions(
-    element_id: str, transcription_worker_versions: List[str | bool]
+    element_id: str,
+    transcription_worker_versions: List[str | bool],
+    transcription_worker_runs: List[str | bool],
 ) -> List[Transcription]:
     """
     Retrieve transcriptions from an SQLite export of an Arkindex corpus
     """
     query = Transcription.select(
-        Transcription.id, Transcription.text, Transcription.worker_version
+        Transcription.id,
+        Transcription.text,
+        Transcription.worker_version,
+        Transcription.worker_run,
     ).where((Transcription.element == element_id))
 
     if transcription_worker_versions:
@@ -84,12 +103,21 @@ def get_transcriptions(
                 Transcription, worker_versions=transcription_worker_versions
             )
         )
+
+    if transcription_worker_runs:
+        query = query.where(
+            build_worker_run_filter(
+                Transcription, worker_runs=transcription_worker_runs
+            )
+        )
+
     return query
 
 
 def get_transcription_entities(
     transcription_id: str,
     entity_worker_versions: List[str | bool],
+    entity_worker_runs: List[str | bool],
     supported_types: List[str],
 ) -> List[TranscriptionEntity]:
     """
@@ -102,6 +130,7 @@ def get_transcription_entities(
             TranscriptionEntity.offset,
             TranscriptionEntity.length,
             TranscriptionEntity.worker_version,
+            TranscriptionEntity.worker_run,
         )
         .join(Entity, on=TranscriptionEntity.entity)
         .join(EntityType, on=Entity.type)
@@ -118,6 +147,11 @@ def get_transcription_entities(
             )
         )
 
+    if entity_worker_runs:
+        query = query.where(
+            build_worker_run_filter(TranscriptionEntity, worker_runs=entity_worker_runs)
+        )
+
     return query.order_by(
         TranscriptionEntity.offset, TranscriptionEntity.length.desc()
     ).dicts()
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index 4ddb9fde..b000b479 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -209,6 +209,7 @@ class XMLEntity:
     offset: int
     length: int
     worker_version: str
+    worker_run: str
     children: List["XMLEntity"] = field(default_factory=list)
 
     @property
@@ -223,6 +224,7 @@ class XMLEntity:
                 offset=child["offset"] - self.offset,
                 length=child["length"],
                 worker_version=child["worker_version"],
+                worker_run=child["worker_run"],
             )
         )
 
diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md
index c098824b..cf0f0124 100644
--- a/docs/usage/datasets/extract.md
+++ b/docs/usage/datasets/extract.md
@@ -19,6 +19,8 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
 | `--tokens`                        | Mapping between starting tokens and end tokens to extract text with their entities.                                                                                                                                                                                       | `pathlib.Path`  |         |
 | `--transcription-worker-versions` | Filter transcriptions by worker_version. Use `manual` for manual filtering.                                                                                                                                                                                               | `str` or `uuid` |         |
 | `--entity-worker-versions`        | Filter transcriptions entities by worker_version. Use `manual` for manual filtering                                                                                                                                                                                       | `str` or `uuid` |         |
+| `--transcription-worker-runs`     | Filter transcriptions by worker_runs. Use `manual` for manual filtering.                                                                                                                                                                                                  | `str` or `uuid` |         |
+| `--entity-worker-runs`            | Filter transcriptions entities by worker_runs. Use `manual` for manual filtering                                                                                                                                                                                          | `str` or `uuid` |         |
 | `--keep-spaces`                   | Transcriptions are trimmed by default. Use this flag to disable this behaviour.                                                                                                                                                                                           | `bool`          | `False` |
 | `--allow-empty`                   | Elements with no transcriptions are skipped by default. This flag disables this behaviour.                                                                                                                                                                                | `bool`          | `False` |
 | `--subword-vocab-size`            | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model.                                                                                                                                                       | `int`           | `1000`  |
diff --git a/tests/conftest.py b/tests/conftest.py
index 3a4e1557..50f1348c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -33,6 +33,7 @@ def mock_database(tmp_path_factory):
     def create_transcription_entity(
         transcription: Transcription,
         worker_version: str | None,
+        worker_run: str | None,
         type: str,
         name: str,
         offset: int,
@@ -45,6 +46,7 @@ def mock_database(tmp_path_factory):
             name=name,
             type=entity_type,
             worker_version=worker_version,
+            worker_run=worker_run,
         )
         TranscriptionEntity.create(
             entity=entity,
@@ -52,6 +54,7 @@ def mock_database(tmp_path_factory):
             offset=offset,
             transcription=transcription,
             worker_version=worker_version,
+            worker_run=worker_run,
         )
 
     def create_transcriptions(element: Element, entities: List[dict]) -> None:
@@ -65,24 +68,31 @@ def mock_database(tmp_path_factory):
         for offset, entity in enumerate(entities[1:], start=1):
             entity["offset"] += offset
 
-        for worker_version in [None, "worker_version_id"]:
+        for worker_version, worker_run in [
+            (None, None),
+            ("worker_version_id", "worker_run_id"),
+        ]:
+            transcription_suffix = ""
             # Use different transcriptions to filter by worker version
-            if worker_version == "worker_version_id":
+            if worker_version and worker_run:
+                transcription_suffix = "source"
                 for entity in entities:
                     entity["name"] = entity["name"].lower()
 
             transcription = Transcription.create(
-                id=element.id + (worker_version or ""),
+                id=element.id + transcription_suffix,
                 # Add extra spaces to test the "keep_spaces" parameters of the "extract" command
                 text="  ".join(map(itemgetter("name"), entities)),
                 element=element,
                 worker_version=worker_version,
+                worker_run=worker_run,
             )
 
             for entity in entities:
                 create_transcription_entity(
                     transcription=transcription,
                     worker_version=worker_version,
+                    worker_run=worker_run,
                     **entity,
                 )
 
@@ -183,6 +193,11 @@ def mock_database(tmp_path_factory):
         type="worker",
     )
 
+    WorkerRun.create(
+        id="worker_run_id",
+        worker_version="worker_version_id",
+    )
+
     # Create dataset
     split_names = [VAL_NAME, TEST_NAME, TRAIN_NAME]
     dataset = Dataset.create(
@@ -216,82 +231,88 @@ def mock_database(tmp_path_factory):
         element=Element.select().first(),
     )
 
-    # Create worker version
-    WorkerVersion.bulk_create(
-        [
-            WorkerVersion(
-                id=f"{nestation}-id",
-                slug=nestation,
-                name=nestation,
-                repository_url="http://repository/url",
-                revision="main",
-                type="worker",
-            )
-            for nestation in ("nested", "non-nested", "special-chars")
-        ]
-    )
+    # Create worker version and worker run
+    for nestation in ("nested", "non-nested", "special-chars"):
+        WorkerVersion.create(
+            id=f"worker-version-{nestation}-id",
+            slug=nestation,
+            name=nestation,
+            repository_url="http://repository/url",
+            revision="main",
+            type="worker",
+        )
+        WorkerRun.create(
+            id=f"worker-run-{nestation}-id",
+            worker_version=f"worker-version-{nestation}-id",
+        )
 
     # Create entities
     for entity in [
         # Non-nested entities
         {
-            "worker_version": "non-nested-id",
+            "source": "non-nested",
             "type": "adj",
             "name": "great",
             "offset": 4,
         },
         {
-            "worker_version": "non-nested-id",
+            "source": "non-nested",
             "type": "name",
             "name": "Charles",
             "offset": 15,
         },
         {
-            "worker_version": "non-nested-id",
+            "source": "non-nested",
             "type": "person",
             "name": "us",
             "offset": 43,
         },
         # Nested entities
         {
-            "worker_version": "nested-id",
+            "source": "nested",
             "type": "fullname",
             "name": "Charles III",
             "offset": 15,
         },
         {
-            "worker_version": "nested-id",
+            "source": "nested",
             "type": "name",
             "name": "Charles",
             "offset": 15,
         },
         {
-            "worker_version": "nested-id",
+            "source": "nested",
             "type": "person",
             "name": "us",
             "offset": 43,
         },
         # Special characters
         {
-            "worker_version": "special-chars-id",
+            "source": "special-chars",
             "type": "Arkindex's entity",
             "name": "great",
             "offset": 4,
         },
         {
-            "worker_version": "special-chars-id",
+            "source": "special-chars",
             "type": '"Name" (1)',
             "name": "Charles",
             "offset": 15,
         },
         {
-            "worker_version": "special-chars-id",
+            "source": "special-chars",
             "type": "Person /!\\",
             "name": "us",
             "offset": 43,
         },
     ]:
-        create_transcription_entity(transcription=transcription, **entity)
+        source = entity.pop("source")
+        create_transcription_entity(
+            transcription=transcription,
+            worker_version=f"worker-version-{source}-id",
+            worker_run=f"worker-run-{source}-id",
+            **entity,
+        )
 
     return database_path
 
diff --git a/tests/test_db.py b/tests/test_db.py
index d7ab5c79..875add45 100644
--- a/tests/test_db.py
+++ b/tests/test_db.py
@@ -57,33 +57,45 @@ def test_get_elements(mock_database):
 
 
 @pytest.mark.parametrize(
-    "worker_versions",
-    ([False], ["worker_version_id"], [], [False, "worker_version_id"]),
+    "sources",
+    ([False], ["id"], [], [False, "id"]),
 )
-def test_get_transcriptions(worker_versions, mock_database):
+def test_get_transcriptions(sources, mock_database):
     """
     Assert transcriptions retrieval output against verified results
     """
+    worker_versions = [
+        f"worker_version_{source}" if isinstance(source, str) else source
+        for source in sources
+    ]
+    worker_runs = [
+        f"worker_run_{source}" if isinstance(source, str) else source
+        for source in sources
+    ]
+
     element_id = "train-page_1-line_1"
     transcriptions = get_transcriptions(
         element_id=element_id,
         transcription_worker_versions=worker_versions,
+        transcription_worker_runs=worker_runs,
     )
 
     expected_transcriptions = []
-    if not worker_versions or False in worker_versions:
+    if not sources or False in sources:
         expected_transcriptions.append(
             {
                 "text": "Caillet  Maurice  28.9.06",
                 "worker_version_id": None,
+                "worker_run_id": None,
             }
         )
 
-    if not worker_versions or "worker_version_id" in worker_versions:
+    if not sources or "id" in sources:
         expected_transcriptions.append(
             {
                 "text": "caillet  maurice  28.9.06",
                 "worker_version_id": "worker_version_id",
+                "worker_run_id": "worker_run_id",
             }
         )
 
@@ -95,6 +107,9 @@ def test_get_transcriptions(worker_versions, mock_database):
                     "worker_version_id": transcription.worker_version.id
                     if transcription.worker_version
                     else None,
+                    "worker_run_id": transcription.worker_run.id
+                    if transcription.worker_run
+                    else None,
                 }
                 for transcription in transcriptions
             ],
@@ -104,15 +119,19 @@ def test_get_transcriptions(worker_versions, mock_database):
     )
 
 
-@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
+@pytest.mark.parametrize("source", (False, "id", None))
 @pytest.mark.parametrize(
     "supported_types", (["surname"], ["surname", "firstname", "birthdate"])
 )
-def test_get_transcription_entities(worker_version, mock_database, supported_types):
-    transcription_id = "train-page_1-line_1" + (worker_version or "")
+def test_get_transcription_entities(source, mock_database, supported_types):
+    worker_version = f"worker_version_{source}" if isinstance(source, str) else source
+    worker_run = f"worker_run_{source}" if isinstance(source, str) else source
+
+    transcription_id = "train-page_1-line_1" + ("source" if source else "")
     entities = get_transcription_entities(
         transcription_id=transcription_id,
         entity_worker_versions=[worker_version],
+        entity_worker_runs=[worker_run],
         supported_types=supported_types,
     )
 
@@ -141,9 +160,10 @@ def test_get_transcription_entities(worker_version, mock_database, supported_typ
         filter(lambda ent: ent["type"] in supported_types, expected_entities)
     )
     for entity in expected_entities:
-        if worker_version:
+        if source:
             entity["name"] = entity["name"].lower()
         entity["worker_version"] = worker_version or None
+        entity["worker_run"] = worker_run or None
 
     assert (
         sorted(
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 92143b0d..54e03f89 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -459,31 +459,31 @@ def test_extract_transcription_no_translation(mock_database, tokens):
     (
         # Non-nested
         (
-            "non-nested-id",
+            "non-nested",
             "<root>The <adj>great</adj> king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>",
             None,
         ),
         # Non-nested no text between entities
         (
-            "non-nested-id",
+            "non-nested",
             "<root><adj>great</adj> <name>Charles</name>\n<person>us</person></root>",
             ["\n", " "],
         ),
         # Nested
         (
-            "nested-id",
+            "nested",
             "<root>The great king <fullname><name>Charles</name> III</fullname> has eaten \nwith <person>us</person>.</root>",
             None,
         ),
         # Nested no text between entities
         (
-            "nested-id",
+            "nested",
             "<root><fullname><name>Charles</name> III</fullname>\n<person>us</person></root>",
             ["\n", " "],
         ),
         # Special characters in entities
         (
-            "special-chars-id",
+            "special-chars",
             "<root>The <Arkindex_s_entity>great</Arkindex_s_entity> king <_Name_1_>Charles</_Name_1_> III has eaten \nwith <Person_>us</Person_>.</root>",
             None,
         ),
@@ -496,7 +496,8 @@ def test_entities_to_xml(mock_database, nestation, xml_output, separators):
             text=transcription.text,
             predictions=get_transcription_entities(
                 transcription_id="tr-with-entities",
-                entity_worker_versions=[nestation],
+                entity_worker_versions=[f"worker-version-{nestation}-id"],
+                entity_worker_runs=[f"worker-run-{nestation}-id"],
                 supported_types=[
                     "name",
                     "fullname",
@@ -539,7 +540,8 @@ def test_entities_to_xml_partial_entities(
             text=transcription.text,
             predictions=get_transcription_entities(
                 transcription_id="tr-with-entities",
-                entity_worker_versions=["non-nested-id"],
+                entity_worker_versions=["worker-version-non-nested-id"],
+                entity_worker_runs=["worker-run-non-nested-id"],
                 supported_types=supported_entities,
             ),
             entity_separators=separators,
-- 
GitLab