Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (6)
Showing
with 498 additions and 320 deletions
......@@ -4,6 +4,7 @@ Preprocess datasets for training.
"""
from dan.datasets.analyze import add_analyze_parser
from dan.datasets.download import add_download_parser
from dan.datasets.entities import add_entities_parser
from dan.datasets.extract import add_extract_parser
from dan.datasets.tokens import add_tokens_parser
......@@ -18,6 +19,7 @@ def add_dataset_parser(subcommands) -> None:
subcommands = parser.add_subparsers(metavar="subcommand")
add_extract_parser(subcommands)
add_download_parser(subcommands)
add_analyze_parser(subcommands)
add_entities_parser(subcommands)
add_tokens_parser(subcommands)
......@@ -3,7 +3,7 @@ import logging
from collections import Counter, defaultdict
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List
import imagesize
import numpy as np
......@@ -157,7 +157,7 @@ class Statistics:
level=3,
)
def run(self, labels: Dict, tokens: Optional[Dict]):
def run(self, labels: Dict, tokens: Dict | None):
# Iterate over each split
for split_name, split_data in labels.items():
self.document.new_header(level=1, title=split_name.capitalize())
......@@ -175,7 +175,7 @@ class Statistics:
self.document.create_md_file()
def run(labels: Dict, tokens: Optional[Dict], output: Path) -> None:
def run(labels: Dict, tokens: Dict | None, output: Path) -> None:
"""
Compute and save a dataset statistics.
"""
......
# -*- coding: utf-8 -*-
"""
Download images of a dataset from a split extracted by DAN
"""
import pathlib
from dan.datasets.download.images import run
def _valid_image_format(value: str):
im_format = value
if not im_format.startswith("."):
im_format = "." + im_format
return im_format
def add_download_parser(subcommands) -> None:
parser = subcommands.add_parser(
"download",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--output",
type=pathlib.Path,
help="Path where the `split.json` file is stored and where the data will be generated.",
required=True,
)
parser.add_argument(
"--max-width",
type=int,
help="Images larger than this width will be resized to this width.",
)
parser.add_argument(
"--max-height",
type=int,
help="Images larger than this height will be resized to this height.",
)
# Formatting arguments
parser.add_argument(
"--image-format",
type=_valid_image_format,
default=".jpg",
help="Images will be saved under this format.",
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
from pathlib import Path
class ImageDownloadError(Exception):
"""
Raised when an element's image could not be downloaded
"""
def __init__(
self, split: str, path: Path, url: str, exc: Exception, *args: object
) -> None:
super().__init__(*args)
self.split: str = split
self.path: str = str(path)
self.url: str = url
self.message = f"{str(exc)} for element {path.stem}"
# -*- coding: utf-8 -*-
import json
import logging
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Tuple
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from dan.datasets.download.exceptions import ImageDownloadError
from dan.datasets.download.utils import download_image, get_bbox
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
IMAGES_DIR = "images" # Subpath to the images directory.
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
# IIIF 2.0 uses `full`
IIIF_FULL_SIZE = "full"
logger = logging.getLogger(__name__)
class ImageDownloader:
"""
Download images from extracted data
"""
def __init__(
self,
output: Path | None = None,
max_width: int | None = None,
max_height: int | None = None,
image_extension: str = "",
) -> None:
self.output = output
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
# Load split file
split_file = output / "split.json" if output else None
self.split: Dict = (
json.loads(split_file.read_text())
if split_file and split_file.is_file()
else {}
)
# Create directories
for split_name in self.split:
Path(output, IMAGES_DIR, split_name).mkdir(parents=True, exist_ok=True)
self.data: Dict = defaultdict(dict)
def check_extraction(self, values: dict) -> str | None:
# Check image parameters
if not (image := values.get("image")):
return "Image information not found"
# Only support `iiif_url` with `polygon` for now
if not image.get("iiif_url"):
return "Image IIIF URL not found"
if not image.get("polygon"):
return "Image polygon not found"
# Check text parameter
if values.get("text") is None:
return "Text not found"
def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height
):
return IIIF_FULL_SIZE
bigger_width = self.max_width and width >= self.max_width
bigger_height = self.max_height and height >= self.max_height
if bigger_width and bigger_height:
# Resize to the biggest dim to keep aspect ratio
# Only resize width is bigger than max size
# This ratio tells which dim needs the biggest shrinking
ratio = width * self.max_height / (height * self.max_width)
return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
elif bigger_width:
return f"{self.max_width},"
# Only resize height is bigger than max size
elif bigger_height:
return f",{self.max_height}"
def build_iiif_url(
self, polygon: List[List[int]], image_url: str
) -> Tuple[BoundingBox, str]:
bbox = polygon_to_bbox(polygon)
size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
# Rotations are done using the lib
return IIIF_URL.format(image_url=image_url, bbox=get_bbox(polygon), size=size)
def build_tasks(self) -> List[Dict[str, str]]:
tasks = []
for split, items in self.split.items():
# Create directories
destination = self.output / IMAGES_DIR / split
destination.mkdir(parents=True, exist_ok=True)
for element_id, values in items.items():
image_path = (destination / element_id).with_suffix(
self.image_extension
)
error = self.check_extraction(values)
if error:
logger.warning(f"{image_path}: {error}")
continue
self.data[split][str(image_path)] = values["text"]
# Create task for multithreading pool if image does not exist yet
if image_path.exists():
continue
polygon = values["image"]["polygon"]
iiif_url = values["image"]["iiif_url"]
tasks.append(
{
"split": split,
"polygon": polygon,
"image_url": self.build_iiif_url(polygon, iiif_url),
"destination": image_path,
}
)
return tasks
def get_image(
self,
split: str,
polygon: List[List[int]],
image_url: str,
destination: Path,
) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param split: Dataset split this image belongs to.
:param polygon: Polygon of the processed element.
:param image_url: Base IIIF URL of the image.
:param destination: Where the image should be saved.
"""
bbox = polygon_to_bbox(polygon)
try:
img: Image.Image = download_image(image_url)
# The polygon's coordinate are in the referential of the full image
# We need to remove the offset of the bounding rectangle
polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]
# Normalize bbox
bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)
image = extract(
img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
polygon=np.asarray(polygon).clip(0),
bbox=bbox,
extraction_mode=Extraction.boundingRect,
max_deskew_angle=45,
)
cv2.imwrite(str(destination), image)
except Exception as e:
raise ImageDownloadError(
split=split, path=destination, url=image_url, exc=e
)
def download_images(self, tasks: List[Dict[str, str]]) -> None:
"""
Execute each image download task in parallel
:param tasks: List of tasks to execute.
"""
failed_downloads = []
with tqdm(
desc="Downloading images", total=len(tasks)
) as pbar, ThreadPoolExecutor() as executor:
def process_future(future: Future):
"""
Callback function called at the end of the thread
"""
# Update the progress bar count
pbar.update(1)
exc = future.exception()
if exc is None:
# No error
return
# If failed, tag for removal
assert isinstance(exc, ImageDownloadError)
# Remove transcription from labels dict
del self.data[exc.split][exc.path]
# Save tried URL
failed_downloads.append((exc.url, exc.message))
# Submit all tasks
for task in tasks:
executor.submit(self.get_image, **task).add_done_callback(
process_future
)
if failed_downloads:
logger.error(f"Failed to download {len(failed_downloads)} image(s).")
print(*list(map(": ".join, failed_downloads)), sep="\n")
def export(self) -> None:
"""
Writes a `labels.json` file containing a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
"""
(self.output / "labels.json").write_text(
json.dumps(
self.data,
sort_keys=True,
indent=4,
)
)
def run(self) -> None:
"""
Download the missing images from a `split.json` file and build a `labels.json` file containing
a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
"""
tasks: List[Dict[str, str]] = self.build_tasks()
self.download_images(tasks)
self.export()
def run(
output: Path,
max_width: int | None,
max_height: int | None,
image_format: str,
):
"""
Download the missing images from a `split.json` file and build a `labels.json` file containing
a mapping of the images that have been correctly uploaded (identified by its path)
to the ground-truth transcription (with NER tokens if needed).
:param output: Path where the `split.json` file is stored and where the data will be generated
:param max_width: Images larger than this width will be resized to this width
:param max_height: Images larger than this height will be resized to this height
:param image_format: Images will be saved under this format
"""
ImageDownloader(
output=output,
max_width=max_width,
max_height=max_height,
image_extension=image_format,
).run()
# -*- coding: utf-8 -*-
import logging
from io import BytesIO
from typing import List
import requests
from PIL import Image, ImageOps
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url: str) -> requests.Response:
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url: str) -> Image.Image:
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
try:
resp = _retried_request(url)
except requests.HTTPError as e:
if "/full/" in url and 400 <= e.response.status_code < 500:
# Retry with max instead of full as IIIF size
resp = _retried_request(url.replace("/full/", "/max/"))
else:
raise e
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content)).convert("RGB")
# Do not rotate JPEG images (see https://github.com/python-pillow/Pillow/issues/4703)
image = ImageOps.exif_transpose(image)
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def get_bbox(polygon: List[List[int]]) -> str:
"""
Returns a comma-separated string of upper left-most pixel, width + height of the image
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
......@@ -5,7 +5,6 @@ Extract dataset from Arkindex using a corpus export.
import argparse
import pathlib
from typing import Union
from uuid import UUID
from dan.datasets.extract.arkindex import run
......@@ -13,7 +12,7 @@ from dan.datasets.extract.arkindex import run
MANUAL_SOURCE = "manual"
def parse_worker_version(worker_version_id) -> Union[str, bool]:
def parse_worker_version(worker_version_id) -> str | bool:
if worker_version_id == MANUAL_SOURCE:
return False
......@@ -34,13 +33,6 @@ def validate_char(char):
return char
def _valid_image_format(value: str):
im_format = value
if not im_format.startswith("."):
im_format = "." + im_format
return im_format
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
......@@ -131,18 +123,6 @@ def add_extract_parser(subcommands) -> None:
required=False,
)
parser.add_argument(
"--max-width",
type=int,
help="Images larger than this width will be resized to this width.",
)
parser.add_argument(
"--max-height",
type=int,
help="Images larger than this height will be resized to this height.",
)
parser.add_argument(
"--subword-vocab-size",
type=int,
......@@ -151,13 +131,6 @@ def add_extract_parser(subcommands) -> None:
)
# Formatting arguments
parser.add_argument(
"--image-format",
type=_valid_image_format,
default=".jpg",
help="Images will be saved under this format.",
)
parser.add_argument(
"--keep-spaces",
action="store_true",
......
......@@ -5,14 +5,10 @@ import logging
import pickle
import random
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List
from uuid import UUID
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from arkindex_export import open_database
......@@ -23,37 +19,24 @@ from dan.datasets.extract.db import (
get_transcriptions,
)
from dan.datasets.extract.exceptions import (
ImageDownloadError,
NoTranscriptionError,
ProcessingError,
UnknownTokenInText,
)
from dan.datasets.extract.utils import (
Tokenizer,
download_image,
entities_to_xml,
get_bbox,
get_translation_map,
get_vocabulary,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LANGUAGE_DIR = "language_model" # Subpath to the language model directory.
TRAIN_NAME = "train"
SPLIT_NAMES = [TRAIN_NAME, "val", "test"]
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
# IIIF 2.0 uses `full`
IIIF_FULL_SIZE = "full"
logger = logging.getLogger(__name__)
......@@ -67,17 +50,14 @@ class ArkindexExtractor:
self,
folders: list = [],
element_type: List[str] = [],
parent_element_type: str = None,
output: Path = None,
parent_element_type: str | None = None,
output: Path | None = None,
entity_separators: List[str] = ["\n", " "],
unknown_token: str = "",
tokens: Path = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
tokens: Path | None = None,
transcription_worker_version: str | bool | None = None,
entity_worker_version: str | bool | None = None,
keep_spaces: bool = False,
image_extension: str = "",
allow_empty: bool = False,
subword_vocab_size: int = 1000,
) -> None:
......@@ -90,9 +70,6 @@ class ArkindexExtractor:
self.tokens = parse_tokens(tokens) if tokens else {}
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
self.allow_empty = allow_empty
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
......@@ -104,41 +81,9 @@ class ArkindexExtractor:
self.language_tokens = []
self.language_lexicon = defaultdict(list)
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
# NER extraction
self.translation_map: Dict[str, str] | None = get_translation_map(self.tokens)
def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height
):
return IIIF_FULL_SIZE
bigger_width = self.max_width and width >= self.max_width
bigger_height = self.max_height and height >= self.max_height
if bigger_width and bigger_height:
# Resize to the biggest dim to keep aspect ratio
# Only resize width is bigger than max size
# This ratio tells which dim needs the biggest shrinking
ratio = width * self.max_height / (height * self.max_width)
return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
elif bigger_width:
return f"{self.max_width},"
# Only resize height is bigger than max size
elif bigger_height:
return f",{self.max_height}"
def build_iiif_url(self, polygon, image_url) -> Tuple[BoundingBox, str]:
bbox = polygon_to_bbox(json.loads(str(polygon)))
size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
# Rotations are done using the lib
return bbox, IIIF_URL.format(
image_url=image_url, bbox=get_bbox(polygon), size=size
)
def translate(self, text: str):
"""
Use translation map to replace XML tags to actual tokens
......@@ -177,48 +122,7 @@ class ArkindexExtractor:
)
)
def get_image(
self,
split: str,
polygon: List[List[int]],
image_url: str,
destination: Path,
) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param split: Dataset split this image belongs to.
:param polygon: Polygon of the processed element.
:param image_url: Base IIIF URL of the image.
:param destination: Where the image should be saved.
"""
bbox, download_url = self.build_iiif_url(polygon=polygon, image_url=image_url)
try:
img: Image.Image = download_image(download_url)
# The polygon's coordinate are in the referential of the full image
# We need to remove the offset of the bounding rectangle
polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]
# Normalize bbox
bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)
image = extract(
img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
polygon=np.asarray(polygon).clip(0),
bbox=bbox,
extraction_mode=Extraction.boundingRect,
max_deskew_angle=45,
)
destination.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(destination), image)
except Exception as e:
raise ImageDownloadError(
split=split, path=destination, url=download_url, exc=e
)
def format_text(self, text: str, charset: Optional[set] = None):
def format_text(self, text: str, charset: set | None = None):
if not self.keep_spaces:
text = normalize_spaces(text)
text = normalize_linebreaks(text)
......@@ -234,11 +138,7 @@ class ArkindexExtractor:
)
return text.strip()
def process_element(
self,
element: Element,
split: str,
):
def process_element(self, element: Element, split: str):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
......@@ -248,36 +148,23 @@ class ArkindexExtractor:
if self.unknown_token in text:
raise UnknownTokenInText(element_id=element.id)
image_path = Path(self.output, IMAGES_DIR, split, element.id).with_suffix(
self.image_extension
)
# Create task for multithreading pool if image does not exist yet
if not image_path.exists():
self.tasks.append(
{
"split": split,
"polygon": json.loads(str(element.polygon)),
"image_url": element.image.url,
"destination": image_path,
}
)
text = self.format_text(
text,
# Do not replace unknown characters in train split
charset=self.charset if split != TRAIN_NAME else None,
)
self.data[split][str(image_path)] = text
self.data[split][element.id] = {
"text": text,
"image": {
"iiif_url": element.image.url,
"polygon": json.loads(element.polygon),
},
}
self.charset = self.charset.union(set(text))
def process_parent(
self,
pbar,
parent: Element,
split: str,
):
def process_parent(self, pbar, parent: Element, split: str):
"""
Extract data from a parent element.
"""
......@@ -326,8 +213,10 @@ class ArkindexExtractor:
# Build LM corpus
train_corpus = [
text.replace(self.mapping.linebreak.display, self.mapping.space.display)
for text in self.data["train"].values()
values["text"].replace(
self.mapping.linebreak.display, self.mapping.space.display
)
for values in self.data[TRAIN_NAME].values()
]
tokenizer = Tokenizer(
......@@ -361,7 +250,7 @@ class ArkindexExtractor:
]
def export(self):
(self.output / "labels.json").write_text(
(self.output / "split.json").write_text(
json.dumps(
self.data,
sort_keys=True,
......@@ -382,40 +271,6 @@ class ArkindexExtractor:
pickle.dumps(sorted(list(self.charset)))
)
def download_images(self):
failed_downloads = []
with tqdm(
desc="Downloading images", total=len(self.tasks)
) as pbar, ThreadPoolExecutor() as executor:
def process_future(future: Future):
"""
Callback function called at the end of the thread
"""
# Update the progress bar count
pbar.update(1)
exc = future.exception()
if exc is None:
# No error
return
# If failed, tag for removal
assert isinstance(exc, ImageDownloadError)
# Remove transcription from labels dict
del self.data[exc.split][exc.path]
# Save tried URL
failed_downloads.append((exc.url, exc.message))
# Submit all tasks
for task in self.tasks:
executor.submit(self.get_image, **task).add_done_callback(
process_future
)
if failed_downloads:
logger.error(f"Failed to download {len(failed_downloads)} image(s).")
print(*list(map(": ".join, failed_downloads)), sep="\n")
def run(self):
# Iterate over the subsets to find the page images and labels.
for folder_id, split in zip(self.folders, SPLIT_NAMES):
......@@ -442,7 +297,6 @@ class ArkindexExtractor:
"No data was extracted using the provided export database and parameters."
)
self.download_images()
self.format_lm_files()
self.export()
......@@ -458,11 +312,8 @@ def run(
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
max_width: Optional[int],
max_height: Optional[int],
image_format: str,
transcription_worker_version: str | bool | None,
entity_worker_version: str | bool | None,
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
......@@ -473,8 +324,6 @@ def run(
folders = [str(train_folder), str(val_folder), str(test_folder)]
# Create directories
for split in SPLIT_NAMES:
Path(output, IMAGES_DIR, split).mkdir(parents=True, exist_ok=True)
Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True)
ArkindexExtractor(
......@@ -487,10 +336,7 @@ def run(
tokens=tokens,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
max_width=max_width,
max_height=max_height,
keep_spaces=keep_spaces,
image_extension=image_format,
allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size,
).run()
# -*- coding: utf-8 -*-
from typing import List, Optional, Union
from typing import List
from arkindex_export import Image
from arkindex_export.models import (
......@@ -41,7 +41,7 @@ def build_worker_version_filter(ArkindexModel, worker_version):
def get_transcriptions(
element_id: str, transcription_worker_version: Union[str, bool]
element_id: str, transcription_worker_version: str | bool
) -> List[Transcription]:
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
......@@ -61,7 +61,7 @@ def get_transcriptions(
def get_transcription_entities(
transcription_id: str,
entity_worker_version: Optional[Union[str, bool]],
entity_worker_version: str | bool | None,
supported_types: List[str],
) -> List[TranscriptionEntity]:
"""
......
# -*- coding: utf-8 -*-
from pathlib import Path
class ProcessingError(Exception):
......@@ -21,21 +20,6 @@ class ElementProcessingError(ProcessingError):
self.element_id = element_id
class ImageDownloadError(Exception):
"""
Raised when an element's image could not be downloaded
"""
def __init__(
self, split: str, path: Path, url: str, exc: Exception, *args: object
) -> None:
super().__init__(*args)
self.split: str = split
self.path: str = str(path)
self.url: str = url
self.message = f"{str(exc)} for element {path.stem}"
class NoTranscriptionError(ElementProcessingError):
"""
Raised when there are no transcriptions on an element
......
......@@ -4,31 +4,19 @@ import logging
import operator
import re
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, Iterator, List, Optional, Union
from typing import Dict, Iterator, List
import requests
import sentencepiece as spm
from lxml.etree import Element, SubElement, tostring
from nltk import wordpunct_tokenize
from PIL import Image, ImageOps
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from arkindex_export import TranscriptionEntity
from dan.utils import EntityType, LMTokenMapping
logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces
TRIM_SPACE_REGEX = re.compile(r"[\t ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r\n]+")
......@@ -42,57 +30,6 @@ ENCODING_MAP = {
}
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url):
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url):
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
try:
resp = _retried_request(url)
except requests.HTTPError as e:
if "/full/" in url and 400 <= e.response.status_code < 500:
# Retry with max instead of full as IIIF size
resp = _retried_request(url.replace("/full/", "/max/"))
else:
raise e
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content)).convert("RGB")
# Do not rotate JPEG images (see https://github.com/python-pillow/Pillow/issues/4703)
image = ImageOps.exif_transpose(image)
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def normalize_linebreaks(text: str) -> str:
"""
Remove begin/ending linebreaks.
......@@ -111,17 +48,6 @@ def normalize_spaces(text: str) -> str:
return TRIM_SPACE_REGEX.sub(" ", text.strip())
def get_bbox(polygon: List[List[int]]) -> str:
"""
Arkindex polygon stored as string
returns a comma-separated string of upper left-most pixel, width + height of the image
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
def get_vocabulary(tokenized_text: List[str]) -> set[str]:
"""
Compute set of vocabulary from tokenzied text.
......@@ -146,7 +72,7 @@ class Tokenizer:
unknown_token: str
outdir: Path
mapping: LMTokenMapping
tokens: Optional[EntityType] = None
tokens: EntityType | None = None
subword_vocab_size: int = 1000
sentencepiece_model: spm.SentencePieceProcessor = field(init=False)
......@@ -155,7 +81,7 @@ class Tokenizer:
return self.outdir / "subword_tokenizer"
@property
def ner_tokens(self) -> Union[List[str], Iterator[str]]:
def ner_tokens(self) -> List[str] | Iterator[str]:
if self.tokens is None:
return []
return itertools.chain(
......@@ -253,7 +179,7 @@ def slugify(text: str):
return text.replace(" ", "_")
def get_translation_map(tokens: Dict[str, EntityType]) -> Optional[Dict[str, str]]:
def get_translation_map(tokens: Dict[str, EntityType]) -> Dict[str, str] | None:
if not tokens:
return
......@@ -321,7 +247,7 @@ class XMLEntity:
def entities_to_xml(
text: str,
predictions: List[TranscriptionEntity],
entity_separators: Optional[List[str]] = None,
entity_separators: List[str] | None = None,
) -> str:
"""Represent the transcription and its entities in XML format. Each entity will be exposed with an XML tag.
Its type will be used to name the tag.
......@@ -341,7 +267,7 @@ def entities_to_xml(
return separator
return ""
def add_portion(entity_offset: Optional[int] = None):
def add_portion(entity_offset: int | None = None):
"""
Add the portion of text between entities either:
- after the last node, if there is one before
......
# -*- coding: utf-8 -*-
from typing import Dict, List, Union
from typing import Dict, List
import numpy as np
import torch
......@@ -559,7 +559,7 @@ class CTCLanguageDecoder:
def post_process(
self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
) -> Dict[str, List[str | float]]:
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
:param hypotheses: List of hypotheses returned by the decoder.
......@@ -594,7 +594,7 @@ class CTCLanguageDecoder:
def __call__(
self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
) -> Dict[str, List[str | float]]:
"""
Decode a feature vector using n-gram language modelling.
:param batch_features: Feature vector of size (batch_size, n_tokens, n_frames).
......
......@@ -51,6 +51,9 @@ def eval(rank, config, mlflow_logging):
model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"]
if config["dataset"]["tokens"] is not None:
metrics.append("ner")
for dataset_name in config["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]:
logger.info(f"Evaluating on set `{set_name}`")
......
......@@ -3,7 +3,7 @@ import re
from collections import defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List
import editdistance
import numpy as np
......@@ -21,9 +21,7 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
class MetricManager:
def __init__(
self, metric_names: List[str], dataset_name: str, tokens: Optional[Path]
):
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self.dataset_name: str = dataset_name
self.remove_tokens: str = None
......@@ -34,6 +32,8 @@ class MetricManager:
+ list(map(attrgetter("end"), tokens.values()))
)
self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])")
self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
......@@ -71,6 +71,12 @@ class MetricManager:
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
def format_string_for_ner(self, text: str):
"""
Format string for NER computation: only keep layout tokens
"""
return self.keep_tokens.sub("", text)
def update_metrics(self, batch_metrics):
"""
Add batch metrics to the metrics
......@@ -102,6 +108,8 @@ class MetricManager:
case "wer" | "wer_no_punct":
suffix = metric_name[3:]
num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
case "ner":
num_name, denom_name = "edit_tokens", "nb_tokens"
case "loss" | "loss_ce":
display_values[metric_name] = round(
float(
......@@ -154,6 +162,13 @@ class MetricManager:
map(editdistance.eval, split_gt, split_pred)
)
metrics["nb_words" + suffix] = list(map(len, split_gt))
case "ner":
split_gt = list(map(self.format_string_for_ner, gt))
split_pred = list(map(self.format_string_for_ner, prediction))
metrics["edit_tokens"] = list(
map(editdistance.eval, split_gt, split_pred)
)
metrics["nb_tokens"] = list(map(len, split_gt))
case "loss" | "loss_ce":
metrics[metric_name] = [
values[metric_name],
......
......@@ -363,7 +363,7 @@ def get_polygon(
max_value: np.float32,
offset: int,
weights: np.ndarray,
size: Tuple[int, int] = None,
size: Tuple[int, int] | None = None,
max_object_height: int = 50,
) -> Tuple[dict, np.ndarray]:
"""
......
......@@ -5,7 +5,7 @@ import logging
import pickle
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -159,7 +159,7 @@ class DAN:
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
start_token: str = None,
start_token: str | None = None,
max_object_height: int = 50,
) -> dict:
"""
......@@ -426,7 +426,7 @@ def process_batch(
def run(
image_dir: Optional[Path],
image_dir: Path,
model: Path,
output: Path,
confidence_score: bool,
......
......@@ -34,6 +34,12 @@ def train(rank, params, mlflow_logging=False):
model = Manager(params)
model.load_model()
if params["dataset"]["tokens"] is not None:
if "ner" not in params["training"]["metrics"]["train"]:
params["training"]["metrics"]["train"].append("ner")
if "ner" not in params["training"]["metrics"]["eval"]:
params["training"]["metrics"]["eval"].append("ner")
if mlflow_logging:
logger.info("MLflow logging enabled")
......
......@@ -9,8 +9,9 @@ To extract the data, DAN uses an Arkindex export database in SQLite format. You
1. Structure the data into folders (`train` / `val` / `test`) in [Arkindex](https://demo.arkindex.org/).
1. [Export the project](https://doc.arkindex.org/howto/export/) in SQLite format.
1. Extract the data with the [extract command](../usage/datasets/extract.md).
1. Download images with the [download command](../usage/datasets/download.md).
This command will extract and format the images and labels needed to train DAN. It will also tokenize the training corpus at character, subword, and word levels, allowing you to combine DAN with an explicit statistical language model to improve performance.
These commands will extract and format the images and labels needed to train DAN. It will also tokenize the training corpus at character, subword, and word levels, allowing you to combine DAN with an explicit statistical language model to improve performance.
At the end, you should get the following tree structure:
......
# Exceptions
::: dan.datasets.download.exceptions
# Image
::: dan.datasets.download.images