Something went wrong on our end
-
Solene Tarride authoredSolene Tarride authored
extract.py 18.88 KiB
# -*- coding: utf-8 -*-
import json
import logging
import pickle
import random
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from uuid import UUID
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from arkindex_export import open_database
from dan.datasets.extract.db import (
Element,
get_elements,
get_transcription_entities,
get_transcriptions,
)
from dan.datasets.extract.exceptions import (
ImageDownloadError,
NoEndTokenError,
NoTranscriptionError,
ProcessingError,
UnknownTokenInText,
)
from dan.datasets.extract.utils import (
Tokenizer,
download_image,
get_bbox,
insert_token,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LANGUAGE_DIR = "language_model" # Subpath to the language model directory.
TRAIN_NAME = "train"
SPLIT_NAMES = [TRAIN_NAME, "val", "test"]
IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg"
# IIIF 2.0 uses `full`
IIIF_FULL_SIZE = "full"
logger = logging.getLogger(__name__)
class ArkindexExtractor:
"""
Extract data from Arkindex
"""
def __init__(
self,
folders: list = [],
element_type: List[str] = [],
parent_element_type: str = None,
output: Path = None,
entity_separators: List[str] = ["\n", " "],
unknown_token: str = "⁇",
tokens: Path = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
keep_spaces: bool = False,
image_extension: str = "",
allow_empty: bool = False,
subword_vocab_size: int = 1000,
) -> None:
self.folders = folders
self.element_type = element_type
self.parent_element_type = parent_element_type
self.output = output
self.entity_separators = entity_separators
self.unknown_token = unknown_token
self.tokens = parse_tokens(tokens) if tokens else None
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
self.allow_empty = allow_empty
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
self.subword_vocab_size = subword_vocab_size
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = defaultdict(list)
self.language_tokens = []
self.language_lexicon = defaultdict(list)
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
def get_iiif_size_arg(self, width: int, height: int) -> str:
if (self.max_width is None or width <= self.max_width) and (
self.max_height is None or height <= self.max_height
):
return IIIF_FULL_SIZE
bigger_width = self.max_width and width >= self.max_width
bigger_height = self.max_height and height >= self.max_height
if bigger_width and bigger_height:
# Resize to the biggest dim to keep aspect ratio
# Only resize width is bigger than max size
# This ratio tells which dim needs the biggest shrinking
ratio = width * self.max_height / (height * self.max_width)
return f"{self.max_width}," if ratio > 1 else f",{self.max_height}"
elif bigger_width:
return f"{self.max_width},"
# Only resize height is bigger than max size
elif bigger_height:
return f",{self.max_height}"
def build_iiif_url(self, polygon, image_url) -> Tuple[BoundingBox, str]:
bbox = polygon_to_bbox(json.loads(str(polygon)))
size = self.get_iiif_size_arg(width=bbox.width, height=bbox.height)
# Rotations are done using the lib
return bbox, IIIF_URL.format(
image_url=image_url, bbox=get_bbox(polygon), size=size
)
def _keep_char(self, char: str) -> bool:
# Keep all text by default if no separator was given
return not self.entity_separators or char in self.entity_separators
def reconstruct_text(self, full_text: str, entities) -> str:
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
text, text_offset = "", 0
for entity in entities:
# Text before entity
text += "".join(
filter(self._keep_char, full_text[text_offset : entity.offset])
)
entity_type: EntityType = self.tokens.get(entity.type)
if not entity_type:
logger.warning(
f"Label `{entity.type}` is missing in the NER configuration."
)
# We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends
elif not entity_type.end and not self.entity_separators:
raise NoEndTokenError(entity.type)
# Entity text:
# - with tokens if there is an entity_type
# - without tokens if there is no entity_type but we want to keep the whole text
if entity_type or not self.entity_separators:
text += insert_token(
full_text,
entity_type,
offset=entity.offset,
length=entity.length,
)
text_offset = entity.offset + entity.length
# Remaining text after the last entity
text += "".join(filter(self._keep_char, full_text[text_offset:]))
if not self.entity_separators or self.keep_spaces:
return text
# Add some clean up to avoid several separators between entities
text, full_text = "", text
for char in full_text:
last_char = text[-1] if len(text) else ""
# Keep the current character if there are no two consecutive separators
if (
char not in self.entity_separators
or last_char not in self.entity_separators
):
text += char
# If several separators follow each other, keep only one according to the given order
elif self.entity_separators.index(char) < self.entity_separators.index(
last_char
):
text = text[:-1] + char
# Remove separators at the beginning and end of text
return text.strip("".join(self.entity_separators))
def extract_transcription(self, element: Element):
"""
Extract the element's transcription.
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = get_transcriptions(
element.id, self.transcription_worker_version
)
if len(transcriptions) == 0:
if self.allow_empty:
return ""
raise NoTranscriptionError(element.id)
transcription = random.choice(transcriptions)
if not self.tokens:
return transcription.text.strip()
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
def get_image(
self,
split: str,
polygon: List[List[int]],
image_url: str,
destination: Path,
) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param split: Dataset split this image belongs to.
:param polygon: Polygon of the processed element.
:param image_url: Base IIIF URL of the image.
:param destination: Where the image should be saved.
"""
bbox, download_url = self.build_iiif_url(polygon=polygon, image_url=image_url)
try:
img: Image.Image = download_image(download_url)
# The polygon's coordinate are in the referential of the full image
# We need to remove the offset of the bounding rectangle
polygon = [(x - bbox.x, y - bbox.y) for x, y in polygon]
# Normalize bbox
bbox = BoundingBox(x=0, y=0, width=bbox.width, height=bbox.height)
image = extract(
img=cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR),
polygon=np.asarray(polygon).clip(0),
bbox=bbox,
extraction_mode=Extraction.boundingRect,
max_deskew_angle=45,
)
destination.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(destination), image)
except Exception as e:
raise ImageDownloadError(
split=split, path=str(destination), url=download_url, exc=e
)
def format_text(self, text: str, charset: Optional[set] = None):
if not self.keep_spaces:
text = normalize_spaces(text)
text = normalize_linebreaks(text)
# Replace unknown characters by the unknown token
if charset is not None:
unknown_charset = set(text) - charset
text = text.translate(
{
ord(unknown_char): self.unknown_token
for unknown_char in unknown_charset
}
)
return text.strip()
def process_element(
self,
element: Element,
split: str,
):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
"""
text = self.extract_transcription(element)
if self.unknown_token in text:
raise UnknownTokenInText(element_id=element.id)
image_path = Path(self.output, IMAGES_DIR, split, element.id).with_suffix(
self.image_extension
)
# Create task for multithreading pool if image does not exist yet
if not image_path.exists():
self.tasks.append(
{
"split": split,
"polygon": json.loads(str(element.polygon)),
"image_url": element.image.url,
"destination": image_path,
}
)
text = self.format_text(
text,
# Do not replace unknown characters in train split
charset=self.charset if split != TRAIN_NAME else None,
)
self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text))
def process_parent(
self,
pbar,
parent: Element,
split: str,
):
"""
Extract data from a parent element.
"""
base_description = (
f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
)
pbar.set_description(desc=base_description)
if self.element_type == [parent.type]:
try:
self.process_element(parent, split)
except ProcessingError as e:
logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements
else:
children = get_elements(
parent.id,
self.element_type,
)
nb_children = children.count()
for idx, element in enumerate(children, start=1):
# Update description to update the children processing progress
pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
try:
self.process_element(element, split)
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
def format_lm_files(self) -> None:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
logger.info("Preparing language resources")
# Build LM tokens
for token in sorted(list(self.charset)):
assert (
token not in self.mapping.encode.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
self.language_tokens.append(self.mapping.ctc.encoded)
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
# Build LM corpus
train_corpus = [text.replace("\n", " ") for text in self.data["train"].values()]
tokenizer = Tokenizer(
train_corpus,
outdir=self.output / "language_model",
mapping=self.mapping,
tokens=self.tokens,
subword_vocab_size=self.subword_vocab_size,
)
self.language_corpus["characters"] = [
tokenizer.char_tokenize(doc) for doc in train_corpus
]
self.language_corpus["words"] = [
tokenizer.word_tokenize(doc) for doc in train_corpus
]
self.language_corpus["subwords"] = [
tokenizer.subword_tokenize(doc) for doc in train_corpus
]
# Build vocabulary
word_vocabulary = set(
[
word
for doc in self.language_corpus["words"]
for word in doc.split()
if word != ""
]
)
subword_vocabulary = set(
[
subword
for doc in self.language_corpus["subwords"]
for subword in doc.split()
if subword != ""
]
)
# Build LM lexicon
self.language_lexicon["characters"] = [
f"{token} {token}" for token in self.language_tokens
]
self.language_lexicon["words"] = [
f"{word} {tokenizer.char_tokenize(word)}"
for word in sorted(word_vocabulary)
if word != ""
]
self.language_lexicon["subwords"] = [
f"{subword} {tokenizer.char_tokenize(subword)}"
for subword in sorted(subword_vocabulary)
]
def export(self):
(self.output / "labels.json").write_text(
json.dumps(
self.data,
sort_keys=True,
indent=4,
)
)
for level in ["characters", "words", "subwords"]:
(self.output / "language_model" / f"corpus_{level}.txt").write_text(
"\n".join(self.language_corpus[level])
)
(self.output / "language_model" / f"lexicon_{level}.txt").write_text(
"\n".join(self.language_lexicon[level])
)
(self.output / "language_model" / "tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
def download_images(self):
failed_downloads = []
with tqdm(
desc="Downloading images", total=len(self.tasks)
) as pbar, ThreadPoolExecutor() as executor:
def process_future(future: Future):
"""
Callback function called at the end of the thread
"""
# Update the progress bar count
pbar.update(1)
exc = future.exception()
if exc is None:
# No error
return
# If failed, tag for removal
assert isinstance(exc, ImageDownloadError)
# Remove transcription from labels dict
del self.data[exc.split][exc.path]
# Save tried URL
failed_downloads.append((exc.url, exc.message))
# Submit all tasks
for task in self.tasks:
executor.submit(self.get_image, **task).add_done_callback(
process_future
)
if failed_downloads:
logger.error(f"Failed to download {len(failed_downloads)} image(s).")
print(*list(map(": ".join, failed_downloads)), sep="\n")
def run(self):
# Iterate over the subsets to find the page images and labels.
for folder_id, split in zip(self.folders, SPLIT_NAMES):
with tqdm(
get_elements(
folder_id,
[self.parent_element_type],
),
desc=f"Extracting data from ({folder_id}) for split ({split})",
) as pbar:
# Iterate over the pages to create splits at page level.
for parent in pbar:
self.process_parent(
pbar=pbar,
parent=parent,
split=split,
)
# Progress bar updates
pbar.update()
pbar.refresh()
self.download_images()
self.format_lm_files()
self.export()
def run(
database: Path,
element_type: List[str],
parent_element_type: str,
output: Path,
entity_separators: List[str],
unknown_token: str,
tokens: Path,
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
max_width: Optional[int],
max_height: Optional[int],
image_format: str,
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
):
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
folders = [str(train_folder), str(val_folder), str(test_folder)]
# Create directories
for split in SPLIT_NAMES:
Path(output, IMAGES_DIR, split).mkdir(parents=True, exist_ok=True)
Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True)
ArkindexExtractor(
folders=folders,
element_type=element_type,
parent_element_type=parent_element_type,
output=output,
entity_separators=entity_separators,
unknown_token=unknown_token,
tokens=tokens,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
max_width=max_width,
max_height=max_height,
keep_spaces=keep_spaces,
image_extension=image_format,
allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size,
).run()