From fbdce150bad74ee0ccc91d957e816226c0c204e5 Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Tue, 16 Jan 2024 10:30:06 +0100
Subject: [PATCH] Support multiple worker_version for dataset extraction

---
 dan/datasets/extract/__init__.py | 10 ++++++----
 dan/datasets/extract/arkindex.py | 20 ++++++++++----------
 dan/datasets/extract/db.py       | 26 +++++++++++++++-----------
 docs/usage/datasets/extract.md   | 28 ++++++++++++++--------------
 tests/test_db.py                 | 15 +++++++++------
 tests/test_extract.py            | 15 +++++----------
 6 files changed, 59 insertions(+), 55 deletions(-)

diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index 278ee892..8ce3c74c 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -93,16 +93,18 @@ def add_extract_parser(subcommands) -> None:
     )
 
     parser.add_argument(
-        "--transcription-worker-version",
+        "--transcription-worker-versions",
         type=parse_worker_version,
+        nargs="+",
         help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
-        required=False,
+        default=[],
     )
     parser.add_argument(
-        "--entity-worker-version",
+        "--entity-worker-versions",
         type=parse_worker_version,
+        nargs="+",
         help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
-        required=False,
+        default=[],
     )
 
     parser.add_argument(
diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py
index e2aa3088..d8dcc1f6 100644
--- a/dan/datasets/extract/arkindex.py
+++ b/dan/datasets/extract/arkindex.py
@@ -56,8 +56,8 @@ class ArkindexExtractor:
         entity_separators: List[str] = ["\n", " "],
         unknown_token: str = "⁇",
         tokens: Path | None = None,
-        transcription_worker_version: str | bool | None = None,
-        entity_worker_version: str | bool | None = None,
+        transcription_worker_versions: List[str | bool] = [],
+        entity_worker_versions: List[str | bool] = [],
         keep_spaces: bool = False,
         allow_empty: bool = False,
         subword_vocab_size: int = 1000,
@@ -68,8 +68,8 @@ class ArkindexExtractor:
         self.entity_separators = entity_separators
         self.unknown_token = unknown_token
         self.tokens = parse_tokens(tokens) if tokens else {}
-        self.transcription_worker_version = transcription_worker_version
-        self.entity_worker_version = entity_worker_version
+        self.transcription_worker_versions = transcription_worker_versions
+        self.entity_worker_versions = entity_worker_versions
         self.allow_empty = allow_empty
         self.mapping = LMTokenMapping()
         self.keep_spaces = keep_spaces
@@ -98,7 +98,7 @@ class ArkindexExtractor:
         If the entities are needed, they are added to the transcription using tokens.
         """
         transcriptions = get_transcriptions(
-            element.id, self.transcription_worker_version
+            element.id, self.transcription_worker_versions
         )
         if len(transcriptions) == 0:
             if self.allow_empty:
@@ -112,7 +112,7 @@ class ArkindexExtractor:
 
         entities = get_transcription_entities(
             transcription.id,
-            self.entity_worker_version,
+            self.entity_worker_versions,
             supported_types=list(self.tokens),
         )
 
@@ -319,8 +319,8 @@ def run(
     entity_separators: List[str],
     unknown_token: str,
     tokens: Path,
-    transcription_worker_version: str | bool | None,
-    entity_worker_version: str | bool | None,
+    transcription_worker_versions: List[str | bool],
+    entity_worker_versions: List[str | bool],
     keep_spaces: bool,
     allow_empty: bool,
     subword_vocab_size: int,
@@ -338,8 +338,8 @@ def run(
         entity_separators=entity_separators,
         unknown_token=unknown_token,
         tokens=tokens,
-        transcription_worker_version=transcription_worker_version,
-        entity_worker_version=entity_worker_version,
+        transcription_worker_versions=transcription_worker_versions,
+        entity_worker_versions=entity_worker_versions,
         keep_spaces=keep_spaces,
         allow_empty=allow_empty,
         subword_vocab_size=subword_vocab_size,
diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py
index 25146799..ce9e9a87 100644
--- a/dan/datasets/extract/db.py
+++ b/dan/datasets/extract/db.py
@@ -51,18 +51,22 @@ def get_elements(
     return query
 
 
-def build_worker_version_filter(ArkindexModel, worker_version):
+def build_worker_version_filter(ArkindexModel, worker_versions: List[str | bool]):
     """
     `False` worker version means `manual` worker_version -> null field.
     """
-    if worker_version:
-        return ArkindexModel.worker_version == worker_version
-    else:
-        return ArkindexModel.worker_version.is_null()
+    condition = None
+    for worker_version in worker_versions:
+        condition |= (
+            ArkindexModel.worker_version == worker_version
+            if worker_version
+            else ArkindexModel.worker_version.is_null()
+        )
+    return condition
 
 
 def get_transcriptions(
-    element_id: str, transcription_worker_version: str | bool
+    element_id: str, transcription_worker_versions: List[str | bool]
 ) -> List[Transcription]:
     """
     Retrieve transcriptions from an SQLite export of an Arkindex corpus
@@ -71,10 +75,10 @@ def get_transcriptions(
         Transcription.id, Transcription.text, Transcription.worker_version
     ).where((Transcription.element == element_id))
 
-    if transcription_worker_version is not None:
+    if transcription_worker_versions:
         query = query.where(
             build_worker_version_filter(
-                Transcription, worker_version=transcription_worker_version
+                Transcription, worker_versions=transcription_worker_versions
             )
         )
     return query
@@ -82,7 +86,7 @@ def get_transcriptions(
 
 def get_transcription_entities(
     transcription_id: str,
-    entity_worker_version: str | bool | None,
+    entity_worker_versions: List[str | bool],
     supported_types: List[str],
 ) -> List[TranscriptionEntity]:
     """
@@ -104,10 +108,10 @@ def get_transcription_entities(
         )
     )
 
-    if entity_worker_version is not None:
+    if entity_worker_versions:
         query = query.where(
             build_worker_version_filter(
-                TranscriptionEntity, worker_version=entity_worker_version
+                TranscriptionEntity, worker_versions=entity_worker_versions
             )
         )
 
diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md
index f3f5f05c..c098824b 100644
--- a/docs/usage/datasets/extract.md
+++ b/docs/usage/datasets/extract.md
@@ -8,20 +8,20 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
 - Store the set of characters encountered in the dataset (in the `charset.pkl` file),
 - Generate the resources needed to build a n-gram language model at character, subword or word-level with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder).
 
-| Parameter                        | Description                                                                                                                                                                                                                                                               | Type            | Default |
-| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
-| `database`                       | Path to an Arkindex export database in SQLite format.                                                                                                                                                                                                                     | `pathlib.Path`  |         |
-| `--dataset-id `                  | ID of the dataset to extract from Arkindex.                                                                                                                                                                                                                               | `uuid`          |         |
-| `--element-type`                 | Type of the elements to extract. You may specify multiple types.                                                                                                                                                                                                          | `str`           |         |
-| `--output`                       | Folder where the data will be generated.                                                                                                                                                                                                                                  | `pathlib.Path`  |         |
-| `--entity-separators`            | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str`           |         |
-| `--unknown-token`                | Token to use to replace character in the validation/test sets that is not included in the training set.                                                                                                                                                                   | `str`           | `⁇`     |
-| `--tokens`                       | Mapping between starting tokens and end tokens to extract text with their entities.                                                                                                                                                                                       | `pathlib.Path`  |         |
-| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering.                                                                                                                                                                                               | `str` or `uuid` |         |
-| `--entity-worker-version`        | Filter transcriptions entities by worker_version. Use `manual` for manual filtering                                                                                                                                                                                       | `str` or `uuid` |         |
-| `--keep-spaces`                  | Transcriptions are trimmed by default. Use this flag to disable this behaviour.                                                                                                                                                                                           | `bool`          | `False` |
-| `--allow-empty`                  | Elements with no transcriptions are skipped by default. This flag disables this behaviour.                                                                                                                                                                                | `bool`          | `False` |
-| `--subword-vocab-size`           | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model.                                                                                                                                                       | `int`           | `1000`  |
+| Parameter                         | Description                                                                                                                                                                                                                                                               | Type            | Default |
+| --------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
+| `database`                        | Path to an Arkindex export database in SQLite format.                                                                                                                                                                                                                     | `pathlib.Path`  |         |
+| `--dataset-id `                   | ID of the dataset to extract from Arkindex.                                                                                                                                                                                                                               | `uuid`          |         |
+| `--element-type`                  | Type of the elements to extract. You may specify multiple types.                                                                                                                                                                                                          | `str`           |         |
+| `--output`                        | Folder where the data will be generated.                                                                                                                                                                                                                                  | `pathlib.Path`  |         |
+| `--entity-separators`             | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str`           |         |
+| `--unknown-token`                 | Token to use to replace character in the validation/test sets that is not included in the training set.                                                                                                                                                                   | `str`           | `⁇`     |
+| `--tokens`                        | Mapping between starting tokens and end tokens to extract text with their entities.                                                                                                                                                                                       | `pathlib.Path`  |         |
+| `--transcription-worker-versions` | Filter transcriptions by worker_version. Use `manual` for manual filtering.                                                                                                                                                                                               | `str` or `uuid` |         |
+| `--entity-worker-versions`        | Filter transcriptions entities by worker_version. Use `manual` for manual filtering                                                                                                                                                                                       | `str` or `uuid` |         |
+| `--keep-spaces`                   | Transcriptions are trimmed by default. Use this flag to disable this behaviour.                                                                                                                                                                                           | `bool`          | `False` |
+| `--allow-empty`                   | Elements with no transcriptions are skipped by default. This flag disables this behaviour.                                                                                                                                                                                | `bool`          | `False` |
+| `--subword-vocab-size`            | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model.                                                                                                                                                       | `int`           | `1000`  |
 
 The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md).
 
diff --git a/tests/test_db.py b/tests/test_db.py
index 0da81d8e..ba305814 100644
--- a/tests/test_db.py
+++ b/tests/test_db.py
@@ -53,19 +53,22 @@ def test_get_elements(mock_database):
     ]
 
 
-@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
-def test_get_transcriptions(worker_version, mock_database):
+@pytest.mark.parametrize(
+    "worker_versions",
+    ([False], ["worker_version_id"], [], [False, "worker_version_id"]),
+)
+def test_get_transcriptions(worker_versions, mock_database):
     """
     Assert transcriptions retrieval output against verified results
     """
     element_id = "train-page_1-line_1"
     transcriptions = get_transcriptions(
         element_id=element_id,
-        transcription_worker_version=worker_version,
+        transcription_worker_versions=worker_versions,
     )
 
     expected_transcriptions = []
-    if worker_version in [False, None]:
+    if not worker_versions or False in worker_versions:
         expected_transcriptions.append(
             {
                 "text": "Caillet  Maurice  28.9.06",
@@ -73,7 +76,7 @@ def test_get_transcriptions(worker_version, mock_database):
             }
         )
 
-    if worker_version in ["worker_version_id", None]:
+    if not worker_versions or "worker_version_id" in worker_versions:
         expected_transcriptions.append(
             {
                 "text": "caillet  maurice  28.9.06",
@@ -106,7 +109,7 @@ def test_get_transcription_entities(worker_version, mock_database, supported_typ
     transcription_id = "train-page_1-line_1" + (worker_version or "")
     entities = get_transcription_entities(
         transcription_id=transcription_id,
-        entity_worker_version=worker_version,
+        entity_worker_versions=[worker_version],
         supported_types=supported_types,
     )
 
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 46721185..29132e86 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -254,10 +254,10 @@ def test_extract(
         # Keep the whole text
         entity_separators=None,
         tokens=tokens_path if load_entities else None,
-        transcription_worker_version=transcription_entities_worker_version,
-        entity_worker_version=transcription_entities_worker_version
+        transcription_worker_versions=[transcription_entities_worker_version],
+        entity_worker_versions=[transcription_entities_worker_version]
         if load_entities
-        else None,
+        else [],
         keep_spaces=keep_spaces,
         subword_vocab_size=subword_vocab_size,
     )
@@ -414,12 +414,7 @@ def test_extract(
 def test_empty_transcription(allow_empty, mock_database):
     extractor = ArkindexExtractor(
         element_type=["text_line"],
-        output=None,
         entity_separators=None,
-        tokens=None,
-        transcription_worker_version=None,
-        entity_worker_version=None,
-        keep_spaces=False,
         allow_empty=allow_empty,
     )
     element_no_transcription = Element(id="unknown")
@@ -466,7 +461,7 @@ def test_entities_to_xml(mock_database, nestation, xml_output, separators):
             text=transcription.text,
             predictions=get_transcription_entities(
                 transcription_id="tr-with-entities",
-                entity_worker_version=nestation,
+                entity_worker_versions=[nestation],
                 supported_types=["name", "fullname", "person", "adj"],
             ),
             entity_separators=separators,
@@ -501,7 +496,7 @@ def test_entities_to_xml_partial_entities(
             text=transcription.text,
             predictions=get_transcription_entities(
                 transcription_id="tr-with-entities",
-                entity_worker_version="non-nested-id",
+                entity_worker_versions=["non-nested-id"],
                 supported_types=supported_entities,
             ),
             entity_separators=separators,
-- 
GitLab