diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 278ee89250fd1b4a9246f32c5655acb1fec1620e..8ce3c74c00becc11e26f020ad8ebcc2deda17ca4 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -93,16 +93,18 @@ def add_extract_parser(subcommands) -> None: ) parser.add_argument( - "--transcription-worker-version", + "--transcription-worker-versions", type=parse_worker_version, + nargs="+", help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.", - required=False, + default=[], ) parser.add_argument( - "--entity-worker-version", + "--entity-worker-versions", type=parse_worker_version, + nargs="+", help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.", - required=False, + default=[], ) parser.add_argument( diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index e2aa30882248982b4bc8f2fba6fc998d8a7ba170..d8dcc1f6e13a8fb8414244a010c90fbf64db9830 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -56,8 +56,8 @@ class ArkindexExtractor: entity_separators: List[str] = ["\n", " "], unknown_token: str = "â‡", tokens: Path | None = None, - transcription_worker_version: str | bool | None = None, - entity_worker_version: str | bool | None = None, + transcription_worker_versions: List[str | bool] = [], + entity_worker_versions: List[str | bool] = [], keep_spaces: bool = False, allow_empty: bool = False, subword_vocab_size: int = 1000, @@ -68,8 +68,8 @@ class ArkindexExtractor: self.entity_separators = entity_separators self.unknown_token = unknown_token self.tokens = parse_tokens(tokens) if tokens else {} - self.transcription_worker_version = transcription_worker_version - self.entity_worker_version = entity_worker_version + self.transcription_worker_versions = transcription_worker_versions + self.entity_worker_versions = entity_worker_versions self.allow_empty = allow_empty self.mapping = LMTokenMapping() self.keep_spaces = keep_spaces @@ -98,7 +98,7 @@ class ArkindexExtractor: If the entities are needed, they are added to the transcription using tokens. """ transcriptions = get_transcriptions( - element.id, self.transcription_worker_version + element.id, self.transcription_worker_versions ) if len(transcriptions) == 0: if self.allow_empty: @@ -112,7 +112,7 @@ class ArkindexExtractor: entities = get_transcription_entities( transcription.id, - self.entity_worker_version, + self.entity_worker_versions, supported_types=list(self.tokens), ) @@ -319,8 +319,8 @@ def run( entity_separators: List[str], unknown_token: str, tokens: Path, - transcription_worker_version: str | bool | None, - entity_worker_version: str | bool | None, + transcription_worker_versions: List[str | bool], + entity_worker_versions: List[str | bool], keep_spaces: bool, allow_empty: bool, subword_vocab_size: int, @@ -338,8 +338,8 @@ def run( entity_separators=entity_separators, unknown_token=unknown_token, tokens=tokens, - transcription_worker_version=transcription_worker_version, - entity_worker_version=entity_worker_version, + transcription_worker_versions=transcription_worker_versions, + entity_worker_versions=entity_worker_versions, keep_spaces=keep_spaces, allow_empty=allow_empty, subword_vocab_size=subword_vocab_size, diff --git a/dan/datasets/extract/db.py b/dan/datasets/extract/db.py index 251467995c2cc790b24946b09820bcc21acecd49..ce9e9a8753bc42d3be3c0ac0a4a011f1d892dd5c 100644 --- a/dan/datasets/extract/db.py +++ b/dan/datasets/extract/db.py @@ -51,18 +51,22 @@ def get_elements( return query -def build_worker_version_filter(ArkindexModel, worker_version): +def build_worker_version_filter(ArkindexModel, worker_versions: List[str | bool]): """ `False` worker version means `manual` worker_version -> null field. """ - if worker_version: - return ArkindexModel.worker_version == worker_version - else: - return ArkindexModel.worker_version.is_null() + condition = None + for worker_version in worker_versions: + condition |= ( + ArkindexModel.worker_version == worker_version + if worker_version + else ArkindexModel.worker_version.is_null() + ) + return condition def get_transcriptions( - element_id: str, transcription_worker_version: str | bool + element_id: str, transcription_worker_versions: List[str | bool] ) -> List[Transcription]: """ Retrieve transcriptions from an SQLite export of an Arkindex corpus @@ -71,10 +75,10 @@ def get_transcriptions( Transcription.id, Transcription.text, Transcription.worker_version ).where((Transcription.element == element_id)) - if transcription_worker_version is not None: + if transcription_worker_versions: query = query.where( build_worker_version_filter( - Transcription, worker_version=transcription_worker_version + Transcription, worker_versions=transcription_worker_versions ) ) return query @@ -82,7 +86,7 @@ def get_transcriptions( def get_transcription_entities( transcription_id: str, - entity_worker_version: str | bool | None, + entity_worker_versions: List[str | bool], supported_types: List[str], ) -> List[TranscriptionEntity]: """ @@ -104,10 +108,10 @@ def get_transcription_entities( ) ) - if entity_worker_version is not None: + if entity_worker_versions: query = query.where( build_worker_version_filter( - TranscriptionEntity, worker_version=entity_worker_version + TranscriptionEntity, worker_versions=entity_worker_versions ) ) diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index f3f5f05c4f6193adc5510dfe583a43717bbc00da..c098824b01d25708bd03ea213187632b342124b6 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -8,20 +8,20 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind - Store the set of characters encountered in the dataset (in the `charset.pkl` file), - Generate the resources needed to build a n-gram language model at character, subword or word-level with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder). -| Parameter | Description | Type | Default | -| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- | -| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | | -| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | | -| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | -| `--output` | Folder where the data will be generated. | `pathlib.Path` | | -| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | | -| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | -| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | -| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | -| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | -| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | -| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` | -| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | +| Parameter | Description | Type | Default | +| --------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- | +| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | | +| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | | +| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | +| `--output` | Folder where the data will be generated. | `pathlib.Path` | | +| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | | +| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | +| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | +| `--transcription-worker-versions` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | +| `--entity-worker-versions` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | +| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | +| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` | +| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md). diff --git a/tests/test_db.py b/tests/test_db.py index 0da81d8e99ae3e8123df48ca73b3266a80e7c223..ba3058148c84704ff08e1f73fb935c2a0a0462b4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -53,19 +53,22 @@ def test_get_elements(mock_database): ] -@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None)) -def test_get_transcriptions(worker_version, mock_database): +@pytest.mark.parametrize( + "worker_versions", + ([False], ["worker_version_id"], [], [False, "worker_version_id"]), +) +def test_get_transcriptions(worker_versions, mock_database): """ Assert transcriptions retrieval output against verified results """ element_id = "train-page_1-line_1" transcriptions = get_transcriptions( element_id=element_id, - transcription_worker_version=worker_version, + transcription_worker_versions=worker_versions, ) expected_transcriptions = [] - if worker_version in [False, None]: + if not worker_versions or False in worker_versions: expected_transcriptions.append( { "text": "Caillet Maurice 28.9.06", @@ -73,7 +76,7 @@ def test_get_transcriptions(worker_version, mock_database): } ) - if worker_version in ["worker_version_id", None]: + if not worker_versions or "worker_version_id" in worker_versions: expected_transcriptions.append( { "text": "caillet maurice 28.9.06", @@ -106,7 +109,7 @@ def test_get_transcription_entities(worker_version, mock_database, supported_typ transcription_id = "train-page_1-line_1" + (worker_version or "") entities = get_transcription_entities( transcription_id=transcription_id, - entity_worker_version=worker_version, + entity_worker_versions=[worker_version], supported_types=supported_types, ) diff --git a/tests/test_extract.py b/tests/test_extract.py index 4672118538041285c81238a8c5dd942023fd5573..29132e863278ec7aaf3dc9a43cff96e4c7e7a8b0 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -254,10 +254,10 @@ def test_extract( # Keep the whole text entity_separators=None, tokens=tokens_path if load_entities else None, - transcription_worker_version=transcription_entities_worker_version, - entity_worker_version=transcription_entities_worker_version + transcription_worker_versions=[transcription_entities_worker_version], + entity_worker_versions=[transcription_entities_worker_version] if load_entities - else None, + else [], keep_spaces=keep_spaces, subword_vocab_size=subword_vocab_size, ) @@ -414,12 +414,7 @@ def test_extract( def test_empty_transcription(allow_empty, mock_database): extractor = ArkindexExtractor( element_type=["text_line"], - output=None, entity_separators=None, - tokens=None, - transcription_worker_version=None, - entity_worker_version=None, - keep_spaces=False, allow_empty=allow_empty, ) element_no_transcription = Element(id="unknown") @@ -466,7 +461,7 @@ def test_entities_to_xml(mock_database, nestation, xml_output, separators): text=transcription.text, predictions=get_transcription_entities( transcription_id="tr-with-entities", - entity_worker_version=nestation, + entity_worker_versions=[nestation], supported_types=["name", "fullname", "person", "adj"], ), entity_separators=separators, @@ -501,7 +496,7 @@ def test_entities_to_xml_partial_entities( text=transcription.text, predictions=get_transcription_entities( transcription_id="tr-with-entities", - entity_worker_version="non-nested-id", + entity_worker_versions=["non-nested-id"], supported_types=supported_entities, ), entity_separators=separators,