Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (18)
Showing
with 744 additions and 445 deletions
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.282
rev: v0.1.6
hooks:
# Run the linter.
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/ambv/black
rev: 23.7.0
hooks:
- id: black
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
......@@ -44,7 +43,7 @@ repos:
rev: 0.7.16
hooks:
- id: mdformat
exclude: tests/data/analyze
exclude: tests/data/analyze|tests/data/evaluate/metrics_table.md
# Optionally add plugins
additional_dependencies:
- mdformat-mkdocs[recommended]
......@@ -68,13 +68,17 @@
"train": [
"loss_ce",
"cer",
"cer_no_token",
"wer",
"wer_no_punct"
"wer_no_punct",
"wer_no_token"
],
"eval": [
"cer",
"cer_no_token",
"wer",
"wer_no_punct"
"wer_no_punct",
"wer_no_token"
]
},
"validation": {
......
......@@ -77,13 +77,17 @@
"train": [
"loss_ce",
"cer",
"cer_no_token",
"wer",
"wer_no_punct"
"wer_no_punct",
"wer_no_token"
],
"eval": [
"cer",
"cer_no_token",
"wer",
"wer_no_punct"
"wer_no_punct",
"wer_no_token"
]
},
"validation": {
......
......@@ -68,13 +68,17 @@
"train": [
"loss_ce",
"cer",
"cer_no_token",
"wer",
"wer_no_punct"
"wer_no_punct",
"wer_no_token"
],
"eval": [
"cer",
"cer_no_token",
"wer",
"wer_no_punct"
"wer_no_punct",
"wer_no_token"
]
},
"validation": {
......
import logging
import re
from typing import Dict, List
from dan.utils import EntityType
logger = logging.getLogger(__name__)
def convert(text: str, ner_tokens: Dict[str, EntityType]) -> str:
# Mapping to find a starting token for an ending token efficiently
mapping_end_start: Dict[str, str] = {
entity_type.end: entity_type.start for entity_type in ner_tokens.values()
}
# Mapping to find the entity name for a starting token efficiently
mapping_start_name: Dict[str, str] = {
entity_type.start: name for name, entity_type in ner_tokens.items()
}
starting_tokens: List[str] = mapping_start_name.keys()
ending_tokens: List[str] = mapping_end_start.keys()
has_ending_tokens: bool = set(ending_tokens) != {
""
} # Whether ending tokens are used
# Spacing starting tokens and ending tokens (if necessary)
tokens_spacing: re.Pattern = re.compile(
r"([" + "".join([*starting_tokens, *ending_tokens]) + "])"
)
text: str = tokens_spacing.sub(r" \1 ", text)
iob: List[str] = [] # List of IOB formatted strings
entity_types: List[str] = [] # Encountered entity types
inside: bool = False # Whether we are inside an entity
for token in text.split():
# Encountering a starting token
if token in starting_tokens:
entity_types.append(token)
# Stopping any current entity type
inside = False
continue
# Encountering an ending token
elif has_ending_tokens and token in ending_tokens:
if not entity_types:
logger.warning(
f"Missing starting token for ending token {token}, skipping the entity"
)
continue
# Making sure this ending token closes the current entity
assert (
entity_types[-1] == mapping_end_start[token]
), f"Ending token {token} doesn't match the starting token {entity_types[-1]}"
# Removing the current entity from the queue as it is its end
entity_types.pop()
# If there is still entities in the queue, we continue in the parent one
# Else, we are not in any entity anymore
inside = bool(entity_types)
continue
# The token is not part of an entity
if not entity_types:
iob.append(f"{token} O")
continue
# The token is part of at least one entity
entity_name: str = mapping_start_name[entity_types[-1]]
if inside:
# Inside the same entity
iob.append(f"{token} I-{entity_name}")
continue
# Starting a new entity
iob.append(f"{token} B-{entity_name}")
inside = True
# Concatenating all formatted iob strings
return "\n".join(iob)
......@@ -4,6 +4,7 @@ Preprocess datasets for training.
"""
from dan.datasets.analyze import add_analyze_parser
from dan.datasets.download import add_download_parser
from dan.datasets.entities import add_entities_parser
from dan.datasets.extract import add_extract_parser
from dan.datasets.tokens import add_tokens_parser
......@@ -18,6 +19,7 @@ def add_dataset_parser(subcommands) -> None:
subcommands = parser.add_subparsers(metavar="subcommand")
add_extract_parser(subcommands)
add_download_parser(subcommands)
add_analyze_parser(subcommands)
add_entities_parser(subcommands)
add_tokens_parser(subcommands)
......@@ -3,7 +3,7 @@ import logging
from collections import Counter, defaultdict
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List
import imagesize
import numpy as np
......@@ -157,7 +157,7 @@ class Statistics:
level=3,
)
def run(self, labels: Dict, tokens: Optional[Dict]):
def run(self, labels: Dict, tokens: Dict | None):
# Iterate over each split
for split_name, split_data in labels.items():
self.document.new_header(level=1, title=split_name.capitalize())
......@@ -175,7 +175,7 @@ class Statistics:
self.document.create_md_file()
def run(labels: Dict, tokens: Optional[Dict], output: Path) -> None:
def run(labels: Dict, tokens: Dict | None, output: Path) -> None:
"""
Compute and save a dataset statistics.
"""
......
# -*- coding: utf-8 -*-
"""
Download images of a dataset from a split extracted by DAN
"""
import pathlib
from dan.datasets.download.images import run
def _valid_image_format(value: str):
im_format = value
if not im_format.startswith("."):
im_format = "." + im_format
return im_format
def add_download_parser(subcommands) -> None:
parser = subcommands.add_parser(
"download",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--output",
type=pathlib.Path,
help="Path where the `split.json` file is stored and where the data will be generated.",
required=True,
)
parser.add_argument(
"--max-width",
type=int,
help="Images larger than this width will be resized to this width.",
)
parser.add_argument(
"--max-height",
type=int,
help="Images larger than this height will be resized to this height.",
)
# Formatting arguments
parser.add_argument(
"--image-format",
type=_valid_image_format,
default=".jpg",
help="Images will be saved under this format.",
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
from pathlib import Path
class ImageDownloadError(Exception):
"""
Raised when an element's image could not be downloaded
"""
def __init__(
self, split: str, path: Path, url: str, exc: Exception, *args: object
) -> None:
super().__init__(*args)
self.split: str = split
self.path: str = str(path)
self.url: str = url
self.message = f"{str(exc)} for element {path.stem}"
# -*- coding: utf-8 -*-
import json
import logging
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Tuple
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from dan.datasets.download.exceptions import ImageDownloadError
from dan.datasets.download.utils import download_image, get_bbox
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
IMAGES_DIR = "images" # Subpath to the images directory.
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
# IIIF 2.0 uses `full`
IIIF_FULL_SIZE = "full"
logger = logging.getLogger(__name__)
class ImageDownloader:
"""
Download images from extracted data
"""
def __init__(
self,
output: Path | None = None,
max_width: int | None = None,
max_height: int | None = None,
image_extension: str = "",
) -> None:
self.output = output
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
# Load split file
split_file = output / "split.json" if output else None
self.split: Dict = (
json.loads(split_file.read_text())
if split_file and split_file.is_file()
else {}
)
# Create directories
for split_name in self.split:
Path(output, IMAGES_DIR, split_name).mkdir(parents=True, exist_ok=True)
self.data: Dict = defaultdict(dict)
def check_extraction(self, values: dict) -> str | None:
# Check dataset_id parameter
if values.get("dataset_id") is None:
return "Dataset ID not found"
# Check image parameters
if not (image := values.get("image")):
return "Image information not found"
# Only support iiif_url with polygon for now
if not image.get("iiif_url"):
return "Image IIIF URL not found"
if not image.get("polygon"):
return "Image polygon not found"
# Check text parameter
if values.get("text") is None:
return "Text not found"
def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height
):
return IIIF_FULL_SIZE
bigger_width = self.max_width and width >= self.max_width
bigger_height = self.max_height and height >= self.max_height
if bigger_width and bigger_height:
# Resize to the biggest dim to keep aspect ratio
# Only resize width is bigger than max size
# This ratio tells which dim needs the biggest shrinking
ratio = width * self.max_height / (height * self.max_width)
return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
elif bigger_width:
return f"{self.max_width},"
# Only resize height is bigger than max size
elif bigger_height:
return f",{self.max_height}"
def build_iiif_url(
self, polygon: List[List[int]], image_url: str
) -> Tuple[BoundingBox, str]:
bbox = polygon_to_bbox(polygon)
size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
# Rotations are done using the lib
return IIIF_URL.format(image_url=image_url, bbox=get_bbox(polygon), size=size)
def build_tasks(self) -> List[Dict[str, str]]:
tasks = []
for split, items in self.split.items():
# Create directories
destination = self.output / IMAGES_DIR / split
destination.mkdir(parents=True, exist_ok=True)
for element_id, values in items.items():
filename = Path(element_id).with_suffix(self.image_extension)
error = self.check_extraction(values)
if error:
logger.warning(f"{destination / filename}: {error}")
continue
image_path = destination / values["dataset_id"] / filename
image_path.parent.mkdir(parents=True, exist_ok=True)
self.data[split][str(image_path)] = values["text"]
# Create task for multithreading pool if image does not exist yet
if image_path.exists():
continue
polygon = values["image"]["polygon"]
iiif_url = values["image"]["iiif_url"]
tasks.append(
{
"split": split,
"polygon": polygon,
"image_url": self.build_iiif_url(polygon, iiif_url),
"destination": image_path,
}
)
return tasks
def get_image(
self,
split: str,
polygon: List[List[int]],
image_url: str,
destination: Path,
) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param split: Dataset split this image belongs to.
:param polygon: Polygon of the processed element.
:param image_url: Base IIIF URL of the image.
:param destination: Where the image should be saved.
"""
bbox = polygon_to_bbox(polygon)
try:
img: Image.Image = download_image(image_url)
# The polygon's coordinate are in the referential of the full image
# We need to remove the offset of the bounding rectangle
polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]
# Normalize bbox
bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)
image = extract(
img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
polygon=np.asarray(polygon).clip(0),
bbox=bbox,
extraction_mode=Extraction.boundingRect,
max_deskew_angle=45,
)
cv2.imwrite(str(destination), image)
except Exception as e:
raise ImageDownloadError(
split=split, path=destination, url=image_url, exc=e
)
def download_images(self, tasks: List[Dict[str, str]]) -> None:
"""
Execute each image download task in parallel
:param tasks: List of tasks to execute.
"""
failed_downloads = []
with tqdm(
desc="Downloading images", total=len(tasks)
) as pbar, ThreadPoolExecutor() as executor:
def process_future(future: Future):
"""
Callback function called at the end of the thread
"""
# Update the progress bar count
pbar.update(1)
exc = future.exception()
if exc is None:
# No error
return
# If failed, tag for removal
assert isinstance(exc, ImageDownloadError)
# Remove transcription from labels dict
del self.data[exc.split][exc.path]
# Save tried URL
failed_downloads.append((exc.url, exc.message))
# Submit all tasks
for task in tasks:
executor.submit(self.get_image, **task).add_done_callback(
process_future
)
if failed_downloads:
logger.error(f"Failed to download {len(failed_downloads)} image(s).")
print(*list(map(": ".join, failed_downloads)), sep="\n")
def export(self) -> None:
"""
Writes a `labels.json` file containing a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
"""
(self.output / "labels.json").write_text(
json.dumps(
self.data,
sort_keys=True,
indent=4,
)
)
def run(self) -> None:
"""
Download the missing images from a `split.json` file and build a `labels.json` file containing
a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
"""
tasks: List[Dict[str, str]] = self.build_tasks()
self.download_images(tasks)
self.export()
def run(
output: Path,
max_width: int | None,
max_height: int | None,
image_format: str,
):
"""
Download the missing images from a `split.json` file and build a `labels.json` file containing
a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
:param output: Path where the `split.json` file is stored and where the data will be generated
:param max_width: Images larger than this width will be resized to this width
:param max_height: Images larger than this height will be resized to this height
:param image_format: Images will be saved under this format
"""
ImageDownloader(
output=output,
max_width=max_width,
max_height=max_height,
image_extension=image_format,
).run()
# -*- coding: utf-8 -*-
import logging
from io import BytesIO
from typing import List
import requests
from PIL import Image, ImageOps
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url: str) -> requests.Response:
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url: str) -> Image.Image:
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
try:
resp = _retried_request(url)
except requests.HTTPError as e:
if "/full/" in url and 400 <= e.response.status_code < 500:
# Retry with max instead of full as IIIF size
resp = _retried_request(url.replace("/full/", "/max/"))
else:
raise e
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content)).convert("RGB")
# Do not rotate JPEG images (see https://github.com/python-pillow/Pillow/issues/4703)
image = ImageOps.exif_transpose(image)
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def get_bbox(polygon: List[List[int]]) -> str:
"""
Returns a comma-separated string of upper left-most pixel, width + height of the image
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
......@@ -5,7 +5,6 @@ Extract dataset from Arkindex using a corpus export.
import argparse
import pathlib
from typing import Union
from uuid import UUID
from dan.datasets.extract.arkindex import run
......@@ -13,7 +12,7 @@ from dan.datasets.extract.arkindex import run
MANUAL_SOURCE = "manual"
def parse_worker_version(worker_version_id) -> Union[str, bool]:
def parse_worker_version(worker_version_id) -> str | bool:
if worker_version_id == MANUAL_SOURCE:
return False
......@@ -34,13 +33,6 @@ def validate_char(char):
return char
def _valid_image_format(value: str):
im_format = value
if not im_format.startswith("."):
im_format = "." + im_format
return im_format
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
......@@ -55,18 +47,19 @@ def add_extract_parser(subcommands) -> None:
help="Path where the data were exported from Arkindex.",
)
parser.add_argument(
"--element-type",
"--dataset-id",
nargs="+",
type=str,
help="Type of elements to retrieve.",
type=UUID,
help="ID of the dataset to extract from Arkindex.",
required=True,
dest="dataset_ids",
)
parser.add_argument(
"--parent-element-type",
"--element-type",
nargs="+",
type=str,
help="Type of the parent element containing the data.",
required=False,
default="page",
help="Type of elements to retrieve.",
required=True,
)
parser.add_argument(
"--output",
......@@ -75,25 +68,6 @@ def add_extract_parser(subcommands) -> None:
required=True,
)
parser.add_argument(
"--train-folder",
type=UUID,
help="ID of the training folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--val-folder",
type=UUID,
help="ID of the validation folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--test-folder",
type=UUID,
help="ID of the testing folder to extract from Arkindex.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--entity-separators",
......@@ -131,18 +105,6 @@ def add_extract_parser(subcommands) -> None:
required=False,
)
parser.add_argument(
"--max-width",
type=int,
help="Images larger than this width will be resized to this width.",
)
parser.add_argument(
"--max-height",
type=int,
help="Images larger than this height will be resized to this height.",
)
parser.add_argument(
"--subword-vocab-size",
type=int,
......@@ -151,13 +113,6 @@ def add_extract_parser(subcommands) -> None:
)
# Formatting arguments
parser.add_argument(
"--image-format",
type=_valid_image_format,
default=".jpg",
help="Images will be saved under this format.",
)
parser.add_argument(
"--keep-spaces",
action="store_true",
......
......@@ -5,55 +5,40 @@ import logging
import pickle
import random
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List
from uuid import UUID
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from arkindex_export import open_database
from arkindex_export import Dataset, DatasetElement, Element, open_database
from dan.datasets.extract.db import (
Element,
get_dataset_elements,
get_elements,
get_transcription_entities,
get_transcriptions,
)
from dan.datasets.extract.exceptions import (
ImageDownloadError,
NoTranscriptionError,
ProcessingError,
UnknownTokenInText,
)
from dan.datasets.extract.utils import (
Tokenizer,
download_image,
entities_to_xml,
get_bbox,
get_translation_map,
get_vocabulary,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LANGUAGE_DIR = "language_model" # Subpath to the language model directory.
TRAIN_NAME = "train"
SPLIT_NAMES = [TRAIN_NAME, "val", "test"]
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
# IIIF 2.0 uses `full`
IIIF_FULL_SIZE = "full"
VAL_NAME = "val"
TEST_NAME = "test"
SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
logger = logging.getLogger(__name__)
......@@ -65,34 +50,26 @@ class ArkindexExtractor:
def __init__(
self,
folders: list = [],
dataset_ids: List[UUID] | None = None,
element_type: List[str] = [],
parent_element_type: str = None,
output: Path = None,
output: Path | None = None,
entity_separators: List[str] = ["\n", " "],
unknown_token: str = "",
tokens: Path = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
tokens: Path | None = None,
transcription_worker_version: str | bool | None = None,
entity_worker_version: str | bool | None = None,
keep_spaces: bool = False,
image_extension: str = "",
allow_empty: bool = False,
subword_vocab_size: int = 1000,
) -> None:
self.folders = folders
self.dataset_ids = dataset_ids
self.element_type = element_type
self.parent_element_type = parent_element_type
self.output = output
self.entity_separators = entity_separators
self.unknown_token = unknown_token
self.tokens = parse_tokens(tokens) if tokens else {}
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
self.allow_empty = allow_empty
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
......@@ -104,41 +81,9 @@ class ArkindexExtractor:
self.language_tokens = []
self.language_lexicon = defaultdict(list)
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
# NER extraction
self.translation_map: Dict[str, str] | None = get_translation_map(self.tokens)
def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height
):
return IIIF_FULL_SIZE
bigger_width = self.max_width and width >= self.max_width
bigger_height = self.max_height and height >= self.max_height
if bigger_width and bigger_height:
# Resize to the biggest dim to keep aspect ratio
# Only resize width is bigger than max size
# This ratio tells which dim needs the biggest shrinking
ratio = width * self.max_height / (height * self.max_width)
return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
elif bigger_width:
return f"{self.max_width},"
# Only resize height is bigger than max size
elif bigger_height:
return f",{self.max_height}"
def build_iiif_url(self, polygon, image_url) -> Tuple[BoundingBox, str]:
bbox = polygon_to_bbox(json.loads(str(polygon)))
size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
# Rotations are done using the lib
return bbox, IIIF_URL.format(
image_url=image_url, bbox=get_bbox(polygon), size=size
)
def translate(self, text: str):
"""
Use translation map to replace XML tags to actual tokens
......@@ -177,48 +122,7 @@ class ArkindexExtractor:
)
)
def get_image(
self,
split: str,
polygon: List[List[int]],
image_url: str,
destination: Path,
) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param split: Dataset split this image belongs to.
:param polygon: Polygon of the processed element.
:param image_url: Base IIIF URL of the image.
:param destination: Where the image should be saved.
"""
bbox, download_url = self.build_iiif_url(polygon=polygon, image_url=image_url)
try:
img: Image.Image = download_image(download_url)
# The polygon's coordinate are in the referential of the full image
# We need to remove the offset of the bounding rectangle
polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]
# Normalize bbox
bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)
image = extract(
img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
polygon=np.asarray(polygon).clip(0),
bbox=bbox,
extraction_mode=Extraction.boundingRect,
max_deskew_angle=45,
)
destination.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(destination), image)
except Exception as e:
raise ImageDownloadError(
split=split, path=destination, url=download_url, exc=e
)
def format_text(self, text: str, charset: Optional[set] = None):
def format_text(self, text: str, charset: set | None = None):
if not self.keep_spaces:
text = normalize_spaces(text)
text = normalize_linebreaks(text)
......@@ -234,11 +138,7 @@ class ArkindexExtractor:
)
return text.strip()
def process_element(
self,
element: Element,
split: str,
):
def process_element(self, dataset_parent: DatasetElement, element: Element):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
......@@ -248,46 +148,33 @@ class ArkindexExtractor:
if self.unknown_token in text:
raise UnknownTokenInText(element_id=element.id)
image_path = Path(self.output, IMAGES_DIR, split, element.id).with_suffix(
self.image_extension
)
# Create task for multithreading pool if image does not exist yet
if not image_path.exists():
self.tasks.append(
{
"split": split,
"polygon": json.loads(str(element.polygon)),
"image_url": element.image.url,
"destination": image_path,
}
)
text = self.format_text(
text,
# Do not replace unknown characters in train split
charset=self.charset if split != TRAIN_NAME else None,
charset=self.charset if dataset_parent.set_name != TRAIN_NAME else None,
)
self.data[split][str(image_path)] = text
self.data[dataset_parent.set_name][element.id] = {
"dataset_id": dataset_parent.dataset_id,
"text": text,
"image": {
"iiif_url": element.image.url,
"polygon": json.loads(element.polygon),
},
}
self.charset = self.charset.union(set(text))
def process_parent(
self,
pbar,
parent: Element,
split: str,
):
def process_parent(self, pbar, dataset_parent: DatasetElement):
"""
Extract data from a parent element.
"""
base_description = (
f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
)
parent = dataset_parent.element
base_description = f"Extracting data from {parent.type} ({parent.id}) for split ({dataset_parent.set_name})"
pbar.set_description(desc=base_description)
if self.element_type == [parent.type]:
try:
self.process_element(parent, split)
self.process_element(dataset_parent, parent)
except ProcessingError as e:
logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements
......@@ -302,7 +189,7 @@ class ArkindexExtractor:
# Update description to update the children processing progress
pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
try:
self.process_element(element, split)
self.process_element(dataset_parent, element)
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
......@@ -326,8 +213,10 @@ class ArkindexExtractor:
# Build LM corpus
train_corpus = [
text.replace(self.mapping.linebreak.display, self.mapping.space.display)
for text in self.data["train"].values()
values["text"].replace(
self.mapping.linebreak.display, self.mapping.space.display
)
for values in self.data[TRAIN_NAME].values()
]
tokenizer = Tokenizer(
......@@ -361,7 +250,7 @@ class ArkindexExtractor:
]
def export(self):
(self.output / "labels.json").write_text(
(self.output / "split.json").write_text(
json.dumps(
self.data,
sort_keys=True,
......@@ -382,87 +271,52 @@ class ArkindexExtractor:
pickle.dumps(sorted(list(self.charset)))
)
def download_images(self):
failed_downloads = []
with tqdm(
desc="Downloading images", total=len(self.tasks)
) as pbar, ThreadPoolExecutor() as executor:
def process_future(future: Future):
"""
Callback function called at the end of the thread
"""
# Update the progress bar count
pbar.update(1)
exc = future.exception()
if exc is None:
# No error
return
# If failed, tag for removal
assert isinstance(exc, ImageDownloadError)
# Remove transcription from labels dict
del self.data[exc.split][exc.path]
# Save tried URL
failed_downloads.append((exc.url, exc.message))
# Submit all tasks
for task in self.tasks:
executor.submit(self.get_image, **task).add_done_callback(
process_future
)
if failed_downloads:
logger.error(f"Failed to download {len(failed_downloads)} image(s).")
print(*list(map(": ".join, failed_downloads)), sep="\n")
def run(self):
# Iterate over the subsets to find the page images and labels.
for folder_id, split in zip(self.folders, SPLIT_NAMES):
with tqdm(
get_elements(
folder_id,
[self.parent_element_type],
),
desc=f"Extracting data from ({folder_id}) for split ({split})",
) as pbar:
# Iterate over the pages to create splits at page level.
for parent in pbar:
self.process_parent(
pbar=pbar,
parent=parent,
split=split,
)
# Progress bar updates
pbar.update()
pbar.refresh()
# Retrieve the Dataset and its splits from the cache
for dataset_id in self.dataset_ids:
dataset = Dataset.get_by_id(dataset_id)
splits = dataset.sets.split(",")
if not set(splits).issubset(set(SPLIT_NAMES)):
logger.warning(
f'Dataset {dataset.name} ({dataset.id}) does not have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps'
)
continue
# Iterate over the subsets to find the page images and labels.
for split in splits:
with tqdm(
get_dataset_elements(dataset, split),
desc=f"Extracting data from ({dataset_id}) for split ({split})",
) as pbar:
# Iterate over the pages to create splits at page level.
for parent in pbar:
self.process_parent(
pbar=pbar,
dataset_parent=parent,
)
# Progress bar updates
pbar.update()
pbar.refresh()
if not self.data:
raise Exception(
"No data was extracted using the provided export database and parameters."
)
self.download_images()
self.format_lm_files()
self.export()
def run(
database: Path,
dataset_ids: List[UUID],
element_type: List[str],
parent_element_type: str,
output: Path,
entity_separators: List[str],
unknown_token: str,
tokens: Path,
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
max_width: Optional[int],
max_height: Optional[int],
image_format: str,
transcription_worker_version: str | bool | None,
entity_worker_version: str | bool | None,
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
......@@ -470,27 +324,19 @@ def run(
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
folders = [str(train_folder), str(val_folder), str(test_folder)]
# Create directories
for split in SPLIT_NAMES:
Path(output, IMAGES_DIR, split).mkdir(parents=True, exist_ok=True)
Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True)
ArkindexExtractor(
folders=folders,
dataset_ids=dataset_ids,
element_type=element_type,
parent_element_type=parent_element_type,
output=output,
entity_separators=entity_separators,
unknown_token=unknown_token,
tokens=tokens,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
max_width=max_width,
max_height=max_height,
keep_spaces=keep_spaces,
image_extension=image_format,
allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size,
).run()
# -*- coding: utf-8 -*-
from typing import List, Optional, Union
from typing import List
from arkindex_export import Image
from arkindex_export.models import (
Dataset,
DatasetElement,
Element,
Entity,
EntityType,
......@@ -13,6 +14,26 @@ from arkindex_export.models import (
from arkindex_export.queries import list_children
def get_dataset_elements(
dataset: Dataset,
split: str,
):
"""
Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus
"""
query = (
DatasetElement.select()
.join(Element)
.join(Image, on=(DatasetElement.element.image == Image.id))
.where(
DatasetElement.dataset == dataset,
DatasetElement.set_name == split,
)
)
return query
def get_elements(
parent_id: str,
element_type: List[str],
......@@ -41,7 +62,7 @@ def build_worker_version_filter(ArkindexModel, worker_version):
def get_transcriptions(
element_id: str, transcription_worker_version: Union[str, bool]
element_id: str, transcription_worker_version: str | bool
) -> List[Transcription]:
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
......@@ -61,7 +82,7 @@ def get_transcriptions(
def get_transcription_entities(
transcription_id: str,
entity_worker_version: Optional[Union[str, bool]],
entity_worker_version: str | bool | None,
supported_types: List[str],
) -> List[TranscriptionEntity]:
"""
......
# -*- coding: utf-8 -*-
from pathlib import Path
class ProcessingError(Exception):
......@@ -21,21 +20,6 @@ class ElementProcessingError(ProcessingError):
self.element_id = element_id
class ImageDownloadError(Exception):
"""
Raised when an element's image could not be downloaded
"""
def __init__(
self, split: str, path: Path, url: str, exc: Exception, *args: object
) -> None:
super().__init__(*args)
self.split: str = split
self.path: str = str(path)
self.url: str = url
self.message = f"{str(exc)} for element {path.stem}"
class NoTranscriptionError(ElementProcessingError):
"""
Raised when there are no transcriptions on an element
......
......@@ -4,31 +4,19 @@ import logging
import operator
import re
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, Iterator, List, Optional, Union
from typing import Dict, Iterator, List
import requests
import sentencepiece as spm
from lxml.etree import Element, SubElement, tostring
from nltk import wordpunct_tokenize
from PIL import Image, ImageOps
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from arkindex_export import TranscriptionEntity
from dan.utils import EntityType, LMTokenMapping
logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces
TRIM_SPACE_REGEX = re.compile(r"[\t ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r\n]+")
......@@ -42,57 +30,6 @@ ENCODING_MAP = {
}
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url):
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url):
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
try:
resp = _retried_request(url)
except requests.HTTPError as e:
if "/full/" in url and 400 <= e.response.status_code < 500:
# Retry with max instead of full as IIIF size
resp = _retried_request(url.replace("/full/", "/max/"))
else:
raise e
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content)).convert("RGB")
# Do not rotate JPEG images (see https://github.com/python-pillow/Pillow/issues/4703)
image = ImageOps.exif_transpose(image)
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def normalize_linebreaks(text: str) -> str:
"""
Remove begin/ending linebreaks.
......@@ -111,17 +48,6 @@ def normalize_spaces(text: str) -> str:
return TRIM_SPACE_REGEX.sub(" ", text.strip())
def get_bbox(polygon: List[List[int]]) -> str:
"""
Arkindex polygon stored as string
returns a comma-separated string of upper left-most pixel, width + height of the image
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
def get_vocabulary(tokenized_text: List[str]) -> set[str]:
"""
Compute set of vocabulary from tokenzied text.
......@@ -146,7 +72,7 @@ class Tokenizer:
unknown_token: str
outdir: Path
mapping: LMTokenMapping
tokens: Optional[EntityType] = None
tokens: EntityType | None = None
subword_vocab_size: int = 1000
sentencepiece_model: spm.SentencePieceProcessor = field(init=False)
......@@ -155,7 +81,7 @@ class Tokenizer:
return self.outdir / "subword_tokenizer"
@property
def ner_tokens(self) -> Union[List[str], Iterator[str]]:
def ner_tokens(self) -> List[str] | Iterator[str]:
if self.tokens is None:
return []
return itertools.chain(
......@@ -253,7 +179,7 @@ def slugify(text: str):
return text.replace(" ", "_")
def get_translation_map(tokens: Dict[str, EntityType]) -> Optional[Dict[str, str]]:
def get_translation_map(tokens: Dict[str, EntityType]) -> Dict[str, str] | None:
if not tokens:
return
......@@ -321,7 +247,7 @@ class XMLEntity:
def entities_to_xml(
text: str,
predictions: List[TranscriptionEntity],
entity_separators: Optional[List[str]] = None,
entity_separators: List[str] | None = None,
) -> str:
"""Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag.
Its type will be used to name the tag.
......@@ -341,7 +267,7 @@ def entities_to_xml(
return separator
return ""
def add_portion(entity_offset: Optional[int] = None):
def add_portion(entity_offset: int | None = None):
"""
Add the portion of text between entities either:
- after the last node, if there is one before
......
# -*- coding: utf-8 -*-
from typing import Dict, List, Union
from typing import Dict, List
import numpy as np
import torch
......@@ -559,7 +559,7 @@ class CTCLanguageDecoder:
def post_process(
self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
) -> Dict[str, List[str | float]]:
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
:param hypotheses: List of hypotheses returned by the decoder.
......@@ -594,7 +594,7 @@ class CTCLanguageDecoder:
def __call__(
self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
) -> Dict[str, List[str | float]]:
"""
Decode a feature vector using n-gram language modelling.
:param batch_features: Feature vector of size (batch_size, n_tokens, n_frames).
......
......@@ -11,7 +11,7 @@ import torch
import torch.multiprocessing as mp
from dan.ocr.manager.training import Manager
from dan.ocr.utils import update_config
from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config
from dan.utils import read_json
logger = logging.getLogger(__name__)
......@@ -50,20 +50,34 @@ def eval(rank, config, mlflow_logging):
model = Manager(config)
model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"]
metric_names = [
"cer",
"cer_no_token",
"wer",
"wer_no_punct",
"wer_no_token",
"time",
]
if config["dataset"]["tokens"] is not None:
metric_names.append("ner")
metrics_table = create_metrics_table(metric_names)
for dataset_name in config["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]:
for set_name in ["train", "val", "test"]:
logger.info(f"Evaluating on set `{set_name}`")
model.evaluate(
metrics = model.evaluate(
"{}-{}".format(dataset_name, set_name),
[
(dataset_name, set_name),
],
metrics,
output=True,
metric_names,
mlflow_logging=mlflow_logging,
)
add_metrics_table_row(metrics_table, set_name, metrics)
print(metrics_table)
def run(config: dict):
update_config(config)
......
......@@ -3,7 +3,7 @@ import re
from collections import defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List
import editdistance
import numpy as np
......@@ -19,11 +19,12 @@ REGEX_CONSECUTIVE_SPACES = re.compile(r" +")
# Keep only one space character
REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
# Mapping between computation tasks (CER, WER, NER) and their metric keyword
METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
class MetricManager:
def __init__(
self, metric_names: List[str], dataset_name: str, tokens: Optional[Path]
):
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self.dataset_name: str = dataset_name
self.remove_tokens: str = None
......@@ -34,42 +35,50 @@ class MetricManager:
+ list(map(attrgetter("end"), tokens.values()))
)
self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])")
self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
def edit_cer_from_string(self, gt: str, pred: str):
def format_string_for_cer(self, text: str, remove_token: bool = False):
"""
Format and compute edit distance between two strings at character level
Format string for CER computation: remove layout tokens and extra spaces
"""
gt = self.format_string_for_cer(gt)
pred = self.format_string_for_cer(pred)
return editdistance.eval(gt, pred)
if remove_token and self.remove_tokens is not None:
text = self.remove_tokens.sub("", text)
def nb_chars_cer_from_string(self, gt: str) -> int:
"""
Compute length after formatting of ground truth string
"""
return len(self.format_string_for_cer(gt))
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
def format_string_for_wer(self, text: str, remove_punct: bool = False):
def format_string_for_wer(
self, text: str, remove_punct: bool = False, remove_token: bool = False
):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
if remove_punct:
text = REGEX_PUNCTUATION.sub("", text)
if self.remove_tokens is not None:
if remove_token and self.remove_tokens is not None:
text = self.remove_tokens.sub("", text)
return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ")
def format_string_for_cer(self, text: str):
def format_string_for_ner(self, text: str):
"""
Format string for CER computation: remove layout tokens and extra spaces
Format string for NER computation: only keep layout tokens
"""
if self.remove_tokens is not None:
text = self.remove_tokens.sub("", text)
return self.keep_tokens.sub("", text)
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
def _format_string(self, task: str, *args, **kwargs):
"""
Call the proper `format_string_for_*` method for the given task
"""
match task:
case "cer":
return self.format_string_for_cer(*args, **kwargs)
case "wer":
return self.format_string_for_wer(*args, **kwargs)
case "ner":
return self.format_string_for_ner(*args, **kwargs)
def update_metrics(self, batch_metrics):
"""
......@@ -97,11 +106,20 @@ class MetricManager:
display_values["sample_time"] = float(round(sample_time, 4))
display_values[metric_name] = value
continue
case "cer":
num_name, denom_name = "edit_chars", "nb_chars"
case "wer" | "wer_no_punct":
case (
"cer"
| "cer_no_token"
| "wer"
| "wer_no_punct"
| "wer_no_token"
| "ner"
):
keyword = METRICS_KEYWORD[metric_name[:3]]
suffix = metric_name[3:]
num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
num_name, denom_name = (
"edit_" + keyword + suffix,
"nb_" + keyword + suffix,
)
case "loss" | "loss_ce":
display_values[metric_name] = round(
float(
......@@ -139,21 +157,37 @@ class MetricManager:
gt, prediction = values["str_y"], values["str_x"]
for metric_name in metric_names:
match metric_name:
case "cer":
metrics["edit_chars"] = list(
map(self.edit_cer_from_string, gt, prediction)
)
metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt))
case "wer" | "wer_no_punct":
case (
"cer"
| "cer_no_token"
| "wer"
| "wer_no_punct"
| "wer_no_token"
| "ner"
):
task = metric_name[:3]
keyword = METRICS_KEYWORD[task]
suffix = metric_name[3:]
split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)]))
# Add extra parameters for the format functions
extras = []
if suffix == "_no_punct":
extras.append([{"remove_punct": True}])
elif suffix == "_no_token":
extras.append([{"remove_token": True}])
# Run the format function for the desired computation (CER, WER or NER)
split_gt = list(map(self._format_string, [task], gt, *extras))
split_pred = list(
map(self.format_string_for_wer, prediction, [bool(suffix)])
map(self._format_string, [task], prediction, *extras)
)
metrics["edit_words" + suffix] = list(
# Compute and store edit distance/length for the desired level
# (chars, words or tokens) as metrics
metrics["edit_" + keyword + suffix] = list(
map(editdistance.eval, split_gt, split_pred)
)
metrics["nb_words" + suffix] = list(map(len, split_gt))
metrics["nb_" + keyword + suffix] = list(map(len, split_gt))
case "loss" | "loss_ce":
metrics[metric_name] = [
values[metric_name],
......
......@@ -749,8 +749,8 @@ class GenericTrainingManager:
return display_values
def evaluate(
self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False
):
self, custom_name, sets_list, metric_names, mlflow_logging=False
) -> Dict[str, int | float]:
"""
Main loop for evaluation
"""
......@@ -798,19 +798,19 @@ class GenericTrainingManager:
display_values, logging_name, mlflow_logging, self.is_master
)
# output metrics values if requested
if output:
if "pred" in metric_names:
self.output_pred(custom_name)
metrics = self.metric_manager[custom_name].get_display_values(output=True)
path = self.paths["results"] / "predict_{}_{}.yaml".format(
custom_name, self.latest_epoch
)
path.write_text(yaml.dump(metrics))
if "pred" in metric_names:
self.output_pred(custom_name)
metrics = self.metric_manager[custom_name].get_display_values(output=True)
path = self.paths["results"] / "predict_{}_{}.yaml".format(
custom_name, self.latest_epoch
)
path.write_text(yaml.dump(metrics))
if mlflow_logging:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
if mlflow_logging:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
return metrics
def output_pred(self, name):
path = self.paths["results"] / "predict_{}_{}.yaml".format(
......