Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (6)
# DAN: a Segmentation-free Document Attention Network for Handwritten Document Recognition
[![Python >= 3.10](https://img.shields.io/badge/Python-%3E%3D3.10-blue.svg)](https://www.python.org/downloads/release/python-3100/)
## Documentation
To use DAN in your own scripts, install it using pip:
......
......@@ -146,4 +146,16 @@ def add_extract_parser(subcommands) -> None:
help="Validation set split size.",
)
parser.add_argument(
"--max-width",
type=int,
help="Images larger than this width will be resized to this width.",
)
parser.add_argument(
"--max-height",
type=int,
help="Images larger than this height will be resized to this width.",
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
import ast
from dataclasses import dataclass
from itertools import starmap
from typing import List, NamedTuple, Union
from typing import List, NamedTuple, Optional, Union
from urllib.parse import urljoin
from arkindex_export import Image
......@@ -41,13 +42,21 @@ Entity = NamedTuple(
)
class Element(NamedTuple):
@dataclass
class Element:
id: str
type: str
polygon: str
url: str
width: str
height: str
width: int
height: int
max_width: Optional[int] = None
max_height: Optional[int] = None
def __post_init__(self):
self.max_height = self.max_height or self.height
self.max_width = self.max_width or self.width
@property
def bounding_box(self):
......@@ -56,10 +65,18 @@ class Element(NamedTuple):
@property
def image_url(self):
x, y, width, height = self.bounding_box
return urljoin(self.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg")
return urljoin(
self.url + "/",
f"{x},{y},{width},{height}/!{self.max_width},{self.max_height}/0/default.jpg",
)
def get_elements(parent_id: str, element_type: str) -> List[Element]:
def get_elements(
parent_id: str,
element_type: str,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
) -> List[Element]:
"""
Retrieve elements from an SQLite export of an Arkindex corpus
"""
......@@ -77,10 +94,9 @@ def get_elements(parent_id: str, element_type: str) -> List[Element]:
Image.height,
)
)
return list(
starmap(
Element,
lambda *x: Element(*x, max_width=max_width, max_height=max_height),
query.tuples(),
)
)
......
......@@ -3,7 +3,7 @@
import random
from collections import defaultdict
from pathlib import Path
from typing import List, Union
from typing import List, Optional, Union
from uuid import UUID
from arkindex_export import open_database
......@@ -56,6 +56,8 @@ class ArkindexExtractor:
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
) -> None:
self.element_type = element_type
self.parent_element_type = parent_element_type
......@@ -67,6 +69,8 @@ class ArkindexExtractor:
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.max_width = max_width
self.max_height = max_height
self.subsets = self.get_subsets(folders)
......@@ -179,7 +183,12 @@ class ArkindexExtractor:
# Extract children elements
else:
for element_type in self.element_type:
for element in get_elements(parent.id, element_type):
for element in get_elements(
parent.id,
element_type,
max_width=self.max_width,
max_height=self.max_height,
):
try:
data[element_type].append(self.process_element(element, split))
except ProcessingError as e:
......@@ -192,7 +201,12 @@ class ArkindexExtractor:
for idx, subset in enumerate(self.subsets, start=1):
# Iterate over the pages to create splits at page level.
for parent in tqdm(
get_elements(subset.id, self.parent_element_type),
get_elements(
subset.id,
self.parent_element_type,
max_width=self.max_width,
max_height=self.max_height,
),
desc=f"Processing {subset} {idx}/{len(self.subsets)}",
):
split = subset.split or self.get_random_split()
......@@ -219,6 +233,8 @@ def run(
entity_worker_version: Union[str, bool],
train_prob,
val_prob,
max_width: Optional[int],
max_height: Optional[int],
):
assert (
use_existing_split or parent
......@@ -265,4 +281,6 @@ def run(
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
max_width=max_width,
max_height=max_height,
).run()
......@@ -268,7 +268,10 @@ class GenericDataset(Dataset):
def apply_preprocessing(self, preprocessings):
for i in range(len(self.samples)):
self.samples[i] = apply_preprocessing(self.samples[i], preprocessings)
(
self.samples[i]["img"],
self.samples[i]["resize_ratio"],
) = apply_preprocessing(self.samples[i]["img"], preprocessings)
def compute_std_mean(self):
"""
......@@ -276,46 +279,33 @@ class GenericDataset(Dataset):
"""
if self.mean is not None and self.std is not None:
return self.mean, self.std
if not self.load_in_memory:
sample = self.samples[0].copy()
sample["img"] = self.get_sample_img(0)
img = apply_preprocessing(sample, self.params["config"]["preprocessings"])[
"img"
]
else:
img = self.get_sample_img(0)
_, _, c = img.shape
sum = np.zeros((c,))
sum = np.zeros((3,))
diff = np.zeros((3,))
nb_pixels = 0
for metric in ["mean", "std"]:
for ind in range(len(self.samples)):
img = (
self.get_sample_img(ind)
if self.load_in_memory
else apply_preprocessing(
self.get_sample_img(ind),
self.params["config"]["preprocessings"],
)[0]
)
for i in range(len(self.samples)):
if not self.load_in_memory:
sample = self.samples[i].copy()
sample["img"] = self.get_sample_img(i)
img = apply_preprocessing(
sample, self.params["config"]["preprocessings"]
)["img"]
else:
img = self.get_sample_img(i)
sum += np.sum(img, axis=(0, 1))
nb_pixels += np.prod(img.shape[:2])
mean = sum / nb_pixels
diff = np.zeros((c,))
for i in range(len(self.samples)):
if not self.load_in_memory:
sample = self.samples[i].copy()
sample["img"] = self.get_sample_img(i)
img = apply_preprocessing(
sample, self.params["config"]["preprocessings"]
)["img"]
else:
img = self.get_sample_img(i)
diff += [np.sum((img[:, :, k] - mean[k]) ** 2) for k in range(c)]
std = np.sqrt(diff / nb_pixels)
self.mean = mean
self.std = std
return mean, std
if metric == "mean":
sum += np.sum(img, axis=(0, 1))
nb_pixels += np.prod(img.shape[:2])
elif metric == "std":
diff += [
np.sum((img[:, :, k] - self.mean[k]) ** 2) for k in range(3)
]
if metric == "mean":
self.mean = sum / nb_pixels
elif metric == "std":
self.std = np.sqrt(diff / nb_pixels)
return self.mean, self.std
def apply_data_augmentation(self, img):
"""
......@@ -340,12 +330,11 @@ class GenericDataset(Dataset):
return GenericDataset.load_image(self.samples[i]["path"])
def apply_preprocessing(sample, preprocessings):
def apply_preprocessing(img, preprocessings):
"""
Apply preprocessings on each sample
Apply preprocessings on an image
"""
resize_ratio = [1, 1]
img = sample["img"]
for preprocessing in preprocessings:
if preprocessing["type"] == "to_grayscaled":
temp_img = img
......@@ -394,6 +383,4 @@ def apply_preprocessing(sample, preprocessings):
img = temp_img
resize_ratio = [ratio, ratio]
sample["img"] = img
sample["resize_ratio"] = resize_ratio
return sample
return img, resize_ratio
# -*- coding: utf-8 -*-
import copy
import os
import pickle
......@@ -62,12 +61,12 @@ class OCRDataset(GenericDataset):
self.collate_function = OCRCollateFunction
def __getitem__(self, idx):
sample = copy.deepcopy(self.samples[idx])
sample = dict(**self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
sample = apply_preprocessing(
sample, self.params["config"]["preprocessings"]
sample["img"], sample["resize_ratio"] = apply_preprocessing(
sample["img"], self.params["config"]["preprocessings"]
)
# Data augmentation
......
# -*- coding: utf-8 -*-
import copy
import json
import os
import random
from copy import deepcopy
from time import time
import numpy as np
......@@ -481,8 +481,7 @@ class GenericTrainingManager:
path = os.path.join(self.paths["results"], "params")
if os.path.isfile(path):
return
params = copy.deepcopy(self.params)
params = class_to_str_dict(params)
params = class_to_str_dict(my_dict=deepcopy(self.params))
total_params = 0
for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name])
......
......@@ -353,8 +353,8 @@ def run(
result["confidences"]["by ner token"] = [
{
"text": f"{text[current: next_token-1]}",
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token-1]), 2)}",
"text": f"{text[current: next_token]}".replace("\n", " "),
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
}
for current, next_token in pairwise(index + [0])
]
......
......@@ -21,6 +21,8 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str|uuid` | |
| `--train-prob` | Training set split size | `float` | `0.7` |
| `--val-prob` | Validation set split size | `float` | `0.15` |
| `--max-width` | Images larger than this width will be resized to this width. | `int` | |
| `--max-height` | Images larger than this height will be resized to this height. | `int` | |
The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively.
```yaml
......
......@@ -21,6 +21,7 @@ setup(
author="Teklia",
author_email="contact@teklia.com",
url="https://gitlab.com/teklia/atr/dan",
python_requires=">=3.10",
install_requires=parse_requirements("requirements.txt"),
packages=find_packages(),
entry_points={
......