Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Subproject commit 525c1a9e6d5a33075669085148247e2604dd092f
-e ./nerval
albumentations==1.3.1
arkindex-export==0.1.9
boto3==1.26.124
editdistance==0.6.2
flashlight-text==0.0.4
imageio==2.26.1
imagesize==1.4.1
......@@ -9,7 +9,6 @@ lxml==4.9.3
mdutils==1.6.0
nltk==3.8.1
numpy==1.24.3
prettytable==3.8.0
PyYAML==6.0
scipy==1.10.1
sentencepiece==0.1.99
......
......@@ -54,4 +54,6 @@ setup(
"docs": parse_requirements("doc-requirements.txt"),
"mlflow": parse_requirements("mlflow-requirements.txt"),
},
license="MIT",
license_files=("LICENSE",),
)
......@@ -21,7 +21,7 @@ from arkindex_export import (
WorkerVersion,
database,
)
from dan.datasets.extract.arkindex import SPLIT_NAMES
from dan.datasets.extract.arkindex import TEST_NAME, TRAIN_NAME, VAL_NAME
from tests import FIXTURES
......@@ -181,15 +181,16 @@ def mock_database(tmp_path_factory):
)
# Create dataset
split_names = [VAL_NAME, TEST_NAME, TRAIN_NAME]
dataset = Dataset.create(
id="dataset_id",
name="Dataset",
state="complete",
sets=",".join(SPLIT_NAMES),
sets=",".join(split_names),
)
# Create dataset elements
for split in SPLIT_NAMES:
for split in split_names:
element_path = (FIXTURES / "extraction" / "elements" / split).with_suffix(
".json"
)
......
Source diff could not be displayed: it is stored in LFS. Options to address this: view the blob.
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) |
|:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:|
| train | 1.3023 | 1.3023 | 1.0 | 1.0 | 1.0 |
| val | 1.2683 | 1.2683 | 1.0 | 1.0 | 1.0 |
| test | 1.1224 | 1.1224 | 1.0 | 1.0 | 1.0 |
#### DAN evaluation
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER |
|:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:|:----:|
| train | 18.89 | 21.05 | 26.67 | 26.67 | 26.67 | 7.14 |
| val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 |
| test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 |
#### Nerval evaluation
##### train
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
| Surname | 2 | 2 | 100.0 | 100.0 | 100.0 | 2 |
| Patron | 2 | 0 | 0.0 | 0.0 | 0 | 1 |
| Operai | 2 | 2 | 100.0 | 100.0 | 100.0 | 2 |
| Louche | 2 | 1 | 50.0 | 50.0 | 50.0 | 2 |
| Koala | 2 | 2 | 100.0 | 100.0 | 100.0 | 2 |
| Firstname | 2 | 2 | 100.0 | 100.0 | 100.0 | 2 |
| Chalumeau | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Batiment | 2 | 2 | 100.0 | 100.0 | 100.0 | 2 |
| All | 15 | 12 | 80.0 | 85.71 | 82.76 | 14 |
##### val
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
| Surname | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Patron | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Operai | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Louche | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Koala | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Firstname | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Chalumeau | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Batiment | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| All | 8 | 6 | 75.0 | 75.0 | 75.0 | 8 |
##### test
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
| Surname | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Louche | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Koala | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Firstname | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| Chalumeau | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Batiment | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 |
| All | 6 | 5 | 83.33 | 83.33 | 83.33 | 6 |
#### 5 worst prediction(s)
| Image name | WER | Alignment between ground truth - prediction |
|:----------------------------------------:|:-----:|:---------------------------------------------------------:|
| 2c242f5c-e979-43c4-b6f2-a6d4815b651d.png | 50.0 | ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331 |
| | | |.||||||||||||||||||||||||.||||.|| |
| | | Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31 |
| 0dfe8bcd-ed0b-453e-bf19-cc697012296e.png | 26.67 | ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle------- |
| | | ||||||||||||||||||||||||.|||||||||||.||.------- |
| | | ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376 |
| ffdec445-7f14-4f5f-be44-68d0844d0df1.png | 14.29 | ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère |
| | | |||||||||||||||||||||||.|||||||||||| |
| | | ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère |
| 0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png | 12.5 | ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ-------12241 |
| | | |||||||||||||||||||||||||||||||||||||||||||||-------||||| |
| | | ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241 |
{
"train": {
"tests/data/prediction/images/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241",
"tests/data/prediction/images/0dfe8bcd-ed0b-453e-bf19-cc697012296e.png": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle"
},
"val": {
"tests/data/prediction/images/2c242f5c-e979-43c4-b6f2-a6d4815b651d.png": "ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331"
},
"test": {
"tests/data/prediction/images/ffdec445-7f14-4f5f-be44-68d0844d0df1.png": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère"
}
}
......@@ -53,19 +53,22 @@ def test_get_elements(mock_database):
]
@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
def test_get_transcriptions(worker_version, mock_database):
@pytest.mark.parametrize(
"worker_versions",
([False], ["worker_version_id"], [], [False, "worker_version_id"]),
)
def test_get_transcriptions(worker_versions, mock_database):
"""
Assert transcriptions retrieval output against verified results
"""
element_id = "train-page_1-line_1"
transcriptions = get_transcriptions(
element_id=element_id,
transcription_worker_version=worker_version,
transcription_worker_versions=worker_versions,
)
expected_transcriptions = []
if worker_version in [False, None]:
if not worker_versions or False in worker_versions:
expected_transcriptions.append(
{
"text": "Caillet Maurice 28.9.06",
......@@ -73,7 +76,7 @@ def test_get_transcriptions(worker_version, mock_database):
}
)
if worker_version in ["worker_version_id", None]:
if not worker_versions or "worker_version_id" in worker_versions:
expected_transcriptions.append(
{
"text": "caillet maurice 28.9.06",
......@@ -106,7 +109,7 @@ def test_get_transcription_entities(worker_version, mock_database, supported_typ
transcription_id = "train-page_1-line_1" + (worker_version or "")
entities = get_transcription_entities(
transcription_id=transcription_id,
entity_worker_version=worker_version,
entity_worker_versions=[worker_version],
supported_types=supported_types,
)
......
......@@ -44,7 +44,7 @@ def test_add_metrics_table_row():
"WER (HTR-NER)",
"NER",
]
assert metrics_table.rows == [["train", 1.3023, 1.0, ""]]
assert metrics_table.rows == [["train", 130.23, 100, ""]]
@pytest.mark.parametrize(
......@@ -52,43 +52,49 @@ def test_add_metrics_table_row():
(
(
{
"nb_chars": 43,
"cer": 1.3023,
"nb_chars_no_token": 43,
"cer_no_token": 1.3023,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_chars": 90,
"cer": 0.1889,
"nb_chars_no_token": 76,
"cer_no_token": 0.2105,
"nb_words": 15,
"wer": 0.2667,
"nb_words_no_punct": 15,
"wer_no_punct": 0.2667,
"nb_words_no_token": 15,
"wer_no_token": 0.2667,
"nb_tokens": 14,
"ner": 0.0714,
"nb_samples": 2,
},
{
"nb_chars": 41,
"cer": 1.2683,
"nb_chars_no_token": 41,
"cer_no_token": 1.2683,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
"nb_chars": 34,
"cer": 0.0882,
"nb_chars_no_token": 26,
"cer_no_token": 0.1154,
"nb_words": 8,
"wer": 0.5,
"nb_words_no_punct": 8,
"wer_no_punct": 0.5,
"nb_words_no_token": 8,
"wer_no_token": 0.5,
"nb_tokens": 8,
"ner": 0.0,
"nb_samples": 1,
},
{
"nb_chars": 49,
"cer": 1.1224,
"nb_chars_no_token": 49,
"cer_no_token": 1.1224,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
"nb_chars": 36,
"cer": 0.0278,
"nb_chars_no_token": 30,
"cer_no_token": 0.0333,
"nb_words": 7,
"wer": 0.1429,
"nb_words_no_punct": 7,
"wer_no_punct": 0.1429,
"nb_words_no_token": 7,
"wer_no_token": 0.1429,
"nb_tokens": 6,
"ner": 0.0,
"nb_samples": 1,
},
),
),
......@@ -97,7 +103,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
evaluate.run(evaluate_config)
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)
# Check that the evaluation results are correct
for split_name, expected_res in zip(
......@@ -106,7 +112,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
filename = (
evaluate_config["training"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_0.yaml"
/ f"predict_training-{split_name}_1685.yaml"
)
with filename.open() as f:
......@@ -123,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[-6:]
last_printed_lines = captured_std.out.split("\n")[10:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
......
......@@ -254,10 +254,10 @@ def test_extract(
# Keep the whole text
entity_separators=None,
tokens=tokens_path if load_entities else None,
transcription_worker_version=transcription_entities_worker_version,
entity_worker_version=transcription_entities_worker_version
transcription_worker_versions=[transcription_entities_worker_version],
entity_worker_versions=[transcription_entities_worker_version]
if load_entities
else None,
else [],
keep_spaces=keep_spaces,
subword_vocab_size=subword_vocab_size,
)
......@@ -414,12 +414,7 @@ def test_extract(
def test_empty_transcription(allow_empty, mock_database):
extractor = ArkindexExtractor(
element_type=["text_line"],
output=None,
entity_separators=None,
tokens=None,
transcription_worker_version=None,
entity_worker_version=None,
keep_spaces=False,
allow_empty=allow_empty,
)
element_no_transcription = Element(id="unknown")
......@@ -466,7 +461,7 @@ def test_entities_to_xml(mock_database, nestation, xml_output, separators):
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_version=nestation,
entity_worker_versions=[nestation],
supported_types=["name", "fullname", "person", "adj"],
),
entity_separators=separators,
......@@ -501,7 +496,7 @@ def test_entities_to_xml_partial_entities(
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_version="non-nested-id",
entity_worker_versions=["non-nested-id"],
supported_types=supported_entities,
),
entity_separators=separators,
......