Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
......@@ -2,11 +2,13 @@
import json
import uuid
from operator import itemgetter
from typing import List, Optional, Union
from typing import List
import pytest
from arkindex_export import (
Dataset,
DatasetElement,
Element,
ElementPath,
Entity,
......@@ -19,14 +21,15 @@ from arkindex_export import (
WorkerVersion,
database,
)
from dan.datasets.extract.arkindex import SPLIT_NAMES
from tests import FIXTURES
@pytest.fixture(scope="session")
@pytest.fixture()
def mock_database(tmp_path_factory):
def create_transcription_entity(
transcription: Transcription,
worker_version: Union[str, None],
worker_version: str | None,
type: str,
name: str,
offset: int,
......@@ -80,7 +83,7 @@ def mock_database(tmp_path_factory):
**entity,
)
def create_element(id: str, parent: Optional[Element] = None) -> None:
def create_element(id: str, parent: Element | None = None) -> None:
element_path = (FIXTURES / "extraction" / "elements" / id).with_suffix(".json")
element_json = json.loads(element_path.read_text())
......@@ -133,6 +136,8 @@ def mock_database(tmp_path_factory):
WorkerRun,
ImageServer,
Image,
Dataset,
DatasetElement,
Element,
ElementPath,
EntityType,
......@@ -175,8 +180,29 @@ def mock_database(tmp_path_factory):
type="worker",
)
# Create folders
create_element(id="root")
# Create dataset
dataset = Dataset.create(
id="dataset_id",
name="Dataset",
state="complete",
sets=",".join(SPLIT_NAMES),
)
# Create dataset elements
for split in SPLIT_NAMES:
element_path = (FIXTURES / "extraction" / "elements" / split).with_suffix(
".json"
)
element_json = json.loads(element_path.read_text())
# Recursive function to create children
for child in element_json.get("children", []):
create_element(id=child)
# Linking the element to the dataset split
DatasetElement.create(
id=child, element_id=child, dataset=dataset, set_name=split
)
# Create data for entities extraction tests
# Create transcription
......@@ -257,5 +283,12 @@ def evaluate_config():
@pytest.fixture
def prediction_data_path():
return FIXTURES / "prediction"
def split_content():
splits = json.loads((FIXTURES / "extraction" / "split.json").read_text())
for split in splits:
for element_id in splits[split]:
splits[split][element_id]["image"]["iiif_url"] = splits[split][element_id][
"image"
]["iiif_url"].replace("{FIXTURES}", str(FIXTURES))
return splits
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) |
|:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:|
| train | 1.3023 | 1.3023 | 1.0 | 1.0 | 1.0 |
| val | 1.2683 | 1.2683 | 1.0 | 1.0 | 1.0 |
| test | 1.1224 | 1.1224 | 1.0 | 1.0 | 1.0 |
{
"test": {
"test-page_1-line_1": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_1.jpg",
"polygon": [
[
37,
191
],
[
37,
339
],
[
767,
339
],
[
767,
191
],
[
37,
191
]
]
},
"text": "ⓢCou⁇e⁇ ⓕBouis ⓑ⁇.12.14"
},
"test-page_1-line_2": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_2.jpg",
"polygon": [
[
28,
339
],
[
28,
464
],
[
767,
464
],
[
767,
339
],
[
28,
339
]
]
},
"text": "ⓢ⁇outrain ⓕA⁇ol⁇⁇e ⓑ9.4.13"
},
"test-page_1-line_3": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_1-line_3.jpg",
"polygon": [
[
28,
464
],
[
28,
614
],
[
767,
614
],
[
767,
464
],
[
28,
464
]
]
},
"text": "ⓢ⁇abale ⓕ⁇ran⁇ais ⓑ26.3.11"
},
"test-page_2-line_1": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_1.jpg",
"polygon": [
[
14,
199
],
[
14,
330
],
[
767,
330
],
[
767,
199
],
[
14,
199
]
]
},
"text": "ⓢ⁇urosoy ⓕBouis ⓑ22⁇4⁇18"
},
"test-page_2-line_2": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_2.jpg",
"polygon": [
[
16,
330
],
[
16,
471
],
[
765,
471
],
[
765,
330
],
[
16,
330
]
]
},
"text": "ⓢColaiani ⓕAn⁇els ⓑ28.11.1⁇"
},
"test-page_2-line_3": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/test-page_2-line_3.jpg",
"polygon": [
[
11,
473
],
[
11,
598
],
[
772,
598
],
[
772,
473
],
[
11,
473
]
]
},
"text": "ⓢRenouar⁇ ⓕMaurice ⓑ2⁇.⁇.04"
}
},
"train": {
"train-page_1-line_1": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_1.jpg",
"polygon": [
[
27,
187
],
[
27,
327
],
[
754,
327
],
[
754,
187
],
[
27,
187
]
]
},
"text": "ⓢCaillet ⓕMaurice ⓑ28.9.06"
},
"train-page_1-line_2": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_2.jpg",
"polygon": [
[
28,
328
],
[
28,
465
],
[
755,
465
],
[
755,
328
],
[
28,
328
]
]
},
"text": "ⓢReboul ⓕJean ⓑ30.9.02"
},
"train-page_1-line_3": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_3.jpg",
"polygon": [
[
23,
463
],
[
23,
604
],
[
803,
604
],
[
803,
463
],
[
23,
463
]
]
},
"text": "ⓢBareyre ⓕJean ⓑ28.3.11"
},
"train-page_1-line_4": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_1-line_4.jpg",
"polygon": [
[
21,
604
],
[
21,
743
],
[
812,
743
],
[
812,
604
],
[
21,
604
]
]
},
"text": "ⓢRoussy ⓕJean ⓑ4.11.14"
},
"train-page_2-line_1": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_1.jpg",
"polygon": [
[
18,
197
],
[
18,
340
],
[
751,
340
],
[
751,
197
],
[
18,
197
]
]
},
"text": "ⓢMarin ⓕMarcel ⓑ10.8.06"
},
"train-page_2-line_2": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_2.jpg",
"polygon": [
[
18,
340
],
[
18,
476
],
[
751,
476
],
[
751,
340
],
[
18,
340
]
]
},
"text": "ⓢAmical ⓕEloi ⓑ11.10.04"
},
"train-page_2-line_3": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/train-page_2-line_3.jpg",
"polygon": [
[
21,
476
],
[
21,
615
],
[
746,
615
],
[
746,
476
],
[
21,
476
]
]
},
"text": "ⓢBiros ⓕMael ⓑ30.10.10"
}
},
"val": {
"val-page_1-line_1": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_1.jpg",
"polygon": [
[
14,
211
],
[
14,
347
],
[
755,
347
],
[
755,
211
],
[
14,
211
]
]
},
"text": "ⓢMonar⁇ ⓕBouis ⓑ29⁇⁇⁇04"
},
"val-page_1-line_2": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_2.jpg",
"polygon": [
[
14,
350
],
[
14,
484
],
[
748,
484
],
[
748,
350
],
[
14,
350
]
]
},
"text": "ⓢAstier ⓕArt⁇ur ⓑ11⁇2⁇13"
},
"val-page_1-line_3": {
"dataset_id": "dataset_id",
"image": {
"iiif_url": "{FIXTURES}/extraction/images/text_line/val-page_1-line_3.jpg",
"polygon": [
[
11,
484
],
[
11,
622
],
[
751,
622
],
[
751,
484
],
[
11,
484
]
]
},
"text": "ⓢ⁇e ⁇lie⁇er ⓕJules ⓑ21⁇11⁇11"
}
}
}
import logging
import pytest
from dan.bio import convert
from dan.utils import EntityType
ST_TEXT = """ⒶBryan B ⒷParis ⒸJanuary 1st, 1987
ⒶJoe J ⒷGrenoble ⒸAugust 24, 1995
ⒶHannah H ⒷLille ⒸSeptember 15, 2002"""
ST_ET_TEXT = """ⒶBryanⒷ and ⒶJoeⒷ will visit the ⒸEiffel TowerⒹ in ⒸParisⒹ next ⒺTuesdayⒻ.
ⒶHannahⒷ will visit the ⒸPlace ⒶCharles de GaulleⒷ étoileⒹ on ⒺWednesdayⒻ."""
def test_convert_with_error():
ner_tokens = {
"Person": EntityType(start="", end=""),
"Location": EntityType(start="", end=""),
}
with pytest.raises(
AssertionError, match="Ending token Ⓓ doesn't match the starting token Ⓐ"
):
convert("ⒶFredⒹ", ner_tokens)
def test_convert_with_warnings(caplog):
ner_tokens = {
"Person": EntityType(start="", end=""),
"Location": EntityType(start="", end=""),
}
assert convert("BryanⒷ and ⒶJoeⒷ will visit the Eiffel TowerⒹ", ner_tokens).split(
"\n"
) == [
"Bryan O",
"and O",
"Joe B-Person",
"will O",
"visit O",
"the O",
"Eiffel O",
"Tower O",
]
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(
logging.WARNING,
"Missing starting token for ending token Ⓑ, skipping the entity",
),
(
logging.WARNING,
"Missing starting token for ending token Ⓓ, skipping the entity",
),
]
def test_convert_starting_tokens():
ner_tokens = {
"Person": EntityType(start=""),
"Location": EntityType(start=""),
"Date": EntityType(start=""),
}
assert convert(ST_TEXT, ner_tokens).split("\n") == [
"Bryan B-Person",
"B I-Person",
"Paris B-Location",
"January B-Date",
"1st, I-Date",
"1987 I-Date",
"Joe B-Person",
"J I-Person",
"Grenoble B-Location",
"August B-Date",
"24, I-Date",
"1995 I-Date",
"Hannah B-Person",
"H I-Person",
"Lille B-Location",
"September B-Date",
"15, I-Date",
"2002 I-Date",
]
def test_convert_starting_and_ending_tokens():
ner_tokens = {
"Person": EntityType(start="", end=""),
"Location": EntityType(start="", end=""),
"Date": EntityType(start="", end=""),
}
assert convert(ST_ET_TEXT, ner_tokens).split("\n") == [
"Bryan B-Person",
"and O",
"Joe B-Person",
"will O",
"visit O",
"the O",
"Eiffel B-Location",
"Tower I-Location",
"in O",
"Paris B-Location",
"next O",
"Tuesday B-Date",
". O",
"Hannah B-Person",
"will O",
"visit O",
"the O",
"Place B-Location",
"Charles B-Person",
"de I-Person",
"Gaulle I-Person",
"étoile I-Location",
"on O",
"Wednesday B-Date",
". O",
]
......@@ -4,26 +4,53 @@ from operator import itemgetter
import pytest
from arkindex_export import Dataset, DatasetElement, Element
from dan.datasets.extract.arkindex import TRAIN_NAME
from dan.datasets.extract.db import (
Element,
get_dataset_elements,
get_elements,
get_transcription_entities,
get_transcriptions,
)
def test_get_dataset_elements(mock_database):
"""
Assert dataset elements retrieval output against verified results
"""
dataset_elements = get_dataset_elements(
dataset=Dataset.select().get(),
split=TRAIN_NAME,
)
# ID verification
assert all(
isinstance(dataset_element, DatasetElement)
for dataset_element in dataset_elements
)
assert [dataset_element.element.id for dataset_element in dataset_elements] == [
"train-page_1",
"train-page_2",
]
def test_get_elements(mock_database):
"""
Assert elements retrieval output against verified results
"""
elements = get_elements(
parent_id="train",
element_type=["double_page"],
parent_id="train-page_1",
element_type=["text_line"],
)
# ID verification
assert all(isinstance(element, Element) for element in elements)
assert [element.id for element in elements] == ["train-page_1", "train-page_2"]
assert [element.id for element in elements] == [
"train-page_1-line_1",
"train-page_1-line_2",
"train-page_1-line_3",
"train-page_1-line_4",
]
@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
......
# -*- coding: utf-8 -*-
import json
import logging
from operator import attrgetter, methodcaller
from pathlib import Path
import pytest
from PIL import Image, ImageChops
from dan.datasets.download.images import IIIF_FULL_SIZE, ImageDownloader
from dan.datasets.download.utils import download_image
from line_image_extractor.image_utils import BoundingBox
from tests import FIXTURES
EXTRACTION_DATA_PATH = FIXTURES / "extraction"
@pytest.mark.parametrize(
"max_width, max_height, width, height, resize",
(
(1000, 2000, 900, 800, IIIF_FULL_SIZE),
(1000, 2000, 1100, 800, "1000,"),
(1000, 2000, 1100, 2800, ",2000"),
(1000, 2000, 2000, 3000, "1000,"),
),
)
def test_get_iiif_size_arg(max_width, max_height, width, height, resize):
assert (
ImageDownloader(max_width=max_width, max_height=max_height).get_iiif_size_arg(
width=width, height=height
)
== resize
)
def test_download(split_content, monkeypatch, tmp_path):
# Mock download_image so that it simply opens it with Pillow
monkeypatch.setattr(
"dan.datasets.download.images.download_image", lambda url: Image.open(url)
)
output = tmp_path / "download"
output.mkdir(parents=True, exist_ok=True)
(output / "split.json").write_text(json.dumps(split_content))
def mock_build_image_url(polygon, image_url, *args, **kwargs):
# During tests, the image URL is its local path
return image_url
extractor = ImageDownloader(
output=output,
image_extension=".jpg",
)
# Mock build_image_url to simply return the path to the image
extractor.build_iiif_url = mock_build_image_url
extractor.run()
# Check files
IMAGE_DIR = output / "images"
TEST_DIR = IMAGE_DIR / "test" / "dataset_id"
TRAIN_DIR = IMAGE_DIR / "train" / "dataset_id"
VAL_DIR = IMAGE_DIR / "val" / "dataset_id"
expected_paths = [
# Images of test folder
TEST_DIR / "test-page_1-line_1.jpg",
TEST_DIR / "test-page_1-line_2.jpg",
TEST_DIR / "test-page_1-line_3.jpg",
TEST_DIR / "test-page_2-line_1.jpg",
TEST_DIR / "test-page_2-line_2.jpg",
TEST_DIR / "test-page_2-line_3.jpg",
# Images of train folder
TRAIN_DIR / "train-page_1-line_1.jpg",
TRAIN_DIR / "train-page_1-line_2.jpg",
TRAIN_DIR / "train-page_1-line_3.jpg",
TRAIN_DIR / "train-page_1-line_4.jpg",
TRAIN_DIR / "train-page_2-line_1.jpg",
TRAIN_DIR / "train-page_2-line_2.jpg",
TRAIN_DIR / "train-page_2-line_3.jpg",
# Images of val folder
VAL_DIR / "val-page_1-line_1.jpg",
VAL_DIR / "val-page_1-line_2.jpg",
VAL_DIR / "val-page_1-line_3.jpg",
output / "labels.json",
output / "split.json",
]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
# Check "labels.json"
expected_labels = {
"test": {
str(TEST_DIR / "test-page_1-line_1.jpg"): "ⓢCou⁇e⁇ ⓕBouis ⓑ⁇.12.14",
str(TEST_DIR / "test-page_1-line_2.jpg"): "ⓢ⁇outrain ⓕA⁇ol⁇⁇e ⓑ9.4.13",
str(TEST_DIR / "test-page_1-line_3.jpg"): "ⓢ⁇abale ⓕ⁇ran⁇ais ⓑ26.3.11",
str(TEST_DIR / "test-page_2-line_1.jpg"): "ⓢ⁇urosoy ⓕBouis ⓑ22⁇4⁇18",
str(TEST_DIR / "test-page_2-line_2.jpg"): "ⓢColaiani ⓕAn⁇els ⓑ28.11.1⁇",
str(TEST_DIR / "test-page_2-line_3.jpg"): "ⓢRenouar⁇ ⓕMaurice ⓑ2⁇.⁇.04",
},
"train": {
str(TRAIN_DIR / "train-page_1-line_1.jpg"): "ⓢCaillet ⓕMaurice ⓑ28.9.06",
str(TRAIN_DIR / "train-page_1-line_2.jpg"): "ⓢReboul ⓕJean ⓑ30.9.02",
str(TRAIN_DIR / "train-page_1-line_3.jpg"): "ⓢBareyre ⓕJean ⓑ28.3.11",
str(TRAIN_DIR / "train-page_1-line_4.jpg"): "ⓢRoussy ⓕJean ⓑ4.11.14",
str(TRAIN_DIR / "train-page_2-line_1.jpg"): "ⓢMarin ⓕMarcel ⓑ10.8.06",
str(TRAIN_DIR / "train-page_2-line_2.jpg"): "ⓢAmical ⓕEloi ⓑ11.10.04",
str(TRAIN_DIR / "train-page_2-line_3.jpg"): "ⓢBiros ⓕMael ⓑ30.10.10",
},
"val": {
str(VAL_DIR / "val-page_1-line_1.jpg"): "ⓢMonar⁇ ⓕBouis ⓑ29⁇⁇⁇04",
str(VAL_DIR / "val-page_1-line_2.jpg"): "ⓢAstier ⓕArt⁇ur ⓑ11⁇2⁇13",
str(VAL_DIR / "val-page_1-line_3.jpg"): "ⓢ⁇e ⁇lie⁇er ⓕJules ⓑ21⁇11⁇11",
},
}
assert json.loads((output / "labels.json").read_text()) == expected_labels
# Check cropped images
for expected_path in expected_paths:
if expected_path.suffix != ".jpg":
continue
assert ImageChops.difference(
Image.open(
EXTRACTION_DATA_PATH / "images" / "text_line" / expected_path.name
),
Image.open(expected_path),
)
def test_download_image_error(monkeypatch, caplog, capsys):
task = {
"split": "train",
"polygon": [],
"image_url": "deadbeef",
"destination": Path("/dev/null"),
}
monkeypatch.setattr(
"dan.datasets.download.images.polygon_to_bbox",
lambda polygon: BoundingBox(0, 0, 0, 0),
)
extractor = ImageDownloader(image_extension=".jpg")
# Add the key in data
extractor.data[task["split"]][str(task["destination"])] = "deadbeefdata"
# Build a random task
extractor.download_images([task])
# Key should have been removed
assert str(task["destination"]) not in extractor.data[task["split"]]
# Check error log
assert len(caplog.record_tuples) == 1
_, level, msg = caplog.record_tuples[0]
assert level == logging.ERROR
assert msg == "Failed to download 1 image(s)."
# Check stdout
captured = capsys.readouterr()
assert captured.out == "deadbeef: Image URL must be HTTP(S) for element null\n"
def test_download_image_error_try_max(responses, caplog):
# An image's URL
url = (
"https://blabla.com/iiif/2/image_path.jpg/231,699,2789,3659/full/0/default.jpg"
)
fixed_url = (
"https://blabla.com/iiif/2/image_path.jpg/231,699,2789,3659/max/0/default.jpg"
)
# Fake responses error
responses.add(
responses.GET,
url,
status=400,
)
# Correct response with max
responses.add(
responses.GET,
fixed_url,
status=200,
body=next((FIXTURES / "prediction" / "images").iterdir()).read_bytes(),
)
image = download_image(url)
assert image
# We try 3 times with the first URL
# Then the first try with the new URL is successful
assert len(responses.calls) == 4
assert list(map(attrgetter("request.url"), responses.calls)) == [url] * 3 + [
fixed_url
]
# Check error log
assert len(caplog.record_tuples) == 2
# We should only have WARNING levels
assert set(level for _, level, _ in caplog.record_tuples) == {logging.WARNING}
# -*- coding: utf-8 -*-
import shutil
from pathlib import Path
import pytest
import yaml
from prettytable import PrettyTable
from dan.ocr import evaluate
from dan.ocr.utils import add_metrics_table_row, create_metrics_table
from tests import FIXTURES
def test_create_metrics_table():
metric_names = ["ignored", "wer", "cer", "time", "ner"]
metrics_table = create_metrics_table(metric_names)
assert isinstance(metrics_table, PrettyTable)
assert metrics_table.field_names == [
"Split",
"CER (HTR-NER)",
"WER (HTR-NER)",
"NER",
]
def test_add_metrics_table_row():
metric_names = ["ignored", "wer", "cer", "time", "ner"]
metrics_table = create_metrics_table(metric_names)
metrics = {
"ignored": "whatever",
"wer": 1.0,
"cer": 1.3023,
"time": 42,
}
add_metrics_table_row(metrics_table, "train", metrics)
assert isinstance(metrics_table, PrettyTable)
assert metrics_table.field_names == [
"Split",
"CER (HTR-NER)",
"WER (HTR-NER)",
"NER",
]
assert metrics_table.rows == [["train", 1.3023, 1.0, ""]]
@pytest.mark.parametrize(
"training_res, val_res, test_res",
(
......@@ -16,34 +54,46 @@ from tests import FIXTURES
{
"nb_chars": 43,
"cer": 1.3023,
"nb_chars_no_token": 43,
"cer_no_token": 1.3023,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 41,
"cer": 1.2683,
"nb_chars_no_token": 41,
"cer_no_token": 1.2683,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 49,
"cer": 1.1224,
"nb_chars_no_token": 49,
"cer_no_token": 1.1224,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
},
),
),
)
def test_evaluate(training_res, val_res, test_res, evaluate_config):
def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
......@@ -70,3 +120,11 @@ def test_evaluate(training_res, val_res, test_res, evaluate_config):
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[-6:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
)
# -*- coding: utf-8 -*-
import json
import logging
import pickle
import re
from operator import attrgetter, methodcaller
from pathlib import Path
from operator import methodcaller
from typing import NamedTuple
from unittest.mock import patch
import pytest
from PIL import Image, ImageChops
from arkindex_export import Element, Transcription, TranscriptionEntity
from dan.datasets.extract.arkindex import IIIF_FULL_SIZE, ArkindexExtractor
from arkindex_export import (
DatasetElement,
Element,
Transcription,
TranscriptionEntity,
)
from dan.datasets.extract.arkindex import ArkindexExtractor
from dan.datasets.extract.db import get_transcription_entities
from dan.datasets.extract.exceptions import (
NoTranscriptionError,
......@@ -21,13 +22,11 @@ from dan.datasets.extract.exceptions import (
)
from dan.datasets.extract.utils import (
EntityType,
download_image,
entities_to_xml,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import parse_tokens
from line_image_extractor.image_utils import BoundingBox, polygon_to_bbox
from tests import FIXTURES
EXTRACTION_DATA_PATH = FIXTURES / "extraction"
......@@ -51,24 +50,6 @@ def filter_tokens(keys):
return {key: value for key, value in TOKENS.items() if key in keys}
@pytest.mark.parametrize(
"max_width, max_height, width, height, resize",
(
(1000, 2000, 900, 800, IIIF_FULL_SIZE),
(1000, 2000, 1100, 800, "1000,"),
(1000, 2000, 1100, 2800, ",2000"),
(1000, 2000, 2000, 3000, "1000,"),
),
)
def test_get_iiif_size_arg(max_width, max_height, width, height, resize):
assert (
ArkindexExtractor(max_width=max_width, max_height=max_height).get_iiif_size_arg(
width=width, height=height
)
== resize
)
@pytest.mark.parametrize(
"text,trimmed",
(
......@@ -109,28 +90,18 @@ def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
output = tmp_path / "extraction"
arkindex_extractor = ArkindexExtractor(output=output)
# Create an element with an invalid transcription
element = Element.create(
id="element_id",
name="1",
type="page",
polygon="[]",
created=0.0,
updated=0.0,
)
Transcription.create(
id="transcription_id",
text="Is this text valid⁇",
element=element,
)
# Retrieve a dataset element and update its transcription with an invalid one
dataset_element = DatasetElement.select().first()
element = dataset_element.element
Transcription.update({Transcription.text: "Is this text valid⁇"}).execute()
with pytest.raises(
UnknownTokenInText,
match=re.escape(
"Unknown token found in the transcription text of element (element_id)"
f"Unknown token found in the transcription text of element ({element.id})"
),
):
arkindex_extractor.process_element(element, "val")
arkindex_extractor.process_element(dataset_element, element)
@pytest.mark.parametrize(
......@@ -255,12 +226,11 @@ def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
),
),
)
@patch("dan.datasets.extract.arkindex.download_image")
def test_extract(
mock_download_image,
load_entities,
keep_spaces,
transcription_entities_worker_version,
split_content,
mock_database,
expected_subword_language_corpus,
subword_vocab_size,
......@@ -277,14 +247,9 @@ def test_extract(
if token
]
def mock_build_image_url(image_url, polygon, *args, **kwargs):
# During tests, the image URL is its local path
return polygon_to_bbox(json.loads(str(polygon))), image_url
extractor = ArkindexExtractor(
folders=["train", "val", "test"],
dataset_ids=["dataset_id"],
element_type=["text_line"],
parent_element_type="double_page",
output=output,
# Keep the whole text
entity_separators=None,
......@@ -294,43 +259,12 @@ def test_extract(
if load_entities
else None,
keep_spaces=keep_spaces,
image_extension=".jpg",
subword_vocab_size=subword_vocab_size,
)
# Mock build_image_url to simply return the path to the image
extractor.build_iiif_url = mock_build_image_url
# Mock download_image so that it simply opens it with Pillow
mock_download_image.side_effect = Image.open
extractor.run()
# Check files
IMAGE_DIR = output / "images"
TEST_DIR = IMAGE_DIR / "test"
TRAIN_DIR = IMAGE_DIR / "train"
VAL_DIR = IMAGE_DIR / "val"
expected_paths = [
output / "charset.pkl",
# Images of test folder
TEST_DIR / "test-page_1-line_1.jpg",
TEST_DIR / "test-page_1-line_2.jpg",
TEST_DIR / "test-page_1-line_3.jpg",
TEST_DIR / "test-page_2-line_1.jpg",
TEST_DIR / "test-page_2-line_2.jpg",
TEST_DIR / "test-page_2-line_3.jpg",
# Images of train folder
TRAIN_DIR / "train-page_1-line_1.jpg",
TRAIN_DIR / "train-page_1-line_2.jpg",
TRAIN_DIR / "train-page_1-line_3.jpg",
TRAIN_DIR / "train-page_1-line_4.jpg",
TRAIN_DIR / "train-page_2-line_1.jpg",
TRAIN_DIR / "train-page_2-line_2.jpg",
TRAIN_DIR / "train-page_2-line_3.jpg",
# Images of val folder
VAL_DIR / "val-page_1-line_1.jpg",
VAL_DIR / "val-page_1-line_2.jpg",
VAL_DIR / "val-page_1-line_3.jpg",
output / "labels.json",
# Language resources
output / "language_model" / "corpus_characters.txt",
output / "language_model" / "corpus_subwords.txt",
......@@ -341,64 +275,42 @@ def test_extract(
output / "language_model" / "subword_tokenizer.model",
output / "language_model" / "subword_tokenizer.vocab",
output / "language_model" / "tokens.txt",
output / "split.json",
]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
# Check "labels.json"
expected_labels = {
"test": {
str(TEST_DIR / "test-page_1-line_1.jpg"): "ⓢCou⁇e⁇ ⓕBouis ⓑ⁇.12.14",
str(TEST_DIR / "test-page_1-line_2.jpg"): "ⓢ⁇outrain ⓕA⁇ol⁇⁇e ⓑ9.4.13",
str(TEST_DIR / "test-page_1-line_3.jpg"): "ⓢ⁇abale ⓕ⁇ran⁇ais ⓑ26.3.11",
str(TEST_DIR / "test-page_2-line_1.jpg"): "ⓢ⁇urosoy ⓕBouis ⓑ22⁇4⁇18",
str(TEST_DIR / "test-page_2-line_2.jpg"): "ⓢColaiani ⓕAn⁇els ⓑ28.11.1⁇",
str(TEST_DIR / "test-page_2-line_3.jpg"): "ⓢRenouar⁇ ⓕMaurice ⓑ2⁇.⁇.04",
},
"train": {
str(TRAIN_DIR / "train-page_1-line_1.jpg"): "ⓢCaillet ⓕMaurice ⓑ28.9.06",
str(TRAIN_DIR / "train-page_1-line_2.jpg"): "ⓢReboul ⓕJean ⓑ30.9.02",
str(TRAIN_DIR / "train-page_1-line_3.jpg"): "ⓢBareyre ⓕJean ⓑ28.3.11",
str(TRAIN_DIR / "train-page_1-line_4.jpg"): "ⓢRoussy ⓕJean ⓑ4.11.14",
str(TRAIN_DIR / "train-page_2-line_1.jpg"): "ⓢMarin ⓕMarcel ⓑ10.8.06",
str(TRAIN_DIR / "train-page_2-line_2.jpg"): "ⓢAmical ⓕEloi ⓑ11.10.04",
str(TRAIN_DIR / "train-page_2-line_3.jpg"): "ⓢBiros ⓕMael ⓑ30.10.10",
},
"val": {
str(VAL_DIR / "val-page_1-line_1.jpg"): "ⓢMonar⁇ ⓕBouis ⓑ29⁇⁇⁇04",
str(VAL_DIR / "val-page_1-line_2.jpg"): "ⓢAstier ⓕArt⁇ur ⓑ11⁇2⁇13",
str(VAL_DIR / "val-page_1-line_3.jpg"): "ⓢ⁇e ⁇lie⁇er ⓕJules ⓑ21⁇11⁇11",
},
}
# Check "split.json"
# Transcriptions with worker version are in lowercase
if transcription_entities_worker_version:
for split in expected_labels:
for path in expected_labels[split]:
expected_labels[split][path] = expected_labels[split][path].lower()
for split in split_content:
for element_id in split_content[split]:
split_content[split][element_id]["text"] = split_content[split][
element_id
]["text"].lower()
# If we do not load entities, remove tokens
if not load_entities:
token_translations = {ord(token): None for token in tokens}
for split in expected_labels:
for path in expected_labels[split]:
expected_labels[split][path] = expected_labels[split][path].translate(
token_translations
)
for split in split_content:
for element_id in split_content[split]:
split_content[split][element_id]["text"] = split_content[split][
element_id
]["text"].translate(token_translations)
# Replace double spaces with regular space
if not keep_spaces:
for split in expected_labels:
for path in expected_labels[split]:
expected_labels[split][path] = TWO_SPACES_REGEX.sub(
" ", expected_labels[split][path]
for split in split_content:
for element_id in split_content[split]:
split_content[split][element_id]["text"] = TWO_SPACES_REGEX.sub(
" ", split_content[split][element_id]["text"]
)
assert json.loads((output / "labels.json").read_text()) == expected_labels
assert json.loads((output / "split.json").read_text()) == split_content
# Check "charset.pkl"
expected_charset = set()
for label in expected_labels["train"].values():
expected_charset.update(set(label))
for values in split_content["train"].values():
expected_charset.update(set(values["text"]))
if load_entities:
expected_charset.update(tokens)
......@@ -497,118 +409,17 @@ def test_extract(
output / "language_model" / "lexicon_subwords.txt"
).read_text() == "\n".join(expected_language_subword_lexicon)
# Check cropped images
for expected_path in expected_paths:
if expected_path.suffix != ".jpg":
continue
assert ImageChops.difference(
Image.open(
EXTRACTION_DATA_PATH / "images" / "text_line" / expected_path.name
),
Image.open(expected_path),
)
@patch("dan.datasets.extract.arkindex.ArkindexExtractor.build_iiif_url")
def test_download_image_error(iiif_url, caplog, capsys):
task = {
"split": "train",
"polygon": [],
"image_url": "deadbeef",
"destination": Path("/dev/null"),
}
# Make download_image crash
iiif_url.return_value = BoundingBox(0, 0, 0, 0), task["image_url"]
extractor = ArkindexExtractor(
folders=["train", "val", "test"],
element_type=["text_line"],
parent_element_type="double_page",
output=None,
entity_separators=None,
tokens=None,
transcription_worker_version=None,
entity_worker_version=None,
keep_spaces=False,
image_extension=".jpg",
)
# Build a random task
extractor.tasks = [task]
# Add the key in data
extractor.data[task["split"]][str(task["destination"])] = "deadbeefdata"
extractor.download_images()
# Key should have been removed
assert task["destination"] not in extractor.data[task["split"]]
# Check error log
assert len(caplog.record_tuples) == 1
_, level, msg = caplog.record_tuples[0]
assert level == logging.ERROR
assert msg == "Failed to download 1 image(s)."
# Check stdout
captured = capsys.readouterr()
assert captured.out == "deadbeef: Image URL must be HTTP(S) for element null\n"
def test_download_image_error_try_max(responses, caplog):
# An image's URL
url = (
"https://blabla.com/iiif/2/image_path.jpg/231,699,2789,3659/full/0/default.jpg"
)
fixed_url = (
"https://blabla.com/iiif/2/image_path.jpg/231,699,2789,3659/max/0/default.jpg"
)
# Fake responses error
responses.add(
responses.GET,
url,
status=400,
)
# Correct response with max
responses.add(
responses.GET,
fixed_url,
status=200,
body=next((FIXTURES / "prediction" / "images").iterdir()).read_bytes(),
)
image = download_image(url)
assert image
# We try 3 times with the first URL
# Then the first try with the new URL is successful
assert len(responses.calls) == 4
assert list(map(attrgetter("request.url"), responses.calls)) == [url] * 3 + [
fixed_url
]
# Check error log
assert len(caplog.record_tuples) == 2
# We should only have WARNING levels
assert set(level for _, level, _ in caplog.record_tuples) == {logging.WARNING}
@pytest.mark.parametrize("allow_empty", (True, False))
def test_empty_transcription(allow_empty, mock_database):
extractor = ArkindexExtractor(
folders=["train", "val", "test"],
element_type=["text_line"],
parent_element_type="double_page",
output=None,
entity_separators=None,
tokens=None,
transcription_worker_version=None,
entity_worker_version=None,
keep_spaces=False,
image_extension=".jpg",
allow_empty=allow_empty,
)
element_no_transcription = Element(id="unknown")
......