diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 534ab1d9ad827853b990eb21181191190194d9b0..87522faefdd3fd798d2c98fbffd85c4c0c0330bf 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -15,16 +15,16 @@ from dan.datasets.extract.arkindex import run MANUAL_SOURCE = "manual" -def parse_worker_version(worker_version_id) -> str | bool: - if worker_version_id == MANUAL_SOURCE: +def parse_source(source) -> str | bool: + if source == MANUAL_SOURCE: return False try: - UUID(worker_version_id) + UUID(source) except ValueError: - raise argparse.ArgumentTypeError(f"`{worker_version_id}` is not a valid UUID.") + raise argparse.ArgumentTypeError(f"`{source}` is not a valid UUID.") - return worker_version_id + return source def validate_char(char): @@ -97,18 +97,32 @@ def add_extract_parser(subcommands) -> None: parser.add_argument( "--transcription-worker-versions", - type=parse_worker_version, + type=parse_source, nargs="+", help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.", default=[], ) parser.add_argument( "--entity-worker-versions", - type=parse_worker_version, + type=parse_source, nargs="+", help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.", default=[], ) + parser.add_argument( + "--transcription-worker-runs", + type=parse_source, + nargs="+", + help=f"Filter transcriptions by worker_run. Use {MANUAL_SOURCE} for manual filtering.", + default=[], + ) + parser.add_argument( + "--entity-worker-runs", + type=parse_source, + nargs="+", + help=f"Filter transcriptions entities by worker_run. Use {MANUAL_SOURCE} for manual filtering.", + default=[], + ) parser.add_argument( "--subword-vocab-size", diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index a74299325f1ebc1a9696c7b59fb3167f06f4f86b..dce5b24401136ce39eae8fad975f45495fe8de52 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -61,6 +61,8 @@ class ArkindexExtractor: tokens: Path | None = None, transcription_worker_versions: List[str | bool] = [], entity_worker_versions: List[str | bool] = [], + transcription_worker_runs: List[str | bool] = [], + entity_worker_runs: List[str | bool] = [], keep_spaces: bool = False, allow_empty: bool = False, subword_vocab_size: int = 1000, @@ -73,6 +75,8 @@ class ArkindexExtractor: self.tokens = parse_tokens(tokens) if tokens else {} self.transcription_worker_versions = transcription_worker_versions self.entity_worker_versions = entity_worker_versions + self.transcription_worker_runs = transcription_worker_runs + self.entity_worker_runs = entity_worker_runs self.allow_empty = allow_empty self.mapping = LMTokenMapping() self.keep_spaces = keep_spaces @@ -101,7 +105,9 @@ class ArkindexExtractor: If the entities are needed, they are added to the transcription using tokens. """ transcriptions = get_transcriptions( - element.id, self.transcription_worker_versions + element.id, + self.transcription_worker_versions, + self.transcription_worker_runs, ) if len(transcriptions) == 0: if self.allow_empty: @@ -117,6 +123,7 @@ class ArkindexExtractor: entities = get_transcription_entities( transcription.id, self.entity_worker_versions, + self.entity_worker_runs, supported_types=list(self.tokens), ) @@ -328,6 +335,8 @@ def run( tokens: Path, transcription_worker_versions: List[str | bool], entity_worker_versions: List[str | bool], + transcription_worker_runs: List[str | bool], + entity_worker_runs: List[str | bool], keep_spaces: bool, allow_empty: bool, subword_vocab_size: int, @@ -347,6 +356,8 @@ def run( tokens=tokens, transcription_worker_versions=transcription_worker_versions, entity_worker_versions=entity_worker_versions, + transcription_worker_runs=transcription_worker_runs, + entity_worker_runs=entity_worker_runs, keep_spaces=keep_spaces, allow_empty=allow_empty, subword_vocab_size=subword_vocab_size, diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index 13408f5ac7be1558b3ff0e64c5ce91dfcd620cc2..dfb29191a7adf9889b4ae635b99f3ee4308c6282 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -68,14 +68,33 @@ def build_worker_version_filter(ArkindexModel, worker_versions: List[str | bool] return condition +def build_worker_run_filter(ArkindexModel, worker_runs: List[str | bool]): + """ + `False` worker run means `manual` worker_run -> null field. + """ + condition = None + for worker_run in worker_runs: + condition |= ( + ArkindexModel.worker_run == worker_run + if worker_run + else ArkindexModel.worker_run.is_null() + ) + return condition + + def get_transcriptions( - element_id: str, transcription_worker_versions: List[str | bool] + element_id: str, + transcription_worker_versions: List[str | bool], + transcription_worker_runs: List[str | bool], ) -> List[Transcription]: """ Retrieve transcriptions from an SQLite export of an Arkindex corpus """ query = Transcription.select( - Transcription.id, Transcription.text, Transcription.worker_version + Transcription.id, + Transcription.text, + Transcription.worker_version, + Transcription.worker_run, ).where((Transcription.element == element_id)) if transcription_worker_versions: @@ -84,12 +103,21 @@ def get_transcriptions( Transcription, worker_versions=transcription_worker_versions ) ) + + if transcription_worker_runs: + query = query.where( + build_worker_run_filter( + Transcription, worker_runs=transcription_worker_runs + ) + ) + return query def get_transcription_entities( transcription_id: str, entity_worker_versions: List[str | bool], + entity_worker_runs: List[str | bool], supported_types: List[str], ) -> List[TranscriptionEntity]: """ @@ -102,6 +130,7 @@ def get_transcription_entities( TranscriptionEntity.offset, TranscriptionEntity.length, TranscriptionEntity.worker_version, + TranscriptionEntity.worker_run, ) .join(Entity, on=TranscriptionEntity.entity) .join(EntityType, on=Entity.type) @@ -118,6 +147,11 @@ def get_transcription_entities( ) ) + if entity_worker_runs: + query = query.where( + build_worker_run_filter(TranscriptionEntity, worker_runs=entity_worker_runs) + ) + return query.order_by( TranscriptionEntity.offset, TranscriptionEntity.length.desc() ).dicts() diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 4ddb9fde61c01204cb304a5bec56df16825f06a5..b000b4794dc5346eaa7dc510200d61d9969bcd89 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -209,6 +209,7 @@ class XMLEntity: offset: int length: int worker_version: str + worker_run: str children: List["XMLEntity"] = field(default_factory=list) @property @@ -223,6 +224,7 @@ class XMLEntity: offset=child["offset"] - self.offset, length=child["length"], worker_version=child["worker_version"], + worker_run=child["worker_run"], ) ) diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index c098824b01d25708bd03ea213187632b342124b6..cf0f012401a1d654034ed2b5007aa45f0aec534e 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -19,6 +19,8 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind | `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | | `--transcription-worker-versions` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | | `--entity-worker-versions` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | +| `--transcription-worker-runs` | Filter transcriptions by worker_runs. Use `manual` for manual filtering. | `str` or `uuid` | | +| `--entity-worker-runs` | Filter transcriptions entities by worker_runs. Use `manual` for manual filtering | `str` or `uuid` | | | `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | | `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` | | `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | diff --git a/tests/conftest.py b/tests/conftest.py index 3a4e1557e5a0f8d7fe14a5435bc79aebbc6a4de7..50f1348c60268206f780a8188d9754429ebd73e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,7 @@ def mock_database(tmp_path_factory): def create_transcription_entity( transcription: Transcription, worker_version: str | None, + worker_run: str | None, type: str, name: str, offset: int, @@ -45,6 +46,7 @@ def mock_database(tmp_path_factory): name=name, type=entity_type, worker_version=worker_version, + worker_run=worker_run, ) TranscriptionEntity.create( entity=entity, @@ -52,6 +54,7 @@ def mock_database(tmp_path_factory): offset=offset, transcription=transcription, worker_version=worker_version, + worker_run=worker_run, ) def create_transcriptions(element: Element, entities: List[dict]) -> None: @@ -65,24 +68,31 @@ def mock_database(tmp_path_factory): for offset, entity in enumerate(entities[1:], start=1): entity["offset"] += offset - for worker_version in [None, "worker_version_id"]: + for worker_version, worker_run in [ + (None, None), + ("worker_version_id", "worker_run_id"), + ]: + transcription_suffix = "" # Use different transcriptions to filter by worker version - if worker_version == "worker_version_id": + if worker_version and worker_run: + transcription_suffix = "source" for entity in entities: entity["name"] = entity["name"].lower() transcription = Transcription.create( - id=element.id + (worker_version or ""), + id=element.id + transcription_suffix, # Add extra spaces to test the "keep_spaces" parameters of the "extract" command text=" ".join(map(itemgetter("name"), entities)), element=element, worker_version=worker_version, + worker_run=worker_run, ) for entity in entities: create_transcription_entity( transcription=transcription, worker_version=worker_version, + worker_run=worker_run, **entity, ) @@ -183,6 +193,11 @@ def mock_database(tmp_path_factory): type="worker", ) + WorkerRun.create( + id="worker_run_id", + worker_version="worker_version_id", + ) + # Create dataset split_names = [VAL_NAME, TEST_NAME, TRAIN_NAME] dataset = Dataset.create( @@ -216,82 +231,88 @@ def mock_database(tmp_path_factory): element=Element.select().first(), ) - # Create worker version - WorkerVersion.bulk_create( - [ - WorkerVersion( - id=f"{nestation}-id", - slug=nestation, - name=nestation, - repository_url="http://repository/url", - revision="main", - type="worker", - ) - for nestation in ("nested", "non-nested", "special-chars") - ] - ) + # Create worker version and worker run + for nestation in ("nested", "non-nested", "special-chars"): + WorkerVersion.create( + id=f"worker-version-{nestation}-id", + slug=nestation, + name=nestation, + repository_url="http://repository/url", + revision="main", + type="worker", + ) + WorkerRun.create( + id=f"worker-run-{nestation}-id", + worker_version=f"worker-version-{nestation}-id", + ) # Create entities for entity in [ # Non-nested entities { - "worker_version": "non-nested-id", + "source": "non-nested", "type": "adj", "name": "great", "offset": 4, }, { - "worker_version": "non-nested-id", + "source": "non-nested", "type": "name", "name": "Charles", "offset": 15, }, { - "worker_version": "non-nested-id", + "source": "non-nested", "type": "person", "name": "us", "offset": 43, }, # Nested entities { - "worker_version": "nested-id", + "source": "nested", "type": "fullname", "name": "Charles III", "offset": 15, }, { - "worker_version": "nested-id", + "source": "nested", "type": "name", "name": "Charles", "offset": 15, }, { - "worker_version": "nested-id", + "source": "nested", "type": "person", "name": "us", "offset": 43, }, # Special characters { - "worker_version": "special-chars-id", + "source": "special-chars", "type": "Arkindex's entity", "name": "great", "offset": 4, }, { - "worker_version": "special-chars-id", + "source": "special-chars", "type": '"Name" (1)', "name": "Charles", "offset": 15, }, { - "worker_version": "special-chars-id", + "source": "special-chars", "type": "Person /!\\", "name": "us", "offset": 43, }, ]: - create_transcription_entity(transcription=transcription, **entity) + source = entity.pop("source") + create_transcription_entity( + transcription=transcription, + worker_version=f"worker-version-{source}-id", + worker_run=f"worker-run-{source}-id", + **entity, + ) return database_path diff --git a/tests/test_db.py b/tests/test_db.py index d7ab5c798090bcbf4cd33ca0911476dfd81c1dff..875add45c44d3e395296c5c9f0f5aad90fdf38e4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -57,33 +57,45 @@ def test_get_elements(mock_database): @pytest.mark.parametrize( - "worker_versions", - ([False], ["worker_version_id"], [], [False, "worker_version_id"]), + "sources", + ([False], ["id"], [], [False, "id"]), ) -def test_get_transcriptions(worker_versions, mock_database): +def test_get_transcriptions(sources, mock_database): """ Assert transcriptions retrieval output against verified results """ + worker_versions = [ + f"worker_version_{source}" if isinstance(source, str) else source + for source in sources + ] + worker_runs = [ + f"worker_run_{source}" if isinstance(source, str) else source + for source in sources + ] + element_id = "train-page_1-line_1" transcriptions = get_transcriptions( element_id=element_id, transcription_worker_versions=worker_versions, + transcription_worker_runs=worker_runs, ) expected_transcriptions = [] - if not worker_versions or False in worker_versions: + if not sources or False in sources: expected_transcriptions.append( { "text": "Caillet Maurice 28.9.06", "worker_version_id": None, + "worker_run_id": None, } ) - if not worker_versions or "worker_version_id" in worker_versions: + if not sources or "id" in sources: expected_transcriptions.append( { "text": "caillet maurice 28.9.06", "worker_version_id": "worker_version_id", + "worker_run_id": "worker_run_id", } ) @@ -95,6 +107,9 @@ def test_get_transcriptions(worker_versions, mock_database): "worker_version_id": transcription.worker_version.id if transcription.worker_version else None, + "worker_run_id": transcription.worker_run.id + if transcription.worker_run + else None, } for transcription in transcriptions ], @@ -104,15 +119,19 @@ def test_get_transcriptions(worker_versions, mock_database): ) -@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None)) +@pytest.mark.parametrize("source", (False, "id", None)) @pytest.mark.parametrize( "supported_types", (["surname"], ["surname", "firstname", "birthdate"]) ) -def test_get_transcription_entities(worker_version, mock_database, supported_types): - transcription_id = "train-page_1-line_1" + (worker_version or "") +def test_get_transcription_entities(source, mock_database, supported_types): + worker_version = f"worker_version_{source}" if isinstance(source, str) else source + worker_run = f"worker_run_{source}" if isinstance(source, str) else source + + transcription_id = "train-page_1-line_1" + ("source" if source else "") entities = get_transcription_entities( transcription_id=transcription_id, entity_worker_versions=[worker_version], + entity_worker_runs=[worker_run], supported_types=supported_types, ) @@ -141,9 +160,10 @@ def test_get_transcription_entities(worker_version, mock_database, supported_typ filter(lambda ent: ent["type"] in supported_types, expected_entities) ) for entity in expected_entities: - if worker_version: + if source: entity["name"] = entity["name"].lower() entity["worker_version"] = worker_version or None + entity["worker_run"] = worker_run or None assert ( sorted( diff --git a/tests/test_extract.py b/tests/test_extract.py index 92143b0d706209faa1daba246ab52ae1347891d5..54e03f890416671b89efe0cb36875faf50c93a97 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -459,31 +459,31 @@ def test_extract_transcription_no_translation(mock_database, tokens): ( # Non-nested ( - "non-nested-id", + "non-nested", "<root>The <adj>great</adj> king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>", None, ), # Non-nested no text between entities ( - "non-nested-id", + "non-nested", "<root><adj>great</adj> <name>Charles</name>\n<person>us</person></root>", ["\n", " "], ), # Nested ( - "nested-id", + "nested", "<root>The great king <fullname><name>Charles</name> III</fullname> has eaten \nwith <person>us</person>.</root>", None, ), # Nested no text between entities ( - "nested-id", + "nested", "<root><fullname><name>Charles</name> III</fullname>\n<person>us</person></root>", ["\n", " "], ), # Special characters in entities ( - "special-chars-id", + "special-chars", "<root>The <Arkindex_s_entity>great</Arkindex_s_entity> king <_Name_1_>Charles</_Name_1_> III has eaten \nwith <Person_>us</Person_>.</root>", None, ), @@ -496,7 +496,8 @@ def test_entities_to_xml(mock_database, nestation, xml_output, separators): text=transcription.text, predictions=get_transcription_entities( transcription_id="tr-with-entities", - entity_worker_versions=[nestation], + entity_worker_versions=[f"worker-version-{nestation}-id"], + entity_worker_runs=[f"worker-run-{nestation}-id"], supported_types=[ "name", "fullname", @@ -539,7 +540,8 @@ def test_entities_to_xml_partial_entities( text=transcription.text, predictions=get_transcription_entities( transcription_id="tr-with-entities", - entity_worker_versions=["non-nested-id"], + entity_worker_versions=["worker-version-non-nested-id"], + entity_worker_runs=["worker-run-non-nested-id"], supported_types=supported_entities, ), entity_separators=separators,