Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (5)
......@@ -83,6 +83,8 @@ pages:
- main
- tags
except:
- schedules
docs-deploy:
image: node:18
......
......@@ -2,13 +2,18 @@
class ProcessingError(Exception):
"""
Raised there is a problem somewhere in the processing
"""
...
class ElementProcessingError(ProcessingError):
"""
Raised when a problem is encountered while processing an element
"""
element_id: str
"""
ID of the element being processed.
"""
def __init__(self, element_id: str, *args: object) -> None:
super().__init__(*args)
......@@ -21,6 +26,9 @@ class ImageDownloadError(ElementProcessingError):
"""
error: Exception
"""
Error encountered.
"""
def __init__(self, element_id: str, error: Exception, *args: object) -> None:
super().__init__(element_id, *args)
......
# -*- coding: utf-8 -*-
import logging
import os
import random
from collections import defaultdict
from pathlib import Path
from typing import List, NamedTuple
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
from dan import logger
from dan.datasets.extract.utils import (
insert_token,
parse_tokens,
save_image,
save_json,
save_text,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
Entity = NamedTuple("Entity", offset=int, length=int, label=str)
class ArkindexExtractor:
"""
Extract data from Arkindex
"""
def __init__(
self,
client: ArkindexClient,
folders: list = [],
element_type: list = [],
parent_element_type: list = ["page"],
split_names: list = [],
output: Path = None,
load_entities: bool = None,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
) -> None:
self.client = client
self.element_type = element_type
self.parent_element_type = parent_element_type
self.split_names = split_names
self.output = output
self.load_entities = load_entities
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.get_subsets(folders)
def get_subsets(self, folders: list):
"""
Assign each folder to its split if it's already known.
Assign None if it's unknown.
"""
if self.use_existing_split:
self.subsets = [
(folder, split) for folder, split in zip(folders, self.split_names)
]
else:
self.subsets = [(folder, None) for folder in folders]
def _assign_random_split(self):
"""
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
"""
prob = random.random()
if prob <= self.train_prob:
yield self.split_names[0]
elif prob <= self.train_prob + self.val_prob:
yield self.split_names[1]
else:
yield self.split_names[2]
def get_random_split(self):
return next(self._assign_random_split())
def extract_entities(self, transcription: dict):
entities = self.client.paginate(
"ListTranscriptionEntities",
id=transcription["id"],
worker_version=self.entity_worker_version,
)
if entities is None:
logger.warning(
f"No entities found on transcription ({transcription['id']})."
)
return
return [
Entity(
offset=entity["offset"],
length=entity["length"],
label=entity["entity"]["metas"]["subtype"],
)
for entity in entities
]
def reconstruct_text(self, text: str, entities: List[Entity]):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
count = 0
for entity in entities:
matching_tokens = self.tokens[entity.label]
start_token, end_token = (
matching_tokens["start"],
matching_tokens["end"],
)
text, count = insert_token(
text,
count,
start_token,
end_token,
offset=entity.offset,
length=entity.length,
)
return text
def extract_transcription(
self,
element: dict,
):
"""
Extract the element's transcription.
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = self.client.request(
"ListTranscriptions",
id=element["id"],
worker_version=self.transcription_worker_version,
)
if transcriptions["count"] != 1:
logger.warning(
f"More than one transcription found on element ({element['id']}) with this config."
)
return
transcription = transcriptions["results"].pop()
if self.load_entities:
entities = self.extract_entities(transcription)
return self.reconstruct_text(transcription["text"], entities)
else:
return transcription["text"].strip()
def process_element(
self,
element: dict,
split: str,
):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
"""
text = self.extract_transcription(
element,
)
if not text:
logging.warning(f"Skipping {element['id']}")
else:
logging.info(f"Processed {element['type']} {element['id']}")
im_path = os.path.join(
self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg"
)
txt_path = os.path.join(
self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt"
)
save_text(txt_path, text)
try:
image = iio.imread(element["zone"]["url"])
save_image(im_path, image)
except Exception:
logger.error(f"Couldn't retrieve image of element ({element['id']}")
raise
return element["id"]
def process_parent(
self,
parent: dict,
split: str,
):
"""
Extract data from a parent element.
Depending on the given types,
"""
data = defaultdict(list)
if self.element_type == [parent["type"]]:
data[self.element_type[0]] = [
self.process_element(
parent,
split,
)
]
# Extract children elements
else:
for element_type in self.element_type:
for element in self.client.paginate(
"ListElementChildren",
id=parent["id"],
type=element_type,
recursive=True,
):
data[element_type].append(
self.process_element(
element,
split,
)
)
return data
def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels.
for subset_id, subset_split in self.subsets:
page_idx = 0
# Iterate over the pages to create splits at page level.
for parent in tqdm(
self.client.paginate(
"ListElementChildren",
id=subset_id,
type=self.parent_element_type,
recursive=True,
)
):
page_idx += 1
split = subset_split or self.get_random_split()
split_dict[split][parent["id"]] = self.process_parent(
parent=parent,
split=split,
)
save_json(self.output / "split.json", split_dict)
def run(
parent,
element_type,
parent_element_type,
output,
load_entities,
tokens,
use_existing_split,
train_folder,
val_folder,
test_folder,
transcription_worker_version,
entity_worker_version,
train_prob,
val_prob,
):
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
if use_existing_split:
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [train_folder, val_folder, test_folder]
else:
folders = parent
if load_entities:
assert tokens, "Please provide the entities to match."
# Login to arkindex.
assert (
"ARKINDEX_API_URL" in os.environ
), "The ARKINDEX API URL was not found in your environment."
assert (
"ARKINDEX_API_TOKEN" in os.environ
), "Your API credentials was not found in your environment."
client = ArkindexClient(**options_from_env())
# Create out directories
split_names = ["train", "val", "test"]
for split in split_names:
os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True)
os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True)
ArkindexExtractor(
client=client,
folders=folders,
element_type=element_type,
parent_element_type=parent_element_type,
split_names=split_names,
output=output,
load_entities=load_entities,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
).run()
......@@ -100,7 +100,7 @@ class DAN:
:param confidences: Return the characters probabilities.
:param attentions: Return characters attention weights.
"""
input_tensor.to(self.device)
input_tensor = input_tensor.to(self.device)
start_token = len(self.charset) + 1
end_token = len(self.charset)
......
# Arkindex
::: dan.datasets.extract.extract_from_arkindex
::: dan.datasets.extract.extract
# Database management
::: dan.datasets.extract.db
# Exceptions
::: dan.datasets.extract.exceptions
options:
show_source: false
......@@ -74,6 +74,8 @@ nav:
- ref/datasets/extract/index.md
- Arkindex: ref/datasets/extract/arkindex.md
- Utils: ref/datasets/extract/utils.md
- Database management: ref/datasets/extract/db.md
- Exceptions: ref/datasets/extract/exceptions.md
- Formatting:
- ref/datasets/format/index.md
- Automatic Text Recognition: ref/datasets/format/atr.md
......
arkindex-export==0.1.1
boto3==1.26.97
arkindex-export==0.1.2
boto3==1.26.98
editdistance==0.6.2
fontTools==4.39.2
imageio==2.26.1
......