Skip to content
Snippets Groups Projects
Commit 9aae0c16 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'extract-several-worker-version' into 'main'

Support multiple worker_version for dataset extraction

Closes #250

See merge request !345
parents baeeddeb fbdce150
No related branches found
No related tags found
1 merge request!345Support multiple worker_version for dataset extraction
......@@ -93,16 +93,18 @@ def add_extract_parser(subcommands) -> None:
)
parser.add_argument(
"--transcription-worker-version",
"--transcription-worker-versions",
type=parse_worker_version,
nargs="+",
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=[],
)
parser.add_argument(
"--entity-worker-version",
"--entity-worker-versions",
type=parse_worker_version,
nargs="+",
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=[],
)
parser.add_argument(
......
......@@ -56,8 +56,8 @@ class ArkindexExtractor:
entity_separators: List[str] = ["\n", " "],
unknown_token: str = "",
tokens: Path | None = None,
transcription_worker_version: str | bool | None = None,
entity_worker_version: str | bool | None = None,
transcription_worker_versions: List[str | bool] = [],
entity_worker_versions: List[str | bool] = [],
keep_spaces: bool = False,
allow_empty: bool = False,
subword_vocab_size: int = 1000,
......@@ -68,8 +68,8 @@ class ArkindexExtractor:
self.entity_separators = entity_separators
self.unknown_token = unknown_token
self.tokens = parse_tokens(tokens) if tokens else {}
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.transcription_worker_versions = transcription_worker_versions
self.entity_worker_versions = entity_worker_versions
self.allow_empty = allow_empty
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
......@@ -98,7 +98,7 @@ class ArkindexExtractor:
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = get_transcriptions(
element.id, self.transcription_worker_version
element.id, self.transcription_worker_versions
)
if len(transcriptions) == 0:
if self.allow_empty:
......@@ -112,7 +112,7 @@ class ArkindexExtractor:
entities = get_transcription_entities(
transcription.id,
self.entity_worker_version,
self.entity_worker_versions,
supported_types=list(self.tokens),
)
......@@ -319,8 +319,8 @@ def run(
entity_separators: List[str],
unknown_token: str,
tokens: Path,
transcription_worker_version: str | bool | None,
entity_worker_version: str | bool | None,
transcription_worker_versions: List[str | bool],
entity_worker_versions: List[str | bool],
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
......@@ -338,8 +338,8 @@ def run(
entity_separators=entity_separators,
unknown_token=unknown_token,
tokens=tokens,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
transcription_worker_versions=transcription_worker_versions,
entity_worker_versions=entity_worker_versions,
keep_spaces=keep_spaces,
allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size,
......
......@@ -51,18 +51,22 @@ def get_elements(
return query
def build_worker_version_filter(ArkindexModel, worker_version):
def build_worker_version_filter(ArkindexModel, worker_versions: List[str | bool]):
"""
`False` worker version means `manual` worker_version -> null field.
"""
if worker_version:
return ArkindexModel.worker_version == worker_version
else:
return ArkindexModel.worker_version.is_null()
condition = None
for worker_version in worker_versions:
condition |= (
ArkindexModel.worker_version == worker_version
if worker_version
else ArkindexModel.worker_version.is_null()
)
return condition
def get_transcriptions(
element_id: str, transcription_worker_version: str | bool
element_id: str, transcription_worker_versions: List[str | bool]
) -> List[Transcription]:
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
......@@ -71,10 +75,10 @@ def get_transcriptions(
Transcription.id, Transcription.text, Transcription.worker_version
).where((Transcription.element == element_id))
if transcription_worker_version is not None:
if transcription_worker_versions:
query = query.where(
build_worker_version_filter(
Transcription, worker_version=transcription_worker_version
Transcription, worker_versions=transcription_worker_versions
)
)
return query
......@@ -82,7 +86,7 @@ def get_transcriptions(
def get_transcription_entities(
transcription_id: str,
entity_worker_version: str | bool | None,
entity_worker_versions: List[str | bool],
supported_types: List[str],
) -> List[TranscriptionEntity]:
"""
......@@ -104,10 +108,10 @@ def get_transcription_entities(
)
)
if entity_worker_version is not None:
if entity_worker_versions:
query = query.where(
build_worker_version_filter(
TranscriptionEntity, worker_version=entity_worker_version
TranscriptionEntity, worker_versions=entity_worker_versions
)
)
......
......@@ -8,20 +8,20 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
- Store the set of characters encountered in the dataset (in the `charset.pkl` file),
- Generate the resources needed to build a n-gram language model at character, subword or word-level with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder).
| Parameter | Description | Type | Default |
| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | |
| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--output` | Folder where the data will be generated. | `pathlib.Path` | |
| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | |
| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `⁇` |
| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` |
| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` |
| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` |
| Parameter | Description | Type | Default |
| --------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | |
| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--output` | Folder where the data will be generated. | `pathlib.Path` | |
| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | |
| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `⁇` |
| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | |
| `--transcription-worker-versions` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-versions` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` |
| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` |
| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` |
The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md).
......
......@@ -53,19 +53,22 @@ def test_get_elements(mock_database):
]
@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
def test_get_transcriptions(worker_version, mock_database):
@pytest.mark.parametrize(
"worker_versions",
([False], ["worker_version_id"], [], [False, "worker_version_id"]),
)
def test_get_transcriptions(worker_versions, mock_database):
"""
Assert transcriptions retrieval output against verified results
"""
element_id = "train-page_1-line_1"
transcriptions = get_transcriptions(
element_id=element_id,
transcription_worker_version=worker_version,
transcription_worker_versions=worker_versions,
)
expected_transcriptions = []
if worker_version in [False, None]:
if not worker_versions or False in worker_versions:
expected_transcriptions.append(
{
"text": "Caillet Maurice 28.9.06",
......@@ -73,7 +76,7 @@ def test_get_transcriptions(worker_version, mock_database):
}
)
if worker_version in ["worker_version_id", None]:
if not worker_versions or "worker_version_id" in worker_versions:
expected_transcriptions.append(
{
"text": "caillet maurice 28.9.06",
......@@ -106,7 +109,7 @@ def test_get_transcription_entities(worker_version, mock_database, supported_typ
transcription_id = "train-page_1-line_1" + (worker_version or "")
entities = get_transcription_entities(
transcription_id=transcription_id,
entity_worker_version=worker_version,
entity_worker_versions=[worker_version],
supported_types=supported_types,
)
......
......@@ -254,10 +254,10 @@ def test_extract(
# Keep the whole text
entity_separators=None,
tokens=tokens_path if load_entities else None,
transcription_worker_version=transcription_entities_worker_version,
entity_worker_version=transcription_entities_worker_version
transcription_worker_versions=[transcription_entities_worker_version],
entity_worker_versions=[transcription_entities_worker_version]
if load_entities
else None,
else [],
keep_spaces=keep_spaces,
subword_vocab_size=subword_vocab_size,
)
......@@ -414,12 +414,7 @@ def test_extract(
def test_empty_transcription(allow_empty, mock_database):
extractor = ArkindexExtractor(
element_type=["text_line"],
output=None,
entity_separators=None,
tokens=None,
transcription_worker_version=None,
entity_worker_version=None,
keep_spaces=False,
allow_empty=allow_empty,
)
element_no_transcription = Element(id="unknown")
......@@ -466,7 +461,7 @@ def test_entities_to_xml(mock_database, nestation, xml_output, separators):
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_version=nestation,
entity_worker_versions=[nestation],
supported_types=["name", "fullname", "person", "adj"],
),
entity_separators=separators,
......@@ -501,7 +496,7 @@ def test_entities_to_xml_partial_entities(
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_version="non-nested-id",
entity_worker_versions=["non-nested-id"],
supported_types=supported_entities,
),
entity_separators=separators,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment