Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (3)
Showing
with 255 additions and 251 deletions
......@@ -3,6 +3,10 @@ stages:
- build
- deploy
variables:
# Submodule clone
GIT_SUBMODULE_STRATEGY: recursive
lint:
image: python:3.10
stage: test
......
[submodule "line_image_extractor"]
path = teklia_line_image_extractor
url = ../line_image_extractor.git
......@@ -4,9 +4,21 @@
## Documentation
To use DAN in your own scripts, install it using pip:
To use DAN in your own environment, you need to first clone with its submodules via:
```console
```shell
git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
```
If you forgot the `--recurse-submodules`, you can initialize the submodule using:
```shell
git submodule update --init
```
Then you can install it via pip:
```shell
pip install -e .
```
......
......@@ -138,4 +138,12 @@ def add_extract_parser(subcommands) -> None:
help="Images larger than this height will be resized to this width.",
)
parser.add_argument(
"--cache",
dest="cache_dir",
type=pathlib.Path,
help="Where the images should be cached.",
default=pathlib.Path(".cache"),
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
import ast
from dataclasses import dataclass
from itertools import starmap
from typing import List, Optional, Union
from urllib.parse import urljoin
from typing import List, Union
from arkindex_export import Image
from arkindex_export.models import Element as ArkindexElement
from arkindex_export.models import (
Element,
Entity,
EntityType,
Transcription,
......@@ -17,51 +13,10 @@ from arkindex_export.models import (
from arkindex_export.queries import list_children
def bounding_box(polygon: list):
"""
Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return int(x), int(y), int(width), int(height)
@dataclass
class Element:
id: str
type: str
polygon: str
url: str
width: int
height: int
max_width: Optional[int] = None
max_height: Optional[int] = None
def __post_init__(self):
self.max_height = self.max_height or self.height
self.max_width = self.max_width or self.width
@property
def bounding_box(self):
return bounding_box(ast.literal_eval(self.polygon))
@property
def image_url(self):
x, y, width, height = self.bounding_box
return urljoin(
self.url + "/",
f"{x},{y},{width},{height}/!{self.max_width},{self.max_height}/0/default.jpg",
)
def get_elements(
parent_id: str,
element_type: List[str],
max_width: Optional[int] = None,
max_height: Optional[int] = None,
) -> List[Element]:
):
"""
Retrieve elements from an SQLite export of an Arkindex corpus
"""
......@@ -69,23 +24,10 @@ def get_elements(
query = (
list_children(parent_id=parent_id)
.join(Image)
.where(ArkindexElement.type.in_(element_type))
.select(
ArkindexElement.id,
ArkindexElement.type,
ArkindexElement.polygon,
Image.url,
Image.width,
Image.height,
)
.where(Element.type.in_(element_type))
)
return list(
starmap(
lambda *x: Element(*x, max_width=max_width, max_height=max_height),
query.tuples(),
)
)
return query
def build_worker_version_filter(ArkindexModel, worker_version):
......
# -*- coding: utf-8 -*-
import json
import random
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Union
from uuid import UUID
import numpy as np
from tqdm import tqdm
from arkindex_export import open_database
......@@ -27,10 +29,13 @@ from dan.datasets.extract.utils import (
insert_token,
parse_tokens,
)
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"]
IIIF_URL_SUFFIX = "/full/full/0/default.jpg"
class ArkindexExtractor:
......@@ -51,6 +56,7 @@ class ArkindexExtractor:
entity_worker_version: Optional[Union[str, bool]] = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
cache_dir: Path = Path(".cache"),
) -> None:
self.folders = folders
self.element_type = element_type
......@@ -64,6 +70,19 @@ class ArkindexExtractor:
self.max_width = max_width
self.max_height = max_height
self.cache_dir = cache_dir
# Create cache dir if non existent
self.cache_dir.mkdir(exist_ok=True, parents=True)
def find_image_in_cache(self, image_id: str) -> Path:
"""Images are cached to avoid downloading them twice. They are stored under a specific name,
based on their Arkindex ID. Images are saved under the JPEG format.
:param image_id: ID of the image. The image is saved under this name.
:return: Where the image should be saved in the cache folder.
"""
return self.cache_dir / f"{image_id}.jpg"
def _keep_char(self, char: str) -> bool:
# Keep all text by default if no separator was given
return not self.entity_separators or char in self.entity_separators
......@@ -148,6 +167,52 @@ class ArkindexExtractor:
)
return self.reconstruct_text(transcription.text, entities)
def retrieve_image(self, child: Element):
"""Get or download image of the element. Checks in cache before downloading.
:param child: Processed element.
:return: The element's image.
"""
cached_img_path = self.find_image_in_cache(child.image.id)
if not cached_img_path.exists():
# Save in cache
download_image(child.image.url + IIIF_URL_SUFFIX).save(
cached_img_path, format="jpeg"
)
return read_img(cached_img_path)
def get_image(self, child: Element, destination: Path) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param child: Processed element.
:param destination: Where the image should be saved.
"""
polygon = json.loads(str(child.polygon))
if self.max_height or self.max_width:
polygon = resize(
polygon,
self.max_width,
self.max_height,
scale_x=1.0,
scale_y_top=1.0,
scale_y_bottom=1.0,
)
# Extract the polygon in the image
image = extract(
img=self.retrieve_image(child),
polygon=np.array(polygon),
bbox=polygon_to_bbox(polygon),
# Hardcoded while we don't have a configuration file
extraction_mode=Extraction.deskew_min_area_rect,
max_deskew_angle=45,
)
# Save the image to disk
save_img(path=destination, img=image)
def process_element(
self,
element: Element,
......@@ -162,13 +227,14 @@ class ArkindexExtractor:
base_path = Path(split, f"{element.type}_{element.id}")
Path(self.output, LABELS_DIR, base_path).with_suffix(".txt").write_text(text)
download_image(
element, Path(self.output, LABELS_DIR, base_path).with_suffix(".jpg")
self.get_image(
element, Path(self.output, IMAGES_DIR, base_path).with_suffix(".jpg")
)
return element.id
def process_parent(
self,
pbar,
parent: Element,
split: str,
):
......@@ -176,7 +242,10 @@ class ArkindexExtractor:
Extract data from a parent element.
"""
data = defaultdict(list)
base_description = (
f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
)
pbar.set_description(desc=base_description)
if self.element_type == [parent.type]:
try:
data[parent.type].append(self.process_element(parent, split))
......@@ -184,12 +253,15 @@ class ArkindexExtractor:
logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements
else:
for element in get_elements(
children = get_elements(
parent.id,
self.element_type,
max_width=self.max_width,
max_height=self.max_height,
):
)
nb_children = children.count()
for idx, element in enumerate(children, start=1):
# Update description to update the children processing progress
pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
try:
data[element.type].append(self.process_element(element, split))
except ProcessingError as e:
......@@ -198,23 +270,24 @@ class ArkindexExtractor:
def run(self):
# Iterate over the subsets to find the page images and labels.
for idx, (folder_id, split) in enumerate(
zip(self.folders, SPLIT_NAMES), start=1
):
# Iterate over the pages to create splits at page level.
for parent in tqdm(
for folder_id, split in zip(self.folders, SPLIT_NAMES):
with tqdm(
get_elements(
folder_id,
[self.parent_element_type],
max_width=self.max_width,
max_height=self.max_height,
),
desc=f"Processing {folder_id} {idx}/{len(self.subsets)}",
):
self.process_parent(
parent=parent,
split=split,
)
desc=f"Extracting data from ({folder_id}) for split ({split})",
) as pbar:
# Iterate over the pages to create splits at page level.
for parent in pbar:
self.process_parent(
pbar=pbar,
parent=parent,
split=split,
)
# Progress bar updates
pbar.update()
pbar.refresh()
def run(
......@@ -232,6 +305,7 @@ def run(
entity_worker_version: Optional[Union[str, bool]],
max_width: Optional[int],
max_height: Optional[int],
cache_dir: Path,
):
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
......@@ -258,4 +332,5 @@ def run(
entity_worker_version=entity_worker_version,
max_width=max_width,
max_height=max_height,
cache_dir=cache_dir,
).run()
# -*- coding: utf-8 -*-
import logging
import time
from io import BytesIO
from pathlib import Path
from typing import NamedTuple
import cv2
import imageio.v2 as iio
import requests
import yaml
from numpy import ndarray
from dan.datasets.extract.db import Element
from dan.datasets.extract.exceptions import ImageDownloadError
from PIL import Image
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
MAX_RETRIES = 5
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
class EntityType(NamedTuple):
......@@ -26,29 +36,36 @@ class EntityType(NamedTuple):
return len(self.start) + len(self.end)
def download_image(element: Element, im_path: Path):
if im_path.exists():
return im_path
tries = 1
# retry loop
while True:
if tries > MAX_RETRIES:
raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
try:
image = iio.imread(element.image_url)
save_image(im_path, image)
return
except TimeoutError:
logger.warning("Timeout, retry in 1 second.")
time.sleep(1)
tries += 1
except Exception as e:
raise ImageDownloadError(element.id, e)
def save_image(path: Path, image: ndarray):
cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url):
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def download_image(url):
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = _retried_request(url)
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
return image
def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str:
......
# -*- coding: utf-8 -*-
import re
from operator import attrgetter
from pathlib import Path
from typing import Optional
import editdistance
import numpy as np
from dan.post_processing import PostProcessingModuleSIMARA
from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
from dan.datasets.extract.utils import parse_tokens
class MetricManager:
def __init__(self, metric_names, dataset_name):
def __init__(self, metric_names, dataset_name, tokens: Optional[Path]):
self.dataset_name = dataset_name
if "simara" in dataset_name and "page" in dataset_name:
self.post_processing_module = PostProcessingModuleSIMARA
self.matching_tokens = SIMARA_MATCHING_TOKENS
else:
self.matching_tokens = dict()
self.layout_tokens = "".join(
list(self.matching_tokens.keys()) + list(self.matching_tokens.values())
)
if len(self.layout_tokens) == 0:
self.layout_tokens = None
self.layout_tokens = None
if tokens:
tokens = parse_tokens(tokens)
self.layout_tokens = "".join(
list(map(attrgetter("start"), tokens.values()))
+ list(map(attrgetter("end"), tokens.values()))
)
self.metric_names = metric_names
self.epoch_metrics = None
......
......@@ -60,6 +60,7 @@ class GenericTrainingManager:
if self.params["training_params"]["use_ddp"]
else 1
)
self.tokens = self.params["dataset_params"].get("tokens")
def init_paths(self):
"""
......@@ -297,6 +298,7 @@ class GenericTrainingManager:
if (
"end_conv" in key
and "transfered_charset" in self.params["model_params"]
and self.params["model_params"]["transfered_charset"]
):
self.adapt_decision_layer_to_old_charset(
model_name, key, checkpoint, state_dict_name
......@@ -616,7 +618,9 @@ class GenericTrainingManager:
] = self.latest_epoch
# init epoch metrics values
self.metric_manager["train"] = MetricManager(
metric_names=metric_names, dataset_name=self.dataset_name
metric_names=metric_names,
dataset_name=self.dataset_name,
tokens=self.tokens,
)
with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar:
pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
......@@ -737,7 +741,9 @@ class GenericTrainingManager:
# initialize epoch metrics
self.metric_manager[set_name] = MetricManager(
metric_names, dataset_name=self.dataset_name
metric_names=metric_names,
dataset_name=self.dataset_name,
tokens=self.tokens,
)
with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("Evaluation E{}".format(self.latest_epoch))
......@@ -786,7 +792,9 @@ class GenericTrainingManager:
# initialize epoch metrics
self.metric_manager[custom_name] = MetricManager(
metric_names, self.dataset_name
metric_names=metric_names,
dataset_name=self.dataset_name,
tokens=self.tokens,
)
with tqdm(total=len(loader.dataset)) as pbar:
......
......@@ -114,6 +114,7 @@ def get_config():
],
"augmentation": True,
},
"tokens": None,
},
"model_params": {
"models": {
......
# -*- coding: utf-8 -*-
import numpy as np
from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
class PostProcessingModule:
"""
Forward pass post processing
Add/remove layout tokens only to:
- respect token hierarchy
- complete/remove unpaired tokens
"""
def __init__(self):
self.prediction = None
self.confidence = None
def post_processing(self):
raise NotImplementedError
def post_process(self, prediction, confidence_score=None):
"""
Apply dataset-specific post-processing
"""
self.prediction = list(prediction)
self.confidence = (
list(confidence_score) if confidence_score is not None else None
)
if self.confidence is not None:
assert len(self.prediction) == len(self.confidence)
return self.post_processing()
def insert_label(self, index, label):
"""
Insert token at specific index. The associated confidence score is set to 0.
"""
self.prediction.insert(index, label)
if self.confidence is not None:
self.confidence.insert(index, 0)
def del_label(self, index):
"""
Remove the token at a specific index.
"""
del self.prediction[index]
if self.confidence is not None:
del self.confidence[index]
class PostProcessingModuleSIMARA(PostProcessingModule):
"""
Specific post-processing for the SIMARA dataset at page level
"""
def __init__(self):
super(PostProcessingModuleSIMARA, self).__init__()
self.matching_tokens = SIMARA_MATCHING_TOKENS
self.reverse_matching_tokens = dict()
for key in self.matching_tokens:
self.reverse_matching_tokens[self.matching_tokens[key]] = key
def post_processing(self):
ind = 0
begin_token = None
while ind != len(self.prediction):
char = self.prediction[ind]
# a tag must be closed before starting a new one
if char in self.matching_tokens.keys():
if begin_token is None:
ind += 1
else:
self.insert_label(ind, self.matching_tokens[begin_token])
ind += 2
begin_token = char
continue
# an end token without prior corresponding begin token is removed
elif char in self.matching_tokens.values():
if begin_token == self.reverse_matching_tokens[char]:
ind += 1
begin_token = None
else:
self.del_label(ind)
continue
else:
ind += 1
# a tag must be closed
if begin_token is not None:
self.insert_label(ind + 1, self.matching_tokens[begin_token])
res = "".join(self.prediction)
if self.confidence is not None:
return res, np.array(self.confidence)
return res
......@@ -4,9 +4,6 @@ from itertools import islice
import torch
import torchvision.io as torchvision
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
class MLflowNotInstalled(Exception):
"""
......
# Get started
To use DAN in your own environment, install it using pip:
To use DAN in your own environment, you need to first clone with its submodules via:
```shell
git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
```
If you forgot the `--recurse-submodules`, you can initialize the submodule using:
```shell
git submodule update --init
```
Then you can install it via pip:
```shell
pip install -e .
......
# Post processing
::: dan.post_processing
......@@ -4,16 +4,17 @@ All hyperparameters are specified and editable in the training scripts `dan/ocr/
## Dataset parameters
| Parameter | Description | Type | Default |
| -------------------------------------- | -------------------------------------------------------------------------------------- | ------ | ---------------------------------------------------- |
| `dataset_name` | Name of the dataset. | `str` | |
| `dataset_level` | Level of the dataset. Should be named after the element type. | `str` | |
| `dataset_variant` | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str` | |
| `dataset_path` | Path to the dataset. | `str` | |
| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `bool` | `True` |
| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `dataset_params.config.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) |
| Parameter | Description | Type | Default |
| -------------------------------------- | --------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------------------- |
| `dataset_name` | Name of the dataset. | `str` | |
| `dataset_level` | Level of the dataset. Should be named after the element type. | `str` | |
| `dataset_variant` | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str` | |
| `dataset_path` | Path to the dataset. | `str` | |
| `dataset_params.config.load_in_memory` | Load all images in CPU memory. | `bool` | `True` |
| `dataset_params.config.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `dataset_params.config.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) |
| `dataset_params.tokens` | Path to a NER tokens configuration file similar to [the one used for extraction](../datasets/extract.md#description). | `pathlib.Path` | None |
!!! warning
The variables `dataset_name`, `dataset_level`, `dataset_variant` and `dataset_path` must have values such that the data is located in `{dataset_path}/{dataset_name}_{dataset_level}{dataset_variant}`.
......
......@@ -100,7 +100,6 @@ nav:
- Decoders: ref/decoder.md
- Models: ref/encoder.md
- MLflow: ref/mlflow.md
- Post Processing: ref/post_processing.md
- Schedulers: ref/schedulers.md
- Transformations: ref/transforms.md
- Utils: ref/utils.md
......
......@@ -4,7 +4,7 @@ ignore = ["E501"]
select = ["E", "F", "T1", "W", "I"]
[tool.ruff.isort]
known-first-party = ["arkindex_export"]
known-first-party = ["arkindex_export", "line_image_extractor"]
known-third-party = [
"albumentations",
"cv2",
......
-e ./teklia_line_image_extractor
albumentations==1.3.1
arkindex-export==0.1.3
boto3==1.26.124
editdistance==0.6.2
imageio==2.26.1
numpy==1.24.3
opencv-python==4.7.0.72
PyYAML==6.0
scipy==1.10.1
tenacity==8.2.2
tensorboard==2.12.2
torch==2.0.0
torchvision==0.15.1
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from pathlib import Path
from typing import List
from setuptools import find_packages, setup
def parse_requirements(path):
assert os.path.exists(path), "Missing requirements {}".format(path)
with open(path) as f:
return list(map(str.strip, f.read().splitlines()))
def parse_requirements_line(line) -> str:
# Special case for git requirements
if line.startswith("git+http"):
assert "@" in line, "Branch should be specified with suffix (ex: @master)"
assert (
"#egg=" in line
), "Package name should be specified with suffix (ex: #egg=kraken)"
package_name: str = line.split("#egg=")[-1]
return f"{package_name} @ {line}"
# Special case for submodule requirements
elif line.startswith("-e"):
package_path: str = line.split(" ")[-1]
package = Path(package_path).resolve()
return f"{package.name} @ file://{package}"
else:
return line
def parse_requirements(filename: str) -> List[str]:
path = Path(__file__).parent.resolve() / filename
assert path.exists(), f"Missing requirements: {path}"
return list(
map(parse_requirements_line, map(str.strip, path.read_text().splitlines()))
)
setup(
......
Subproject commit 210c64939d62a8d915dcedbca7dcd529652e5a8b