Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (18)
Showing
with 744 additions and 445 deletions
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.0.282 rev: v0.1.6
hooks: hooks:
# Run the linter.
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix] args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/ambv/black # Run the formatter.
rev: 23.7.0 - id: ruff-format
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.4.0
hooks: hooks:
...@@ -44,7 +43,7 @@ repos: ...@@ -44,7 +43,7 @@ repos:
rev: 0.7.16 rev: 0.7.16
hooks: hooks:
- id: mdformat - id: mdformat
exclude: tests/data/analyze exclude: tests/data/analyze|tests/data/evaluate/metrics_table.md
# Optionally add plugins # Optionally add plugins
additional_dependencies: additional_dependencies:
- mdformat-mkdocs[recommended] - mdformat-mkdocs[recommended]
...@@ -68,13 +68,17 @@ ...@@ -68,13 +68,17 @@
"train": [ "train": [
"loss_ce", "loss_ce",
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
], ],
"eval": [ "eval": [
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
] ]
}, },
"validation": { "validation": {
......
...@@ -77,13 +77,17 @@ ...@@ -77,13 +77,17 @@
"train": [ "train": [
"loss_ce", "loss_ce",
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
], ],
"eval": [ "eval": [
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
] ]
}, },
"validation": { "validation": {
......
...@@ -68,13 +68,17 @@ ...@@ -68,13 +68,17 @@
"train": [ "train": [
"loss_ce", "loss_ce",
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
], ],
"eval": [ "eval": [
"cer", "cer",
"cer_no_token",
"wer", "wer",
"wer_no_punct" "wer_no_punct",
"wer_no_token"
] ]
}, },
"validation": { "validation": {
......
import logging
import re
from typing import Dict, List
from dan.utils import EntityType
logger = logging.getLogger(__name__)
def convert(text: str, ner_tokens: Dict[str, EntityType]) -> str:
# Mapping to find a starting token for an ending token efficiently
mapping_end_start: Dict[str, str] = {
entity_type.end: entity_type.start for entity_type in ner_tokens.values()
}
# Mapping to find the entity name for a starting token efficiently
mapping_start_name: Dict[str, str] = {
entity_type.start: name for name, entity_type in ner_tokens.items()
}
starting_tokens: List[str] = mapping_start_name.keys()
ending_tokens: List[str] = mapping_end_start.keys()
has_ending_tokens: bool = set(ending_tokens) != {
""
} # Whether ending tokens are used
# Spacing starting tokens and ending tokens (if necessary)
tokens_spacing: re.Pattern = re.compile(
r"([" + "".join([*starting_tokens, *ending_tokens]) + "])"
)
text: str = tokens_spacing.sub(r" \1 ", text)
iob: List[str] = [] # List of IOB formatted strings
entity_types: List[str] = [] # Encountered entity types
inside: bool = False # Whether we are inside an entity
for token in text.split():
# Encountering a starting token
if token in starting_tokens:
entity_types.append(token)
# Stopping any current entity type
inside = False
continue
# Encountering an ending token
elif has_ending_tokens and token in ending_tokens:
if not entity_types:
logger.warning(
f"Missing starting token for ending token {token}, skipping the entity"
)
continue
# Making sure this ending token closes the current entity
assert (
entity_types[-1] == mapping_end_start[token]
), f"Ending token {token} doesn't match the starting token {entity_types[-1]}"
# Removing the current entity from the queue as it is its end
entity_types.pop()
# If there is still entities in the queue, we continue in the parent one
# Else, we are not in any entity anymore
inside = bool(entity_types)
continue
# The token is not part of an entity
if not entity_types:
iob.append(f"{token} O")
continue
# The token is part of at least one entity
entity_name: str = mapping_start_name[entity_types[-1]]
if inside:
# Inside the same entity
iob.append(f"{token} I-{entity_name}")
continue
# Starting a new entity
iob.append(f"{token} B-{entity_name}")
inside = True
# Concatenating all formatted iob strings
return "\n".join(iob)
...@@ -4,6 +4,7 @@ Preprocess datasets for training. ...@@ -4,6 +4,7 @@ Preprocess datasets for training.
""" """
from dan.datasets.analyze import add_analyze_parser from dan.datasets.analyze import add_analyze_parser
from dan.datasets.download import add_download_parser
from dan.datasets.entities import add_entities_parser from dan.datasets.entities import add_entities_parser
from dan.datasets.extract import add_extract_parser from dan.datasets.extract import add_extract_parser
from dan.datasets.tokens import add_tokens_parser from dan.datasets.tokens import add_tokens_parser
...@@ -18,6 +19,7 @@ def add_dataset_parser(subcommands) -> None: ...@@ -18,6 +19,7 @@ def add_dataset_parser(subcommands) -> None:
subcommands = parser.add_subparsers(metavar="subcommand") subcommands = parser.add_subparsers(metavar="subcommand")
add_extract_parser(subcommands) add_extract_parser(subcommands)
add_download_parser(subcommands)
add_analyze_parser(subcommands) add_analyze_parser(subcommands)
add_entities_parser(subcommands) add_entities_parser(subcommands)
add_tokens_parser(subcommands) add_tokens_parser(subcommands)
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
from collections import Counter, defaultdict from collections import Counter, defaultdict
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List
import imagesize import imagesize
import numpy as np import numpy as np
...@@ -157,7 +157,7 @@ class Statistics: ...@@ -157,7 +157,7 @@ class Statistics:
level=3, level=3,
) )
def run(self, labels: Dict, tokens: Optional[Dict]): def run(self, labels: Dict, tokens: Dict | None):
# Iterate over each split # Iterate over each split
for split_name, split_data in labels.items(): for split_name, split_data in labels.items():
self.document.new_header(level=1, title=split_name.capitalize()) self.document.new_header(level=1, title=split_name.capitalize())
...@@ -175,7 +175,7 @@ class Statistics: ...@@ -175,7 +175,7 @@ class Statistics:
self.document.create_md_file() self.document.create_md_file()
def run(labels: Dict, tokens: Optional[Dict], output: Path) -> None: def run(labels: Dict, tokens: Dict | None, output: Path) -> None:
""" """
Compute and save a dataset statistics. Compute and save a dataset statistics.
""" """
......
# -*- coding: utf-8 -*-
"""
Download images of a dataset from a split extracted by DAN
"""
import pathlib
from dan.datasets.download.images import run
def _valid_image_format(value: str):
im_format = value
if not im_format.startswith("."):
im_format = "." + im_format
return im_format
def add_download_parser(subcommands) -> None:
parser = subcommands.add_parser(
"download",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--output",
type=pathlib.Path,
help="Path where the `split.json` file is stored and where the data will be generated.",
required=True,
)
parser.add_argument(
"--max-width",
type=int,
help="Images larger than this width will be resized to this width.",
)
parser.add_argument(
"--max-height",
type=int,
help="Images larger than this height will be resized to this height.",
)
# Formatting arguments
parser.add_argument(
"--image-format",
type=_valid_image_format,
default=".jpg",
help="Images will be saved under this format.",
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
from pathlib import Path
class ImageDownloadError(Exception):
"""
Raised when an element's image could not be downloaded
"""
def __init__(
self, split: str, path: Path, url: str, exc: Exception, *args: object
) -> None:
super().__init__(*args)
self.split: str = split
self.path: str = str(path)
self.url: str = url
self.message = f"{str(exc)} for element {path.stem}"
# -*- coding: utf-8 -*-
import json
import logging
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Tuple
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from dan.datasets.download.exceptions import ImageDownloadError
from dan.datasets.download.utils import download_image, get_bbox
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
IMAGES_DIR = "images" # Subpath to the images directory.
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
# IIIF 2.0 uses `full`
IIIF_FULL_SIZE = "full"
logger = logging.getLogger(__name__)
class ImageDownloader:
"""
Download images from extracted data
"""
def __init__(
self,
output: Path | None = None,
max_width: int | None = None,
max_height: int | None = None,
image_extension: str = "",
) -> None:
self.output = output
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
# Load split file
split_file = output / "split.json" if output else None
self.split: Dict = (
json.loads(split_file.read_text())
if split_file and split_file.is_file()
else {}
)
# Create directories
for split_name in self.split:
Path(output, IMAGES_DIR, split_name).mkdir(parents=True, exist_ok=True)
self.data: Dict = defaultdict(dict)
def check_extraction(self, values: dict) -> str | None:
# Check dataset_id parameter
if values.get("dataset_id") is None:
return "Dataset ID not found"
# Check image parameters
if not (image := values.get("image")):
return "Image information not found"
# Only support iiif_url with polygon for now
if not image.get("iiif_url"):
return "Image IIIF URL not found"
if not image.get("polygon"):
return "Image polygon not found"
# Check text parameter
if values.get("text") is None:
return "Text not found"
def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height
):
return IIIF_FULL_SIZE
bigger_width = self.max_width and width >= self.max_width
bigger_height = self.max_height and height >= self.max_height
if bigger_width and bigger_height:
# Resize to the biggest dim to keep aspect ratio
# Only resize width is bigger than max size
# This ratio tells which dim needs the biggest shrinking
ratio = width * self.max_height / (height * self.max_width)
return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
elif bigger_width:
return f"{self.max_width},"
# Only resize height is bigger than max size
elif bigger_height:
return f",{self.max_height}"
def build_iiif_url(
self, polygon: List[List[int]], image_url: str
) -> Tuple[BoundingBox, str]:
bbox = polygon_to_bbox(polygon)
size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
# Rotations are done using the lib
return IIIF_URL.format(image_url=image_url, bbox=get_bbox(polygon), size=size)
def build_tasks(self) -> List[Dict[str, str]]:
tasks = []
for split, items in self.split.items():
# Create directories
destination = self.output / IMAGES_DIR / split
destination.mkdir(parents=True, exist_ok=True)
for element_id, values in items.items():
filename = Path(element_id).with_suffix(self.image_extension)
error = self.check_extraction(values)
if error:
logger.warning(f"{destination / filename}: {error}")
continue
image_path = destination / values["dataset_id"] / filename
image_path.parent.mkdir(parents=True, exist_ok=True)
self.data[split][str(image_path)] = values["text"]
# Create task for multithreading pool if image does not exist yet
if image_path.exists():
continue
polygon = values["image"]["polygon"]
iiif_url = values["image"]["iiif_url"]
tasks.append(
{
"split": split,
"polygon": polygon,
"image_url": self.build_iiif_url(polygon, iiif_url),
"destination": image_path,
}
)
return tasks
def get_image(
self,
split: str,
polygon: List[List[int]],
image_url: str,
destination: Path,
) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param split: Dataset split this image belongs to.
:param polygon: Polygon of the processed element.
:param image_url: Base IIIF URL of the image.
:param destination: Where the image should be saved.
"""
bbox = polygon_to_bbox(polygon)
try:
img: Image.Image = download_image(image_url)
# The polygon's coordinate are in the referential of the full image
# We need to remove the offset of the bounding rectangle
polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]
# Normalize bbox
bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)
image = extract(
img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
polygon=np.asarray(polygon).clip(0),
bbox=bbox,
extraction_mode=Extraction.boundingRect,
max_deskew_angle=45,
)
cv2.imwrite(str(destination), image)
except Exception as e:
raise ImageDownloadError(
split=split, path=destination, url=image_url, exc=e
)
def download_images(self, tasks: List[Dict[str, str]]) -> None:
"""
Execute each image download task in parallel
:param tasks: List of tasks to execute.
"""
failed_downloads = []
with tqdm(
desc="Downloading images", total=len(tasks)
) as pbar, ThreadPoolExecutor() as executor:
def process_future(future: Future):
"""
Callback function called at the end of the thread
"""
# Update the progress bar count
pbar.update(1)
exc = future.exception()
if exc is None:
# No error
return
# If failed, tag for removal
assert isinstance(exc, ImageDownloadError)
# Remove transcription from labels dict
del self.data[exc.split][exc.path]
# Save tried URL
failed_downloads.append((exc.url, exc.message))
# Submit all tasks
for task in tasks:
executor.submit(self.get_image, **task).add_done_callback(
process_future
)
if failed_downloads:
logger.error(f"Failed to download {len(failed_downloads)} image(s).")
print(*list(map(": ".join, failed_downloads)), sep="\n")
def export(self) -> None:
"""
Writes a `labels.json` file containing a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
"""
(self.output / "labels.json").write_text(
json.dumps(
self.data,
sort_keys=True,
indent=4,
)
)
def run(self) -> None:
"""
Download the missing images from a `split.json` file and build a `labels.json` file containing
a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
"""
tasks: List[Dict[str, str]] = self.build_tasks()
self.download_images(tasks)
self.export()
def run(
output: Path,
max_width: int | None,
max_height: int | None,
image_format: str,
):
"""
Download the missing images from a `split.json` file and build a `labels.json` file containing
a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
:param output: Path where the `split.json` file is stored and where the data will be generated
:param max_width: Images larger than this width will be resized to this width
:param max_height: Images larger than this height will be resized to this height
:param image_format: Images will be saved under this format
"""
ImageDownloader(
output=output,
max_width=max_width,
max_height=max_height,
image_extension=image_format,
).run()
# -*- coding: utf-8 -*-
import logging
from io import BytesIO
from typing import List
import requests
from PIL import Image, ImageOps
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url: str) -> requests.Response:
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url: str) -> Image.Image:
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
try:
resp = _retried_request(url)
except requests.HTTPError as e:
if "/full/" in url and 400 <= e.response.status_code < 500:
# Retry with max instead of full as IIIF size
resp = _retried_request(url.replace("/full/", "/max/"))
else:
raise e
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content)).convert("RGB")
# Do not rotate JPEG images (see https://github.com/python-pillow/Pillow/issues/4703)
image = ImageOps.exif_transpose(image)
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def get_bbox(polygon: List[List[int]]) -> str:
"""
Returns a comma-separated string of upper left-most pixel, width + height of the image
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
...@@ -5,7 +5,6 @@ Extract dataset from Arkindex using a corpus export. ...@@ -5,7 +5,6 @@ Extract dataset from Arkindex using a corpus export.
import argparse import argparse
import pathlib import pathlib
from typing import Union
from uuid import UUID from uuid import UUID
from dan.datasets.extract.arkindex import run from dan.datasets.extract.arkindex import run
...@@ -13,7 +12,7 @@ from dan.datasets.extract.arkindex import run ...@@ -13,7 +12,7 @@ from dan.datasets.extract.arkindex import run
MANUAL_SOURCE = "manual" MANUAL_SOURCE = "manual"
def parse_worker_version(worker_version_id) -> Union[str, bool]: def parse_worker_version(worker_version_id) -> str | bool:
if worker_version_id == MANUAL_SOURCE: if worker_version_id == MANUAL_SOURCE:
return False return False
...@@ -34,13 +33,6 @@ def validate_char(char): ...@@ -34,13 +33,6 @@ def validate_char(char):
return char return char
def _valid_image_format(value: str):
im_format = value
if not im_format.startswith("."):
im_format = "." + im_format
return im_format
def add_extract_parser(subcommands) -> None: def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser( parser = subcommands.add_parser(
"extract", "extract",
...@@ -55,18 +47,19 @@ def add_extract_parser(subcommands) -> None: ...@@ -55,18 +47,19 @@ def add_extract_parser(subcommands) -> None:
help="Path where the data were exported from Arkindex.", help="Path where the data were exported from Arkindex.",
) )
parser.add_argument( parser.add_argument(
"--element-type", "--dataset-id",
nargs="+", nargs="+",
type=str, type=UUID,
help="Type of elements to retrieve.", help="ID of the dataset to extract from Arkindex.",
required=True, required=True,
dest="dataset_ids",
) )
parser.add_argument( parser.add_argument(
"--parent-element-type", "--element-type",
nargs="+",
type=str, type=str,
help="Type of the parent element containing the data.", help="Type of elements to retrieve.",
required=False, required=True,
default="page",
) )
parser.add_argument( parser.add_argument(
"--output", "--output",
...@@ -75,25 +68,6 @@ def add_extract_parser(subcommands) -> None: ...@@ -75,25 +68,6 @@ def add_extract_parser(subcommands) -> None:
required=True, required=True,
) )
parser.add_argument(
"--train-folder",
type=UUID,
help="ID of the training folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--val-folder",
type=UUID,
help="ID of the validation folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--test-folder",
type=UUID,
help="ID of the testing folder to extract from Arkindex.",
required=True,
)
# Optional arguments. # Optional arguments.
parser.add_argument( parser.add_argument(
"--entity-separators", "--entity-separators",
...@@ -131,18 +105,6 @@ def add_extract_parser(subcommands) -> None: ...@@ -131,18 +105,6 @@ def add_extract_parser(subcommands) -> None:
required=False, required=False,
) )
parser.add_argument(
"--max-width",
type=int,
help="Images larger than this width will be resized to this width.",
)
parser.add_argument(
"--max-height",
type=int,
help="Images larger than this height will be resized to this height.",
)
parser.add_argument( parser.add_argument(
"--subword-vocab-size", "--subword-vocab-size",
type=int, type=int,
...@@ -151,13 +113,6 @@ def add_extract_parser(subcommands) -> None: ...@@ -151,13 +113,6 @@ def add_extract_parser(subcommands) -> None:
) )
# Formatting arguments # Formatting arguments
parser.add_argument(
"--image-format",
type=_valid_image_format,
default=".jpg",
help="Images will be saved under this format.",
)
parser.add_argument( parser.add_argument(
"--keep-spaces", "--keep-spaces",
action="store_true", action="store_true",
......
...@@ -5,55 +5,40 @@ import logging ...@@ -5,55 +5,40 @@ import logging
import pickle import pickle
import random import random
from collections import defaultdict from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List
from uuid import UUID from uuid import UUID
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from arkindex_export import open_database from arkindex_export import Dataset, DatasetElement, Element, open_database
from dan.datasets.extract.db import ( from dan.datasets.extract.db import (
Element, get_dataset_elements,
get_elements, get_elements,
get_transcription_entities, get_transcription_entities,
get_transcriptions, get_transcriptions,
) )
from dan.datasets.extract.exceptions import ( from dan.datasets.extract.exceptions import (
ImageDownloadError,
NoTranscriptionError, NoTranscriptionError,
ProcessingError, ProcessingError,
UnknownTokenInText, UnknownTokenInText,
) )
from dan.datasets.extract.utils import ( from dan.datasets.extract.utils import (
Tokenizer, Tokenizer,
download_image,
entities_to_xml, entities_to_xml,
get_bbox,
get_translation_map, get_translation_map,
get_vocabulary, get_vocabulary,
normalize_linebreaks, normalize_linebreaks,
normalize_spaces, normalize_spaces,
) )
from dan.utils import LMTokenMapping, parse_tokens from dan.utils import LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LANGUAGE_DIR = "language_model" # Subpath to the language model directory. LANGUAGE_DIR = "language_model" # Subpath to the language model directory.
TRAIN_NAME = "train" TRAIN_NAME = "train"
SPLIT_NAMES = [TRAIN_NAME, "val", "test"] VAL_NAME = "val"
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg" TEST_NAME = "test"
# IIIF 2.0 uses `full` SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
IIIF_FULL_SIZE = "full"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,34 +50,26 @@ class ArkindexExtractor: ...@@ -65,34 +50,26 @@ class ArkindexExtractor:
def __init__( def __init__(
self, self,
folders: list = [], dataset_ids: List[UUID] | None = None,
element_type: List[str] = [], element_type: List[str] = [],
parent_element_type: str = None, output: Path | None = None,
output: Path = None,
entity_separators: List[str] = ["\n", " "], entity_separators: List[str] = ["\n", " "],
unknown_token: str = "", unknown_token: str = "",
tokens: Path = None, tokens: Path | None = None,
transcription_worker_version: Optional[Union[str, bool]] = None, transcription_worker_version: str | bool | None = None,
entity_worker_version: Optional[Union[str, bool]] = None, entity_worker_version: str | bool | None = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
keep_spaces: bool = False, keep_spaces: bool = False,
image_extension: str = "",
allow_empty: bool = False, allow_empty: bool = False,
subword_vocab_size: int = 1000, subword_vocab_size: int = 1000,
) -> None: ) -> None:
self.folders = folders self.dataset_ids = dataset_ids
self.element_type = element_type self.element_type = element_type
self.parent_element_type = parent_element_type
self.output = output self.output = output
self.entity_separators = entity_separators self.entity_separators = entity_separators
self.unknown_token = unknown_token self.unknown_token = unknown_token
self.tokens = parse_tokens(tokens) if tokens else {} self.tokens = parse_tokens(tokens) if tokens else {}
self.transcription_worker_version = transcription_worker_version self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version self.entity_worker_version = entity_worker_version
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
self.allow_empty = allow_empty self.allow_empty = allow_empty
self.mapping = LMTokenMapping() self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces self.keep_spaces = keep_spaces
...@@ -104,41 +81,9 @@ class ArkindexExtractor: ...@@ -104,41 +81,9 @@ class ArkindexExtractor:
self.language_tokens = [] self.language_tokens = []
self.language_lexicon = defaultdict(list) self.language_lexicon = defaultdict(list)
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
# NER extraction # NER extraction
self.translation_map: Dict[str, str] | None = get_translation_map(self.tokens) self.translation_map: Dict[str, str] | None = get_translation_map(self.tokens)
def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height
):
return IIIF_FULL_SIZE
bigger_width = self.max_width and width >= self.max_width
bigger_height = self.max_height and height >= self.max_height
if bigger_width and bigger_height:
# Resize to the biggest dim to keep aspect ratio
# Only resize width is bigger than max size
# This ratio tells which dim needs the biggest shrinking
ratio = width * self.max_height / (height * self.max_width)
return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
elif bigger_width:
return f"{self.max_width},"
# Only resize height is bigger than max size
elif bigger_height:
return f",{self.max_height}"
def build_iiif_url(self, polygon, image_url) -> Tuple[BoundingBox, str]:
bbox = polygon_to_bbox(json.loads(str(polygon)))
size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
# Rotations are done using the lib
return bbox, IIIF_URL.format(
image_url=image_url, bbox=get_bbox(polygon), size=size
)
def translate(self, text: str): def translate(self, text: str):
""" """
Use translation map to replace XML tags to actual tokens Use translation map to replace XML tags to actual tokens
...@@ -177,48 +122,7 @@ class ArkindexExtractor: ...@@ -177,48 +122,7 @@ class ArkindexExtractor:
) )
) )
def get_image( def format_text(self, text: str, charset: set | None = None):
self,
split: str,
polygon: List[List[int]],
image_url: str,
destination: Path,
) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param split: Dataset split this image belongs to.
:param polygon: Polygon of the processed element.
:param image_url: Base IIIF URL of the image.
:param destination: Where the image should be saved.
"""
bbox, download_url = self.build_iiif_url(polygon=polygon, image_url=image_url)
try:
img: Image.Image = download_image(download_url)
# The polygon's coordinate are in the referential of the full image
# We need to remove the offset of the bounding rectangle
polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]
# Normalize bbox
bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)
image = extract(
img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
polygon=np.asarray(polygon).clip(0),
bbox=bbox,
extraction_mode=Extraction.boundingRect,
max_deskew_angle=45,
)
destination.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(destination), image)
except Exception as e:
raise ImageDownloadError(
split=split, path=destination, url=download_url, exc=e
)
def format_text(self, text: str, charset: Optional[set] = None):
if not self.keep_spaces: if not self.keep_spaces:
text = normalize_spaces(text) text = normalize_spaces(text)
text = normalize_linebreaks(text) text = normalize_linebreaks(text)
...@@ -234,11 +138,7 @@ class ArkindexExtractor: ...@@ -234,11 +138,7 @@ class ArkindexExtractor:
) )
return text.strip() return text.strip()
def process_element( def process_element(self, dataset_parent: DatasetElement, element: Element):
self,
element: Element,
split: str,
):
""" """
Extract an element's data and save it to disk. Extract an element's data and save it to disk.
The output path is directly related to the split of the element. The output path is directly related to the split of the element.
...@@ -248,46 +148,33 @@ class ArkindexExtractor: ...@@ -248,46 +148,33 @@ class ArkindexExtractor:
if self.unknown_token in text: if self.unknown_token in text:
raise UnknownTokenInText(element_id=element.id) raise UnknownTokenInText(element_id=element.id)
image_path = Path(self.output, IMAGES_DIR, split, element.id).with_suffix(
self.image_extension
)
# Create task for multithreading pool if image does not exist yet
if not image_path.exists():
self.tasks.append(
{
"split": split,
"polygon": json.loads(str(element.polygon)),
"image_url": element.image.url,
"destination": image_path,
}
)
text = self.format_text( text = self.format_text(
text, text,
# Do not replace unknown characters in train split # Do not replace unknown characters in train split
charset=self.charset if split != TRAIN_NAME else None, charset=self.charset if dataset_parent.set_name != TRAIN_NAME else None,
) )
self.data[split][str(image_path)] = text self.data[dataset_parent.set_name][element.id] = {
"dataset_id": dataset_parent.dataset_id,
"text": text,
"image": {
"iiif_url": element.image.url,
"polygon": json.loads(element.polygon),
},
}
self.charset = self.charset.union(set(text)) self.charset = self.charset.union(set(text))
def process_parent( def process_parent(self, pbar, dataset_parent: DatasetElement):
self,
pbar,
parent: Element,
split: str,
):
""" """
Extract data from a parent element. Extract data from a parent element.
""" """
base_description = ( parent = dataset_parent.element
f"Extracting data from {parent.type} ({parent.id}) for split ({split})" base_description = f"Extracting data from {parent.type} ({parent.id}) for split ({dataset_parent.set_name})"
)
pbar.set_description(desc=base_description) pbar.set_description(desc=base_description)
if self.element_type == [parent.type]: if self.element_type == [parent.type]:
try: try:
self.process_element(parent, split) self.process_element(dataset_parent, parent)
except ProcessingError as e: except ProcessingError as e:
logger.warning(f"Skipping {parent.id}: {str(e)}") logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements # Extract children elements
...@@ -302,7 +189,7 @@ class ArkindexExtractor: ...@@ -302,7 +189,7 @@ class ArkindexExtractor:
# Update description to update the children processing progress # Update description to update the children processing progress
pbar.set_description(desc=base_description + f" ({idx}/{nb_children})") pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
try: try:
self.process_element(element, split) self.process_element(dataset_parent, element)
except ProcessingError as e: except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}") logger.warning(f"Skipping {element.id}: {str(e)}")
...@@ -326,8 +213,10 @@ class ArkindexExtractor: ...@@ -326,8 +213,10 @@ class ArkindexExtractor:
# Build LM corpus # Build LM corpus
train_corpus = [ train_corpus = [
text.replace(self.mapping.linebreak.display, self.mapping.space.display) values["text"].replace(
for text in self.data["train"].values() self.mapping.linebreak.display, self.mapping.space.display
)
for values in self.data[TRAIN_NAME].values()
] ]
tokenizer = Tokenizer( tokenizer = Tokenizer(
...@@ -361,7 +250,7 @@ class ArkindexExtractor: ...@@ -361,7 +250,7 @@ class ArkindexExtractor:
] ]
def export(self): def export(self):
(self.output / "labels.json").write_text( (self.output / "split.json").write_text(
json.dumps( json.dumps(
self.data, self.data,
sort_keys=True, sort_keys=True,
...@@ -382,87 +271,52 @@ class ArkindexExtractor: ...@@ -382,87 +271,52 @@ class ArkindexExtractor:
pickle.dumps(sorted(list(self.charset))) pickle.dumps(sorted(list(self.charset)))
) )
def download_images(self):
failed_downloads = []
with tqdm(
desc="Downloading images", total=len(self.tasks)
) as pbar, ThreadPoolExecutor() as executor:
def process_future(future: Future):
"""
Callback function called at the end of the thread
"""
# Update the progress bar count
pbar.update(1)
exc = future.exception()
if exc is None:
# No error
return
# If failed, tag for removal
assert isinstance(exc, ImageDownloadError)
# Remove transcription from labels dict
del self.data[exc.split][exc.path]
# Save tried URL
failed_downloads.append((exc.url, exc.message))
# Submit all tasks
for task in self.tasks:
executor.submit(self.get_image, **task).add_done_callback(
process_future
)
if failed_downloads:
logger.error(f"Failed to download {len(failed_downloads)} image(s).")
print(*list(map(": ".join, failed_downloads)), sep="\n")
def run(self): def run(self):
# Iterate over the subsets to find the page images and labels. # Retrieve the Dataset and its splits from the cache
for folder_id, split in zip(self.folders, SPLIT_NAMES): for dataset_id in self.dataset_ids:
with tqdm( dataset = Dataset.get_by_id(dataset_id)
get_elements( splits = dataset.sets.split(",")
folder_id, if not set(splits).issubset(set(SPLIT_NAMES)):
[self.parent_element_type], logger.warning(
), f'Dataset {dataset.name} ({dataset.id}) does not have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps'
desc=f"Extracting data from ({folder_id}) for split ({split})", )
) as pbar: continue
# Iterate over the pages to create splits at page level.
for parent in pbar: # Iterate over the subsets to find the page images and labels.
self.process_parent( for split in splits:
pbar=pbar, with tqdm(
parent=parent, get_dataset_elements(dataset, split),
split=split, desc=f"Extracting data from ({dataset_id}) for split ({split})",
) ) as pbar:
# Progress bar updates # Iterate over the pages to create splits at page level.
pbar.update() for parent in pbar:
pbar.refresh() self.process_parent(
pbar=pbar,
dataset_parent=parent,
)
# Progress bar updates
pbar.update()
pbar.refresh()
if not self.data: if not self.data:
raise Exception( raise Exception(
"No data was extracted using the provided export database and parameters." "No data was extracted using the provided export database and parameters."
) )
self.download_images()
self.format_lm_files() self.format_lm_files()
self.export() self.export()
def run( def run(
database: Path, database: Path,
dataset_ids: List[UUID],
element_type: List[str], element_type: List[str],
parent_element_type: str,
output: Path, output: Path,
entity_separators: List[str], entity_separators: List[str],
unknown_token: str, unknown_token: str,
tokens: Path, tokens: Path,
train_folder: UUID, transcription_worker_version: str | bool | None,
val_folder: UUID, entity_worker_version: str | bool | None,
test_folder: UUID,
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
max_width: Optional[int],
max_height: Optional[int],
image_format: str,
keep_spaces: bool, keep_spaces: bool,
allow_empty: bool, allow_empty: bool,
subword_vocab_size: int, subword_vocab_size: int,
...@@ -470,27 +324,19 @@ def run( ...@@ -470,27 +324,19 @@ def run(
assert database.exists(), f"No file found @ {database}" assert database.exists(), f"No file found @ {database}"
open_database(path=database) open_database(path=database)
folders = [str(train_folder), str(val_folder), str(test_folder)]
# Create directories # Create directories
for split in SPLIT_NAMES:
Path(output, IMAGES_DIR, split).mkdir(parents=True, exist_ok=True)
Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True) Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True)
ArkindexExtractor( ArkindexExtractor(
folders=folders, dataset_ids=dataset_ids,
element_type=element_type, element_type=element_type,
parent_element_type=parent_element_type,
output=output, output=output,
entity_separators=entity_separators, entity_separators=entity_separators,
unknown_token=unknown_token, unknown_token=unknown_token,
tokens=tokens, tokens=tokens,
transcription_worker_version=transcription_worker_version, transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version, entity_worker_version=entity_worker_version,
max_width=max_width,
max_height=max_height,
keep_spaces=keep_spaces, keep_spaces=keep_spaces,
image_extension=image_format,
allow_empty=allow_empty, allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size, subword_vocab_size=subword_vocab_size,
).run() ).run()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import List
from typing import List, Optional, Union
from arkindex_export import Image from arkindex_export import Image
from arkindex_export.models import ( from arkindex_export.models import (
Dataset,
DatasetElement,
Element, Element,
Entity, Entity,
EntityType, EntityType,
...@@ -13,6 +14,26 @@ from arkindex_export.models import ( ...@@ -13,6 +14,26 @@ from arkindex_export.models import (
from arkindex_export.queries import list_children from arkindex_export.queries import list_children
def get_dataset_elements(
dataset: Dataset,
split: str,
):
"""
Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus
"""
query = (
DatasetElement.select()
.join(Element)
.join(Image, on=(DatasetElement.element.image == Image.id))
.where(
DatasetElement.dataset == dataset,
DatasetElement.set_name == split,
)
)
return query
def get_elements( def get_elements(
parent_id: str, parent_id: str,
element_type: List[str], element_type: List[str],
...@@ -41,7 +62,7 @@ def build_worker_version_filter(ArkindexModel, worker_version): ...@@ -41,7 +62,7 @@ def build_worker_version_filter(ArkindexModel, worker_version):
def get_transcriptions( def get_transcriptions(
element_id: str, transcription_worker_version: Union[str, bool] element_id: str, transcription_worker_version: str | bool
) -> List[Transcription]: ) -> List[Transcription]:
""" """
Retrieve transcriptions from an SQLite export of an Arkindex corpus Retrieve transcriptions from an SQLite export of an Arkindex corpus
...@@ -61,7 +82,7 @@ def get_transcriptions( ...@@ -61,7 +82,7 @@ def get_transcriptions(
def get_transcription_entities( def get_transcription_entities(
transcription_id: str, transcription_id: str,
entity_worker_version: Optional[Union[str, bool]], entity_worker_version: str | bool | None,
supported_types: List[str], supported_types: List[str],
) -> List[TranscriptionEntity]: ) -> List[TranscriptionEntity]:
""" """
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from pathlib import Path
class ProcessingError(Exception): class ProcessingError(Exception):
...@@ -21,21 +20,6 @@ class ElementProcessingError(ProcessingError): ...@@ -21,21 +20,6 @@ class ElementProcessingError(ProcessingError):
self.element_id = element_id self.element_id = element_id
class ImageDownloadError(Exception):
"""
Raised when an element's image could not be downloaded
"""
def __init__(
self, split: str, path: Path, url: str, exc: Exception, *args: object
) -> None:
super().__init__(*args)
self.split: str = split
self.path: str = str(path)
self.url: str = url
self.message = f"{str(exc)} for element {path.stem}"
class NoTranscriptionError(ElementProcessingError): class NoTranscriptionError(ElementProcessingError):
""" """
Raised when there are no transcriptions on an element Raised when there are no transcriptions on an element
......
...@@ -4,31 +4,19 @@ import logging ...@@ -4,31 +4,19 @@ import logging
import operator import operator
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Dict, Iterator, List, Optional, Union from typing import Dict, Iterator, List
import requests
import sentencepiece as spm import sentencepiece as spm
from lxml.etree import Element, SubElement, tostring from lxml.etree import Element, SubElement, tostring
from nltk import wordpunct_tokenize from nltk import wordpunct_tokenize
from PIL import Image, ImageOps
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from arkindex_export import TranscriptionEntity from arkindex_export import TranscriptionEntity
from dan.utils import EntityType, LMTokenMapping from dan.utils import EntityType, LMTokenMapping
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces # replace \t with regular space and consecutive spaces
TRIM_SPACE_REGEX = re.compile(r"[\t ]+") TRIM_SPACE_REGEX = re.compile(r"[\t ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r\n]+") TRIM_RETURN_REGEX = re.compile(r"[\r\n]+")
...@@ -42,57 +30,6 @@ ENCODING_MAP = { ...@@ -42,57 +30,6 @@ ENCODING_MAP = {
} }
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url):
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url):
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
try:
resp = _retried_request(url)
except requests.HTTPError as e:
if "/full/" in url and 400 <= e.response.status_code < 500:
# Retry with max instead of full as IIIF size
resp = _retried_request(url.replace("/full/", "/max/"))
else:
raise e
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content)).convert("RGB")
# Do not rotate JPEG images (see https://github.com/python-pillow/Pillow/issues/4703)
image = ImageOps.exif_transpose(image)
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def normalize_linebreaks(text: str) -> str: def normalize_linebreaks(text: str) -> str:
""" """
Remove begin/ending linebreaks. Remove begin/ending linebreaks.
...@@ -111,17 +48,6 @@ def normalize_spaces(text: str) -> str: ...@@ -111,17 +48,6 @@ def normalize_spaces(text: str) -> str:
return TRIM_SPACE_REGEX.sub(" ", text.strip()) return TRIM_SPACE_REGEX.sub(" ", text.strip())
def get_bbox(polygon: List[List[int]]) -> str:
"""
Arkindex polygon stored as string
returns a comma-separated string of upper left-most pixel, width + height of the image
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
def get_vocabulary(tokenized_text: List[str]) -> set[str]: def get_vocabulary(tokenized_text: List[str]) -> set[str]:
""" """
Compute set of vocabulary from tokenzied text. Compute set of vocabulary from tokenzied text.
...@@ -146,7 +72,7 @@ class Tokenizer: ...@@ -146,7 +72,7 @@ class Tokenizer:
unknown_token: str unknown_token: str
outdir: Path outdir: Path
mapping: LMTokenMapping mapping: LMTokenMapping
tokens: Optional[EntityType] = None tokens: EntityType | None = None
subword_vocab_size: int = 1000 subword_vocab_size: int = 1000
sentencepiece_model: spm.SentencePieceProcessor = field(init=False) sentencepiece_model: spm.SentencePieceProcessor = field(init=False)
...@@ -155,7 +81,7 @@ class Tokenizer: ...@@ -155,7 +81,7 @@ class Tokenizer:
return self.outdir / "subword_tokenizer" return self.outdir / "subword_tokenizer"
@property @property
def ner_tokens(self) -> Union[List[str], Iterator[str]]: def ner_tokens(self) -> List[str] | Iterator[str]:
if self.tokens is None: if self.tokens is None:
return [] return []
return itertools.chain( return itertools.chain(
...@@ -253,7 +179,7 @@ def slugify(text: str): ...@@ -253,7 +179,7 @@ def slugify(text: str):
return text.replace(" ", "_") return text.replace(" ", "_")
def get_translation_map(tokens: Dict[str, EntityType]) -> Optional[Dict[str, str]]: def get_translation_map(tokens: Dict[str, EntityType]) -> Dict[str, str] | None:
if not tokens: if not tokens:
return return
...@@ -321,7 +247,7 @@ class XMLEntity: ...@@ -321,7 +247,7 @@ class XMLEntity:
def entities_to_xml( def entities_to_xml(
text: str, text: str,
predictions: List[TranscriptionEntity], predictions: List[TranscriptionEntity],
entity_separators: Optional[List[str]] = None, entity_separators: List[str] | None = None,
) -> str: ) -> str:
"""Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag. """Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag.
Its type will be used to name the tag. Its type will be used to name the tag.
...@@ -341,7 +267,7 @@ def entities_to_xml( ...@@ -341,7 +267,7 @@ def entities_to_xml(
return separator return separator
return "" return ""
def add_portion(entity_offset: Optional[int] = None): def add_portion(entity_offset: int | None = None):
""" """
Add the portion of text between entities either: Add the portion of text between entities either:
- after the last node, if there is one before - after the last node, if there is one before
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Dict, List, Union from typing import Dict, List
import numpy as np import numpy as np
import torch import torch
...@@ -559,7 +559,7 @@ class CTCLanguageDecoder: ...@@ -559,7 +559,7 @@ class CTCLanguageDecoder:
def post_process( def post_process(
self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]: ) -> Dict[str, List[str | float]]:
""" """
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image. Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
:param hypotheses: List of hypotheses returned by the decoder. :param hypotheses: List of hypotheses returned by the decoder.
...@@ -594,7 +594,7 @@ class CTCLanguageDecoder: ...@@ -594,7 +594,7 @@ class CTCLanguageDecoder:
def __call__( def __call__(
self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]: ) -> Dict[str, List[str | float]]:
""" """
Decode a feature vector using n-gram language modelling. Decode a feature vector using n-gram language modelling.
:param batch_features: Feature vector of size (batch_size, n_tokens, n_frames). :param batch_features: Feature vector of size (batch_size, n_tokens, n_frames).
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from dan.ocr.manager.training import Manager from dan.ocr.manager.training import Manager
from dan.ocr.utils import update_config from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config
from dan.utils import read_json from dan.utils import read_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,20 +50,34 @@ def eval(rank, config, mlflow_logging): ...@@ -50,20 +50,34 @@ def eval(rank, config, mlflow_logging):
model = Manager(config) model = Manager(config)
model.load_model() model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"] metric_names = [
"cer",
"cer_no_token",
"wer",
"wer_no_punct",
"wer_no_token",
"time",
]
if config["dataset"]["tokens"] is not None:
metric_names.append("ner")
metrics_table = create_metrics_table(metric_names)
for dataset_name in config["dataset"]["datasets"]: for dataset_name in config["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]: for set_name in ["train", "val", "test"]:
logger.info(f"Evaluating on set `{set_name}`") logger.info(f"Evaluating on set `{set_name}`")
model.evaluate( metrics = model.evaluate(
"{}-{}".format(dataset_name, set_name), "{}-{}".format(dataset_name, set_name),
[ [
(dataset_name, set_name), (dataset_name, set_name),
], ],
metrics, metric_names,
output=True,
mlflow_logging=mlflow_logging, mlflow_logging=mlflow_logging,
) )
add_metrics_table_row(metrics_table, set_name, metrics)
print(metrics_table)
def run(config: dict): def run(config: dict):
update_config(config) update_config(config)
......
...@@ -3,7 +3,7 @@ import re ...@@ -3,7 +3,7 @@ import re
from collections import defaultdict from collections import defaultdict
from operator import attrgetter from operator import attrgetter
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List
import editdistance import editdistance
import numpy as np import numpy as np
...@@ -19,11 +19,12 @@ REGEX_CONSECUTIVE_SPACES = re.compile(r" +") ...@@ -19,11 +19,12 @@ REGEX_CONSECUTIVE_SPACES = re.compile(r" +")
# Keep only one space character # Keep only one space character
REGEX_ONLY_ONE_SPACE = re.compile(r"\s+") REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
# Mapping between computation tasks (CER, WER, NER) and their metric keyword
METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
class MetricManager: class MetricManager:
def __init__( def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self, metric_names: List[str], dataset_name: str, tokens: Optional[Path]
):
self.dataset_name: str = dataset_name self.dataset_name: str = dataset_name
self.remove_tokens: str = None self.remove_tokens: str = None
...@@ -34,42 +35,50 @@ class MetricManager: ...@@ -34,42 +35,50 @@ class MetricManager:
+ list(map(attrgetter("end"), tokens.values())) + list(map(attrgetter("end"), tokens.values()))
) )
self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])") self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])")
self.metric_names: List[str] = metric_names self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list) self.epoch_metrics = defaultdict(list)
def edit_cer_from_string(self, gt: str, pred: str): def format_string_for_cer(self, text: str, remove_token: bool = False):
""" """
Format and compute edit distance between two strings at character level Format string for CER computation: remove layout tokens and extra spaces
""" """
gt = self.format_string_for_cer(gt) if remove_token and self.remove_tokens is not None:
pred = self.format_string_for_cer(pred) text = self.remove_tokens.sub("", text)
return editdistance.eval(gt, pred)
def nb_chars_cer_from_string(self, gt: str) -> int: text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
""" return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
Compute length after formatting of ground truth string
"""
return len(self.format_string_for_cer(gt))
def format_string_for_wer(self, text: str, remove_punct: bool = False): def format_string_for_wer(
self, text: str, remove_punct: bool = False, remove_token: bool = False
):
""" """
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
""" """
if remove_punct: if remove_punct:
text = REGEX_PUNCTUATION.sub("", text) text = REGEX_PUNCTUATION.sub("", text)
if self.remove_tokens is not None: if remove_token and self.remove_tokens is not None:
text = self.remove_tokens.sub("", text) text = self.remove_tokens.sub("", text)
return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ") return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ")
def format_string_for_cer(self, text: str): def format_string_for_ner(self, text: str):
""" """
Format string for CER computation: remove layout tokens and extra spaces Format string for NER computation: only keep layout tokens
""" """
if self.remove_tokens is not None: return self.keep_tokens.sub("", text)
text = self.remove_tokens.sub("", text)
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text) def _format_string(self, task: str, *args, **kwargs):
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip() """
Call the proper `format_string_for_*` method for the given task
"""
match task:
case "cer":
return self.format_string_for_cer(*args, **kwargs)
case "wer":
return self.format_string_for_wer(*args, **kwargs)
case "ner":
return self.format_string_for_ner(*args, **kwargs)
def update_metrics(self, batch_metrics): def update_metrics(self, batch_metrics):
""" """
...@@ -97,11 +106,20 @@ class MetricManager: ...@@ -97,11 +106,20 @@ class MetricManager:
display_values["sample_time"] = float(round(sample_time, 4)) display_values["sample_time"] = float(round(sample_time, 4))
display_values[metric_name] = value display_values[metric_name] = value
continue continue
case "cer": case (
num_name, denom_name = "edit_chars", "nb_chars" "cer"
case "wer" | "wer_no_punct": | "cer_no_token"
| "wer"
| "wer_no_punct"
| "wer_no_token"
| "ner"
):
keyword = METRICS_KEYWORD[metric_name[:3]]
suffix = metric_name[3:] suffix = metric_name[3:]
num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix num_name, denom_name = (
"edit_" + keyword + suffix,
"nb_" + keyword + suffix,
)
case "loss" | "loss_ce": case "loss" | "loss_ce":
display_values[metric_name] = round( display_values[metric_name] = round(
float( float(
...@@ -139,21 +157,37 @@ class MetricManager: ...@@ -139,21 +157,37 @@ class MetricManager:
gt, prediction = values["str_y"], values["str_x"] gt, prediction = values["str_y"], values["str_x"]
for metric_name in metric_names: for metric_name in metric_names:
match metric_name: match metric_name:
case "cer": case (
metrics["edit_chars"] = list( "cer"
map(self.edit_cer_from_string, gt, prediction) | "cer_no_token"
) | "wer"
metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt)) | "wer_no_punct"
case "wer" | "wer_no_punct": | "wer_no_token"
| "ner"
):
task = metric_name[:3]
keyword = METRICS_KEYWORD[task]
suffix = metric_name[3:] suffix = metric_name[3:]
split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)]))
# Add extra parameters for the format functions
extras = []
if suffix == "_no_punct":
extras.append([{"remove_punct": True}])
elif suffix == "_no_token":
extras.append([{"remove_token": True}])
# Run the format function for the desired computation (CER, WER or NER)
split_gt = list(map(self._format_string, [task], gt, *extras))
split_pred = list( split_pred = list(
map(self.format_string_for_wer, prediction, [bool(suffix)]) map(self._format_string, [task], prediction, *extras)
) )
metrics["edit_words" + suffix] = list(
# Compute and store edit distance/length for the desired level
# (chars, words or tokens) as metrics
metrics["edit_" + keyword + suffix] = list(
map(editdistance.eval, split_gt, split_pred) map(editdistance.eval, split_gt, split_pred)
) )
metrics["nb_words" + suffix] = list(map(len, split_gt)) metrics["nb_" + keyword + suffix] = list(map(len, split_gt))
case "loss" | "loss_ce": case "loss" | "loss_ce":
metrics[metric_name] = [ metrics[metric_name] = [
values[metric_name], values[metric_name],
......
...@@ -749,8 +749,8 @@ class GenericTrainingManager: ...@@ -749,8 +749,8 @@ class GenericTrainingManager:
return display_values return display_values
def evaluate( def evaluate(
self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False self, custom_name, sets_list, metric_names, mlflow_logging=False
): ) -> Dict[str, int | float]:
""" """
Main loop for evaluation Main loop for evaluation
""" """
...@@ -798,19 +798,19 @@ class GenericTrainingManager: ...@@ -798,19 +798,19 @@ class GenericTrainingManager:
display_values, logging_name, mlflow_logging, self.is_master display_values, logging_name, mlflow_logging, self.is_master
) )
# output metrics values if requested if "pred" in metric_names:
if output: self.output_pred(custom_name)
if "pred" in metric_names: metrics = self.metric_manager[custom_name].get_display_values(output=True)
self.output_pred(custom_name) path = self.paths["results"] / "predict_{}_{}.yaml".format(
metrics = self.metric_manager[custom_name].get_display_values(output=True) custom_name, self.latest_epoch
path = self.paths["results"] / "predict_{}_{}.yaml".format( )
custom_name, self.latest_epoch path.write_text(yaml.dump(metrics))
)
path.write_text(yaml.dump(metrics)) if mlflow_logging:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
if mlflow_logging: return metrics
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
def output_pred(self, name): def output_pred(self, name):
path = self.paths["results"] / "predict_{}_{}.yaml".format( path = self.paths["results"] / "predict_{}_{}.yaml".format(
......