Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (19)
Showing
with 658 additions and 835 deletions
......@@ -127,9 +127,3 @@ dmypy.json
# Pyre type checker
.pyre/
Datasets/formatted/*
Datasets/raw/*
**/outputs
Fonts/*
.idea
......@@ -4,7 +4,7 @@ stages:
- deploy
lint:
image: python:3.8
image: python:3.10
stage: test
cache:
......@@ -24,6 +24,114 @@ lint:
script:
- pre-commit run -a
test:
image: python:3.10
stage: test
cache:
paths:
- .cache/pip
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
ARKINDEX_API_SCHEMA_URL: schema.yml
before_script:
- pip install tox
# Download OpenAPI schema from last backend build
- curl https://assets.teklia.com/arkindex/openapi.yml > schema.yml
# Add system deps for opencv
- apt-get update -q
- apt-get install -q -y libgl1
except:
- schedules
script:
- tox
# Make sure docs still build correctly
.docs:
image: python:3.10
artifacts:
paths:
- public
before_script:
- pip install -e .[docs]
script:
- mkdocs build --strict --verbose
docs-build:
extends: .docs
stage: build
# Test job outside of tags to ensure the docs still can build before merging
# Does not use the `pages` name, therefore will be ignored by GitLab Pages
except:
- tags
- schedules
pages:
extends: .docs
stage: deploy
only:
- main
- tags
docs-deploy:
image: node:18
stage: deploy
dependencies:
- docs-build
before_script:
- npm install -g surge
except:
- main
- tags
- schedules
environment:
name: ${CI_COMMIT_REF_SLUG}
url: https://${CI_COMMIT_REF_SLUG}-teklia-atr-dan.surge.sh
on_stop: docs-stop-surge
script:
- surge public ${CI_ENVIRONMENT_URL}
docs-stop-surge:
image: node:18
stage: deploy
when: manual
# Do not try to checkout the branch if it was deleted
variables:
GIT_STRATEGY: none
except:
- main
- tags
- schedules
environment:
name: ${CI_COMMIT_REF_SLUG}
url: https://${CI_COMMIT_REF_SLUG}-teklia-atr-dan.surge.sh
action: stop
before_script:
- npm install -g surge
script:
- surge teardown ${CI_ENVIRONMENT_URL}
bump-python-deps:
stage: deploy
image: registry.gitlab.com/teklia/devops:latest
......@@ -32,7 +140,7 @@ bump-python-deps:
- schedules
script:
- devops python-deps requirements.txt
- devops python-deps requirements.txt doc-requirements.txt
release-notes:
stage: deploy
......
......@@ -5,19 +5,18 @@ repos:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/ambv/black
rev: 22.6.0
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies:
- 'flake8-coding==1.3.1'
- 'flake8-copyright==0.2.2'
- 'flake8-debugger==4.0.0'
- 'flake8-coding==1.3.2'
- 'flake8-debugger==4.1.2'
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-ast
- id: check-docstring-first
......@@ -40,6 +39,10 @@ repos:
hooks:
- id: codespell
args: ['--write-changes']
- repo: https://github.com/PyCQA/doc8
rev: v1.0.0
hooks:
- id: doc8
- repo: meta
hooks:
- id: check-useless-excludes
include requirements.txt
include doc-requirements.txt
include VERSION
# DAN: a Segmentation-free Document Attention Network for Handwritten Document Recognition
This repository is a public implementation of the paper: "DAN: a Segmentation-free Document Attention Network for Handwritten Document Recognition".
## Documentation
![Prediction visualization](images/visual.png)
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
The model uses a character-level attention to handle slanted lines:
![Prediction visualization on slanted lines](images/visual_slanted_lines.png)
The paper is available at https://arxiv.org/abs/2203.12273.
To discover my other works, here is my [academic page](https://factodeeplearning.github.io/).
Click to see the demo:
[![Click to see demo](https://img.youtube.com/vi/HrrUsQfW66E/0.jpg)](https://www.youtube.com/watch?v=HrrUsQfW66E)
This work focus on handwritten text and layout recognition through the use of an end-to-end segmentation-free attention-based network.
We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page and double-page levels.
We obtained the following results:
| | CER (%) | WER (%) | LOER (%) | mAP_cer (%) |
| :---------------------: | ------- | :-----: | :------: | ----------- |
| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 |
| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 |
| READ 2016 (double page) | 3.69 | 14.20 | 4.60 | 93.92 |
Pretrained model weights are available [here](https://git.litislab.fr/dcoquenet/dan).
Table of contents:
1. [Getting Started](#Getting-Started)
2. [Datasets](#Datasets)
3. [Training And Evaluation](#Training-and-evaluation)
## Getting Started
We used Python 3.9.1, Pytorch 1.8.2 and CUDA 10.2 for the scripts.
Clone the repository:
```
git clone https://github.com/FactoDeepLearning/DAN.git
```
Install the dependencies:
```
pip install -r requirements.txt
```
### Remarks (for pre-training and training)
All hyperparameters are specified and editable in the training scripts (meaning are in comments).\
Evaluation is performed just after training ending (training is stopped when the maximum elapsed time is reached or after a maximum number of epoch as specified in the training script).\
The outputs files are split into two subfolders: "checkpoints" and "results". \
"checkpoints" contains model weights for the last trained epoch and for the epoch giving the best valid CER. \
"results" contains tensorboard log for loss and metrics as well as text file for used hyperparameters and results of evaluation.
## `Predict` module
This repository also contains a package to run a pre-trained model on an image.
### Installation
## Installation
To use DAN in your own scripts, install it using pip:
......@@ -68,7 +13,7 @@ To use DAN in your own scripts, install it using pip:
pip install -e .
```
### Usage
## Inference
To apply DAN to an image, one needs to first add a few imports and to load an image. Note that the image should be in RGB.
```python
......@@ -93,105 +38,21 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In
text, confidence_scores = model.predict(image, confidences=True)
```
### Commands
## Training
This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
#### Data extraction from Arkindex
Use the `teklia-dan dataset extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model.
The available arguments are
| Parameter | Description | Type | Default |
| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- |
| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str/uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `Path` | |
| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` |
| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | |
| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | |
| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | |
| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering. | `str/uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | `str/uuid` | |
| `--train-prob` | Training set split size | `float` | `0.7` |
| `--val-prob` | Validation set split size | `float` | `0.15` |
The `--tokens` argument expects a file with the following format.
```yaml
---
INTITULE:
start: ⓘ
end: Ⓘ
DATE:
start: ⓓ
end: Ⓓ
COTE_SERIE:
start: ⓢ
end: Ⓢ
ANALYSE_COMPL.:
start: ⓒ
end: Ⓒ
PRECISIONS_SUR_COTE:
start: ⓟ
end: Ⓟ
COTE_ARTICLE:
start: ⓐ
end: Ⓐ
CLASSEMENT:
start: ⓛ
end: Ⓛ
```
### Data extraction from Arkindex
To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command:
```shell
teklia-dan dataset extract \
--parent 665e84ea-bd97-4912-91b0-1f4a844287ff \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
with `tokens.yml` having the content described just above.
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/extract/) on the official DAN documentation.
### Dataset formatting
To do the same but only use the data from three folders, the commands becomes:
```shell
teklia-dan dataset extract \
--parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/format/) on the official DAN documentation.
To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes:
```shell
teklia-dan dataset extract \
--use-existing-split \
--train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c
--val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \
--test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615) that are children of **single_pages**, use the following command:
```shell
teklia-dan dataset extract \
--parent 48852284-fc02-41bb-9a42-4458e5a51615 \
--element-type text_zone annotation \
--parent-element-type single_page \
--output data
```
### Model training
#### Model training
`teklia-dan train` with multiple arguments.
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
#### Synthetic data generation
`teklia-dan generate` with multiple arguments
### Synthetic data generation
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation.
......@@ -4,6 +4,7 @@ Preprocess datasets for training.
"""
from dan.datasets.extract import add_extract_parser
from dan.datasets.format import add_format_parser
def add_dataset_parser(subcommands) -> None:
......@@ -15,3 +16,4 @@ def add_dataset_parser(subcommands) -> None:
subcommands = parser.add_subparsers(metavar="subcommand")
add_extract_parser(subcommands)
add_format_parser(subcommands)
# -*- coding: utf-8 -*-
"""
Format datasets for training.
"""
from pathlib import Path
from dan.datasets.format.atr import run
def add_format_parser(subcommands) -> None:
parser = subcommands.add_parser(
"format",
description=__doc__,
help=__doc__,
)
parser.add_argument(
"--dataset",
type=Path,
help="Path to the exported dataset.",
required=True,
)
parser.add_argument(
"--image-format",
type=str,
help="Format under which the images were saved.",
required=True,
)
parser.add_argument(
"--keep-spaces",
action="store_true",
help="Do not remove spaces in transcriptions.",
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
import json
import os
import pickle
import re
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm
def remove_spaces(text):
# remove begin/ending spaces
text = text.strip()
# replace \t with regular space
text = re.sub("\t", " ", text)
# remove consecutive spaces
text = re.sub(" +", " ", text)
return text
class ATRDatasetFormatter:
"""
Global pipeline/functions for dataset formatting
"""
def __init__(self, dataset: Path, image_format: str, remove_spaces: bool):
self.dataset = dataset
self.set_names = ["train", "val", "test"]
self.remove_spaces = remove_spaces
self.image_folder = self.dataset / "images"
self.labels_folder = self.dataset / "labels"
self.image_format = image_format
if self.image_format.startswith("."):
self.image_format = self.image_format[1:]
def format(self):
"""
Format ATR dataset
"""
ground_truth = defaultdict(dict)
charset = set()
for set_name in self.set_names:
set_folder = self.labels_folder / set_name
for file_name in tqdm(
os.listdir(set_folder), desc="Formatting " + set_name
):
data = self.parse_labels(set_name, file_name)
charset = charset.union(set(data["label"]))
ground_truth[set_name][data["img_path"]] = {
"text": data["label"],
}
return ground_truth, charset
def read_file(self, file_name):
with open(file_name, "r") as f:
text = f.read()
if self.remove_spaces:
text = remove_spaces(text)
return text.strip()
def parse_labels(self, set_name, file_name):
return {
"img_path": os.path.join(
self.image_folder,
set_name,
f"{os.path.splitext(file_name)[0]}.{self.image_format}",
),
"label": self.read_file(
os.path.join(self.labels_folder, set_name, file_name)
),
}
def run(self):
ground_truth, charset = self.format()
with open(self.dataset / "labels.json", "w") as f:
json.dump(
ground_truth,
f,
sort_keys=True,
indent=4,
)
with open(self.dataset / "charset.pkl", "wb") as f:
pickle.dump(sorted(list(charset)), f)
def run(dataset, image_format, keep_spaces):
ATRDatasetFormatter(
dataset=dataset, image_format=image_format, remove_spaces=not keep_spaces
).run()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import re
from collections import Counter
from tqdm import tqdm
from dan.datasets.format.generic import OCRDatasetFormatter
def remove_spaces(text):
# remove begin/ending spaces
text = text.strip()
# replace \t with regular space
text = re.sub("\t", " ", text)
# remove consecutive spaces
text = re.sub(" +", " ", text)
# text = text.encode('ascii', 'ignore').decode("utf-8")
return text
class BessinDatasetFormatter(OCRDatasetFormatter):
def __init__(
self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False
):
super(BessinDatasetFormatter, self).__init__(
"bessin", level, "_sem" if sem_token else "", set_names
)
self.dpi = dpi
self.counter = Counter()
self.map_datasets_files.update(
{
"bessin": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_bessin_zone,
}
}
}
)
def preformat_bessin_zone(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {"train": list(), "valid": list(), "test": list()}
img_folder_path = os.path.join("Datasets", "raw", "bessin", "images")
labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels")
train_files = [
os.path.join(labels_folder_path, "train", name)
for name in os.listdir(os.path.join(labels_folder_path, "train"))
]
valid_files = [
os.path.join(labels_folder_path, "valid", name)
for name in os.listdir(os.path.join(labels_folder_path, "valid"))
]
test_files = [
os.path.join(labels_folder_path, "test", name)
for name in os.listdir(os.path.join(labels_folder_path, "test"))
]
for set_name, files in zip(
self.set_names, [train_files, valid_files, test_files]
):
for i, label_file in enumerate(
tqdm(files, desc="Pre-formatting " + set_name)
):
with open(label_file, "r") as f:
text = remove_spaces(f.read())
dataset[set_name].append(
{
"img_path": os.path.join(
img_folder_path,
set_name,
label_file.split("/")[-1].replace("txt", "jpg"),
),
"label": text.strip(),
}
)
return dataset
def format_bessin_zone(self):
"""
Format synist page dataset
"""
dataset = self.preformat_bessin_zone()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc="Formatting " + set_name):
new_name = sample["img_path"].split("/")[-1]
new_img_path = os.path.join(fold, new_name)
self.load_resize_save(sample["img_path"], new_img_path)
zone = {
"text": sample["label"],
}
self.charset = self.charset.union(set(zone["text"]))
self.gt[set_name][new_name] = zone
self.counter.update(zone["text"])
if __name__ == "__main__":
formatter = BessinDatasetFormatter("line", sem_token=False)
formatter.format()
print("Character freq: ")
for k, v in formatter.counter.items():
print(k, v)
# -*- coding: utf-8 -*-
import os
import pickle
import shutil
import numpy as np
from PIL import Image
class DatasetFormatter:
"""
Global pipeline/functions for dataset formatting
"""
def __init__(
self, dataset_name, level, extra_name="", set_names=["train", "valid", "test"]
):
self.dataset_name = dataset_name
self.level = level
self.set_names = set_names
self.target_fold_path = os.path.join(
"Datasets", "formatted", "{}_{}{}".format(dataset_name, level, extra_name)
)
self.map_datasets_files = dict()
def format(self):
self.init_format()
self.map_datasets_files[self.dataset_name][self.level]["format_function"]()
self.end_format()
def init_format(self):
"""
Load and extracts needed files
"""
os.makedirs(self.target_fold_path, exist_ok=True)
for set_name in self.set_names:
os.makedirs(os.path.join(self.target_fold_path, set_name), exist_ok=True)
class OCRDatasetFormatter(DatasetFormatter):
"""
Specific pipeline/functions for OCR/HTR dataset formatting
"""
def __init__(
self, source_dataset, level, extra_name="", set_names=["train", "valid", "test"]
):
super(OCRDatasetFormatter, self).__init__(
source_dataset, level, extra_name, set_names
)
self.charset = set()
self.gt = dict()
for set_name in set_names:
self.gt[set_name] = dict()
def load_resize_save(self, source_path, target_path):
"""
Load image, apply resolution modification and save it
"""
shutil.copyfile(source_path, target_path)
def resize(self, img, source_dpi, target_dpi):
"""
Apply resolution modification to image
"""
if source_dpi == target_dpi:
return img
if isinstance(img, np.ndarray):
h, w = img.shape[:2]
img = Image.fromarray(img)
else:
w, h = img.size
ratio = target_dpi / source_dpi
img = img.resize((int(w * ratio), int(h * ratio)), Image.BILINEAR)
return np.array(img)
def end_format(self):
"""
Save label and charset files
"""
with open(os.path.join(self.target_fold_path, "labels.pkl"), "wb") as f:
pickle.dump(
{
"ground_truth": self.gt,
"charset": sorted(list(self.charset)),
},
f,
)
with open(os.path.join(self.target_fold_path, "charset.pkl"), "wb") as f:
pickle.dump(sorted(list(self.charset)), f)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from collections import defaultdict
from tqdm import tqdm
from dan.datasets.format.generic import OCRDatasetFormatter
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": "",
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
class SimaraDatasetFormatter(OCRDatasetFormatter):
def __init__(
self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True
):
super(SimaraDatasetFormatter, self).__init__(
"simara", level, "_sem" if sem_token else "", set_names
)
self.dpi = dpi
self.sem_token = sem_token
self.map_datasets_files.update(
{
"simara": {
# (1,050 for train, 100 for validation and 100 for test)
"page": {
"format_function": self.format_simara_page,
},
}
}
)
self.matching_tokens_str = SEM_MATCHING_TOKENS_STR
self.matching_tokens = SEM_MATCHING_TOKENS
def preformat_simara_page(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = defaultdict(list)
img_folder_path = os.path.join("Datasets", "raw", "simara", "images")
labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels")
sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem")
train_files = [
os.path.join(labels_folder_path, "train", name)
for name in os.listdir(os.path.join(sem_labels_folder_path, "train"))
]
valid_files = [
os.path.join(labels_folder_path, "valid", name)
for name in os.listdir(os.path.join(sem_labels_folder_path, "valid"))
]
test_files = [
os.path.join(labels_folder_path, "test", name)
for name in os.listdir(os.path.join(sem_labels_folder_path, "test"))
]
for set_name, files in zip(
self.set_names, [train_files, valid_files, test_files]
):
for i, label_file in enumerate(
tqdm(files, desc="Pre-formatting " + set_name)
):
with open(label_file, "r") as f:
text = f.read()
with open(label_file.replace("labels", "labels_sem"), "r") as f:
sem_text = f.read()
dataset[set_name].append(
{
"img_path": os.path.join(
img_folder_path,
set_name,
label_file.split("/")[-1].replace("txt", "jpg"),
),
"label": text,
"sem_label": sem_text,
}
)
return dataset
def format_simara_page(self):
"""
Format simara page dataset
"""
dataset = self.preformat_simara_page()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc="Formatting " + set_name):
new_name = sample["img_path"].split("/")[-1]
new_img_path = os.path.join(fold, new_name)
self.load_resize_save(
sample["img_path"], new_img_path
) # , 300, self.dpi)
page = {
"text": sample["label"]
if not self.sem_token
else sample["sem_label"],
}
self.charset = self.charset.union(set(page["text"]))
self.gt[set_name][new_name] = page
if __name__ == "__main__":
formatter = SimaraDatasetFormatter("page", sem_token=False)
formatter.format()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import re
from collections import Counter
from tqdm import tqdm
from dan.datasets.format.generic import OCRDatasetFormatter
def remove_spaces(text):
# remove begin/ending spaces
text = text.strip()
# replace \t with regular space
text = re.sub("\t", " ", text)
# remove consecutive spaces
text = re.sub(" +", " ", text)
return text
class SynistDatasetFormatter(OCRDatasetFormatter):
def __init__(
self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False
):
super(SynistDatasetFormatter, self).__init__(
"synist_synth", level, "_sem" if sem_token else "", set_names
)
self.dpi = dpi
self.counter = Counter()
self.map_datasets_files.update(
{
"synist_synth": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_synist_page,
}
}
}
)
def preformat_synist_page(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {"train": list(), "valid": list(), "test": list()}
img_folder_path = os.path.join(
"Datasets", "raw", "synist_synth_lines", "images"
)
labels_folder_path = os.path.join(
"Datasets", "raw", "synist_synth_lines", "labels"
)
train_files = [
os.path.join(labels_folder_path, "train", name)
for name in os.listdir(os.path.join(labels_folder_path, "train"))
]
valid_files = [
os.path.join(labels_folder_path, "valid", name)
for name in os.listdir(os.path.join(labels_folder_path, "valid"))
]
test_files = [
os.path.join(labels_folder_path, "test", name)
for name in os.listdir(os.path.join(labels_folder_path, "test"))
]
for set_name, files in zip(
self.set_names, [train_files, valid_files, test_files]
):
for i, label_file in enumerate(
tqdm(files, desc="Pre-formatting " + set_name)
):
with open(label_file, "r") as f:
text = remove_spaces(f.read())
dataset[set_name].append(
{
"img_path": os.path.join(
img_folder_path,
set_name,
label_file.split("/")[-1].replace("txt", "png"),
),
"label": text.strip(),
}
)
return dataset
def format_synist_page(self):
"""
Format synist page dataset
"""
dataset = self.preformat_synist_page()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc="Formatting " + set_name):
new_name = sample["img_path"].split("/")[-1]
new_img_path = os.path.join(fold, new_name)
# self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi)
self.load_resize_save(sample["img_path"], new_img_path)
# self.load_flip_save(new_img_path, new_img_path)
page = {
"text": sample["label"],
}
self.charset = self.charset.union(set(page["text"]))
self.gt[set_name][new_name] = page
self.counter.update(page["text"])
if __name__ == "__main__":
formatter = SynistDatasetFormatter("line", sem_token=False)
formatter.format()
print(formatter.counter)
print(formatter.counter.most_common(80))
for k, v in formatter.counter.items():
print(k)
print(k.encode("utf-8"), v)
# -*- coding: utf-8 -*-
import json
import os
import pickle
import random
import cv2
......@@ -39,7 +39,7 @@ class DatasetManager:
self.batch_size = {
"train": self.params["batch_size"],
"valid": self.params["valid_batch_size"]
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
......@@ -70,12 +70,12 @@ class DatasetManager:
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["valid"].keys():
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"valid",
"val",
custom_name,
self.get_paths_and_sets(self.params["valid"][custom_name]),
self.get_paths_and_sets(self.params["val"][custom_name]),
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
......@@ -124,7 +124,7 @@ class DatasetManager:
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["valid"],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
......@@ -250,39 +250,38 @@ class GenericDataset(Dataset):
Load images and labels
"""
samples = list()
for path_and_set in paths_and_sets:
path = path_and_set["path"]
with open(os.path.join(path, "labels.json"), "rb") as f:
gt_per_set = json.load(f)
set_name = path_and_set["set_name"]
with open(os.path.join(path, "labels.pkl"), "rb") as f:
info = pickle.load(f)
gt = info["ground_truth"][set_name]
for filename in natural_sort(gt.keys()):
name = os.path.join(os.path.basename(path), set_name, filename)
full_path = os.path.join(path, set_name, filename)
if isinstance(gt[filename], dict) and "text" in gt[filename]:
label = gt[filename]["text"]
else:
label = gt[filename]
samples.append(
{
"name": name,
"label": label,
"unchanged_label": label,
"path": full_path,
"nb_cols": 1
if "nb_cols" not in gt[filename]
else gt[filename]["nb_cols"],
}
)
if load_in_memory:
samples[-1]["img"] = GenericDataset.load_image(full_path)
if type(gt[filename]) is dict:
if "lines" in gt[filename].keys():
samples[-1]["raw_line_seg_label"] = gt[filename]["lines"]
if "paragraphs" in gt[filename].keys():
samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"]
if "pages" in gt[filename].keys():
samples[-1]["pages_label"] = gt[filename]["pages"]
gt = gt_per_set[set_name]
for filename in natural_sort(gt.keys()):
if isinstance(gt[filename], dict) and "text" in gt[filename]:
label = gt[filename]["text"]
else:
label = gt[filename]
samples.append(
{
"name": os.path.basename(filename),
"label": label,
"unchanged_label": label,
"path": os.path.abspath(filename),
"nb_cols": 1
if "nb_cols" not in gt[filename]
else gt[filename]["nb_cols"],
}
)
if load_in_memory:
samples[-1]["img"] = GenericDataset.load_image(filename)
if type(gt[filename]) is dict:
if "lines" in gt[filename].keys():
samples[-1]["raw_line_seg_label"] = gt[filename]["lines"]
if "paragraphs" in gt[filename].keys():
samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"]
if "pages" in gt[filename].keys():
samples[-1]["pages_label"] = gt[filename]["pages"]
return samples
def apply_preprocessing(self, preprocessings):
......@@ -344,7 +343,7 @@ class GenericDataset(Dataset):
self.params["config"][key] if key in self.params["config"].keys() else None
for key in ["augmentation", "valid_augmentation", "test_augmentation"]
]
for aug, set_name in zip(augs, ["train", "valid", "test"]):
for aug, set_name in zip(augs, ["train", "val", "test"]):
if aug and self.set_name == set_name:
return apply_data_augmentation(img, aug)
return img, list()
......
......@@ -5,8 +5,8 @@ import editdistance
import networkx as nx
import numpy as np
from dan.datasets.format.simara import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
from dan.post_processing import PostProcessingModuleSIMARA
from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
class MetricManager:
......@@ -31,6 +31,7 @@ class MetricManager:
self.linked_metrics = {
"cer": ["edit_chars", "nb_chars"],
"wer": ["edit_words", "nb_words"],
"wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"],
"loer": [
"edit_graph",
"nb_nodes_and_edges",
......@@ -127,7 +128,21 @@ class MetricManager:
)
if output:
display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]:
elif metric_name == "wer_no_punct":
value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum(
self.epoch_metrics["nb_words_no_punct"]
)
if output:
display_values["nb_words_no_punct"] = np.sum(
self.epoch_metrics["nb_words_no_punct"]
)
elif metric_name in [
"loss",
"loss_ctc",
"loss_ce",
"syn_max_lines",
"syn_prob_lines",
]:
value = np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
......@@ -177,11 +192,26 @@ class MetricManager:
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words"] = [len(gt) for gt in split_gt]
elif metric_name == "wer_no_punct":
split_gt = [
format_string_for_wer(gt, self.layout_tokens, remove_punct=True)
for gt in values["str_y"]
]
split_pred = [
format_string_for_wer(pred, self.layout_tokens, remove_punct=True)
for pred in values["str_x"]
]
metrics["edit_words_no_punct"] = [
edit_wer_from_formatted_split_text(gt, pred)
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
elif metric_name in [
"loss_ctc",
"loss_ce",
"loss",
"syn_max_lines",
"syn_prob_lines",
]:
metrics[metric_name] = [
values[metric_name],
......@@ -251,22 +281,23 @@ def nb_chars_cer_from_string(gt, layout_tokens=None):
return len(format_string_for_cer(gt, layout_tokens))
def edit_wer_from_string(gt, pred, layout_tokens=None):
def edit_wer_from_string(gt, pred, layout_tokens=None, remove_punct=False):
"""
Format and compute edit distance between two strings at word level
"""
split_gt = format_string_for_wer(gt, layout_tokens)
split_pred = format_string_for_wer(pred, layout_tokens)
split_gt = format_string_for_wer(gt, layout_tokens, remove_punct)
split_pred = format_string_for_wer(pred, layout_tokens, remove_punct)
return edit_wer_from_formatted_split_text(split_gt, split_pred)
def format_string_for_wer(str, layout_tokens):
def format_string_for_wer(str, layout_tokens, remove_punct=False):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", r" \1 ", str
) # punctuation processed as word
if remove_punct:
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation
if layout_tokens is not None:
str = keep_all_but_tokens(
str, layout_tokens
......
......@@ -9,10 +9,12 @@ import torch
from fontTools.ttLib import TTFont
from PIL import Image, ImageDraw, ImageFont
from dan import logger
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.ocr.utils import LM_str_to_ind
from dan.utils import (
pad_image,
pad_image_width_random,
pad_image_width_right,
pad_images,
pad_sequences_1D,
......@@ -37,17 +39,10 @@ class OCRDatasetManager(DatasetManager):
if (
"synthetic_data" in self.params["config"]
and self.params["config"]["synthetic_data"]
and "config" in self.params["config"]["synthetic_data"]
):
self.char_only_set = self.charset.copy()
for token in [
"\n",
]:
if token in self.char_only_set:
self.char_only_set.remove(token)
self.params["config"]["synthetic_data"]["config"][
"valid_fonts"
] = get_valid_fonts(self.char_only_set)
self.synthetic_data = self.params["config"]["synthetic_data"]
if "config" in self.synthetic_data:
self.synthetic_data["config"]["valid_fonts"] = self.get_valid_fonts()
if "new_tokens" in params:
self.charset = sorted(
......@@ -78,9 +73,8 @@ class OCRDatasetManager(DatasetManager):
datasets = self.params["datasets"]
charset = set()
for key in datasets.keys():
with open(os.path.join(datasets[key], "labels.pkl"), "rb") as f:
info = pickle.load(f)
charset = charset.union(set(info["charset"]))
with open(os.path.join(datasets[key], "charset.pkl"), "rb") as f:
charset = charset.union(set(pickle.load(f)))
if (
"\n" in charset
and "remove_linebreaks" in self.params["config"]["constraints"]
......@@ -109,6 +103,34 @@ class OCRDatasetManager(DatasetManager):
[s["img"].shape[1] for s in self.train_dataset.samples]
)
def get_valid_fonts(self):
"""
Select fonts that are compatible with the alphabet
"""
font_path = self.synthetic_data["font_path"]
alphabet = self.charset.copy()
special_chars = ["\n"]
alphabet = [char for char in alphabet if char not in special_chars]
valid_fonts = list()
for fold_detail in os.walk(font_path):
if fold_detail[2]:
for font_name in fold_detail[2]:
if ".ttf" not in font_name:
continue
font_path = os.path.join(fold_detail[0], font_name)
to_add = True
if alphabet is not None:
for char in alphabet:
if not char_in_font(char, font_path):
to_add = False
break
if to_add:
valid_fonts.append(font_path)
else:
valid_fonts.append(font_path)
logger.info(f"Found {len(valid_fonts)} fonts.")
return valid_fonts
class OCRDataset(GenericDataset):
"""
......@@ -124,6 +146,13 @@ class OCRDataset(GenericDataset):
)
self.collate_function = OCRCollateFunction
self.synthetic_id = 0
if (
"synthetic_data" in self.params["config"]
and self.params["config"]["synthetic_data"]
):
self.synthetic_data = self.params["config"]["synthetic_data"]
else:
self.synthetic_data = None
def __getitem__(self, idx):
sample = copy.deepcopy(self.samples[idx])
......@@ -282,67 +311,75 @@ class OCRDataset(GenericDataset):
return sample
def generate_synthetic_data(self, sample):
config = self.params["config"]["synthetic_data"]
if not (config["init_proba"] == config["end_proba"] == 1):
nb_samples = self.training_info["step"] * self.params["batch_size"]
if config["start_scheduler_at_max_line"]:
max_step = config["num_steps_proba"]
current_step = max(
0,
min(
nb_samples
- config["curr_step"]
* (config["max_nb_lines"] - config["min_nb_lines"]),
max_step,
),
)
proba = (
config["init_proba"]
if self.get_syn_max_lines() < config["max_nb_lines"]
else config["proba_scheduler_function"](
config["init_proba"],
config["end_proba"],
current_step,
max_step,
)
)
else:
proba = config["proba_scheduler_function"](
config["init_proba"],
config["end_proba"],
min(nb_samples, config["num_steps_proba"]),
config["num_steps_proba"],
)
if rand() > proba:
return sample
if "mode" in config and config["mode"] == "line_hw_to_printed":
proba = self.get_syn_proba_lines()
if rand() > proba:
return sample
if (
"mode" in self.synthetic_data
and self.synthetic_data["mode"] == "line_hw_to_printed"
):
sample["img"] = self.generate_typed_text_line_image(sample["label"])
return sample
return self.generate_synthetic_page_sample()
def get_syn_max_lines(self):
config = self.params["config"]["synthetic_data"]
if config["curriculum"]:
if self.synthetic_data["curriculum"]:
nb_samples = self.training_info["step"] * self.params["batch_size"]
max_nb_lines = min(
config["max_nb_lines"],
(nb_samples - config["curr_start"]) // config["curr_step"] + 1,
self.synthetic_data["max_nb_lines"],
(nb_samples - self.synthetic_data["curr_start"])
// self.synthetic_data["curr_step"]
+ 1,
)
return max(self.synthetic_data["min_nb_lines"], max_nb_lines)
return self.synthetic_data["max_nb_lines"]
def get_syn_proba_lines(self):
if self.synthetic_data["init_proba"] == self.synthetic_data["end_proba"]:
return self.synthetic_data["init_proba"]
nb_samples = self.training_info["step"] * self.params["batch_size"]
if self.synthetic_data["start_scheduler_at_max_line"]:
max_step = self.synthetic_data["num_steps_proba"]
current_step = max(
0,
min(
nb_samples
- self.synthetic_data["curr_step"]
* (
self.synthetic_data["max_nb_lines"]
- self.synthetic_data["min_nb_lines"]
),
max_step,
),
)
proba = (
self.synthetic_data["init_proba"]
if self.get_syn_max_lines() < self.synthetic_data["max_nb_lines"]
else self.synthetic_data["proba_scheduler_function"](
self.synthetic_data["init_proba"],
self.synthetic_data["end_proba"],
current_step,
max_step,
)
)
else:
proba = self.synthetic_data["proba_scheduler_function"](
self.synthetic_data["init_proba"],
self.synthetic_data["end_proba"],
min(nb_samples, self.synthetic_data["num_steps_proba"]),
self.synthetic_data["num_steps_proba"],
)
return max(config["min_nb_lines"], max_nb_lines)
return config["max_nb_lines"]
return proba
def generate_synthetic_page_sample(self):
config = self.params["config"]["synthetic_data"]
max_nb_lines_per_page = self.get_syn_max_lines()
crop = (
config["crop_curriculum"] and max_nb_lines_per_page < config["max_nb_lines"]
self.synthetic_data["crop_curriculum"]
and max_nb_lines_per_page < self.synthetic_data["max_nb_lines"]
)
sample = {"name": "synthetic_data_{}".format(self.synthetic_id), "path": None}
self.synthetic_id += 1
nb_pages = 2 if "double" in config["dataset_level"] else 1
nb_pages = 2 if "double" in self.synthetic_data["dataset_level"] else 1
background_sample = copy.deepcopy(self.samples[randint(0, len(self))])
pages = list()
backgrounds = list()
......@@ -351,7 +388,7 @@ class OCRDataset(GenericDataset):
page_width = w // 2 if nb_pages == 2 else w
for i in range(nb_pages):
nb_lines_per_page = randint(
config["min_nb_lines"], max_nb_lines_per_page + 1
self.synthetic_data["min_nb_lines"], max_nb_lines_per_page + 1
)
background = (
np.ones((h, page_width, c), dtype=background_sample["img"].dtype) * 255
......@@ -387,7 +424,19 @@ class OCRDataset(GenericDataset):
)
)
else:
raise NotImplementedError
# Get a page-level transcription and split it by lines
texts = self.samples[randint(0, len(self))]["label"].split("\n")
# Select some lines to be generated
n_lines = min(len(texts), nb_lines_per_page)
i = randint(0, len(texts) - n_lines + 1)
texts = texts[i : i + n_lines]
# Generate the synthetic document (of n_lines)
pages.append(
self.generate_typed_text_paragraph_image(
texts=texts,
same_font_size=True,
)
)
if nb_pages == 1:
sample["img"] = pages[0][0]
......@@ -420,9 +469,79 @@ class OCRDataset(GenericDataset):
return sample
def generate_typed_text_line_image(self, text):
return generate_typed_text_line_image(
text, self.params["config"]["synthetic_data"]["config"]
)
return generate_typed_text_line_image(text, self.synthetic_data["config"])
def generate_typed_text_paragraph_image(
self, texts, padding_value=255, max_pad_left_ratio=0.1, same_font_size=False
):
"""
Generate a synthetic paragraph from a list of texts where each line is generated with a different font.
"""
if same_font_size:
images = list()
txt_color = self.synthetic_data["config"]["text_color_default"]
bg_color = self.synthetic_data["config"]["background_color_default"]
font_size = randint(
self.synthetic_data["config"]["font_size_min"],
self.synthetic_data["config"]["font_size_max"] + 1,
)
for text in texts:
font_path = self.synthetic_data["config"]["valid_fonts"][
randint(0, len(self.synthetic_data["config"]["valid_fonts"]))
]
fnt = ImageFont.truetype(font_path, font_size)
text_width, text_height = fnt.getsize(text)
padding_top = get_random_padding(
self.synthetic_data["config"]["padding_top_ratio_min"],
self.synthetic_data["config"]["padding_top_ratio_max"],
text_height,
)
padding_bottom = get_random_padding(
self.synthetic_data["config"]["padding_bottom_ratio_min"],
self.synthetic_data["config"]["padding_bottom_ratio_max"],
text_height,
)
padding_left = get_random_padding(
self.synthetic_data["config"]["padding_left_ratio_min"],
self.synthetic_data["config"]["padding_left_ratio_max"],
text_width,
)
padding_right = get_random_padding(
self.synthetic_data["config"]["padding_right_ratio_min"],
self.synthetic_data["config"]["padding_right_ratio_max"],
text_width,
)
padding = [padding_top, padding_bottom, padding_left, padding_right]
images.append(
generate_typed_text_line_image_from_params(
text,
fnt,
bg_color,
txt_color,
self.synthetic_data["config"]["color_mode"],
padding,
)
)
else:
images = [generate_typed_text_line_image(t) for t in texts]
max_width = max([img.shape[1] for img in images])
padded_images = [
pad_image_width_random(
img,
max_width,
padding_value=padding_value,
max_pad_left_ratio=max_pad_left_ratio,
)
for img in images
]
label = {
"sem": "\n".join(texts),
"begin": "\n".join(texts),
"raw": "\n".join(texts),
}
# image, label, n_col
return [np.concatenate(padded_images, axis=0), label, 1]
class OCRCollateFunction:
......@@ -556,6 +675,13 @@ class OCRCollateFunction:
return formatted_batch_data
def get_random_padding(min_ratio, max_ratio, text_size):
"""
Compute random padding value as a ratio of text width or height
"""
return int(rand_uniform(min_ratio, max_ratio) * text_size)
def generate_typed_text_line_image(
text, config, bg_color=(255, 255, 255), txt_color=(0, 0, 0)
):
......@@ -571,25 +697,25 @@ def generate_typed_text_line_image(
fnt = ImageFont.truetype(font_path, font_size)
text_width, text_height = fnt.getsize(text)
padding_top = int(
rand_uniform(config["padding_top_ratio_min"], config["padding_top_ratio_max"])
* text_height
padding_top = get_random_padding(
config["padding_top_ratio_min"],
config["padding_top_ratio_max"],
text_height,
)
padding_bottom = int(
rand_uniform(
config["padding_bottom_ratio_min"], config["padding_bottom_ratio_max"]
)
* text_height
padding_bottom = get_random_padding(
config["padding_bottom_ratio_min"],
config["padding_bottom_ratio_max"],
text_height,
)
padding_left = int(
rand_uniform(config["padding_left_ratio_min"], config["padding_left_ratio_max"])
* text_width
padding_left = get_random_padding(
config["padding_left_ratio_min"],
config["padding_left_ratio_max"],
text_width,
)
padding_right = int(
rand_uniform(
config["padding_right_ratio_min"], config["padding_right_ratio_max"]
)
* text_width
padding_right = get_random_padding(
config["padding_right_ratio_min"],
config["padding_right_ratio_max"],
text_width,
)
padding = [padding_top, padding_bottom, padding_left, padding_right]
return generate_typed_text_line_image_from_params(
......@@ -617,24 +743,3 @@ def char_in_font(unicode_char, font_path):
if ord(unicode_char) in cmap.cmap:
return True
return False
def get_valid_fonts(alphabet=None):
valid_fonts = list()
for fold_detail in os.walk("../../../Fonts"):
if fold_detail[2]:
for font_name in fold_detail[2]:
if ".ttf" not in font_name:
continue
font_path = os.path.join(fold_detail[0], font_name)
to_add = True
if alphabet is not None:
for char in alphabet:
if not char_in_font(char, font_path):
to_add = False
break
if to_add:
valid_fonts.append(font_path)
else:
valid_fonts.append(font_path)
return valid_fonts
......@@ -828,7 +828,6 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]))
self.dataset.remove_test_dataset(custom_name)
# output metrics values if requested
if output:
if "pred" in metric_names:
......@@ -981,14 +980,14 @@ class OCRManager(GenericTrainingManager):
os.makedirs(path, exist_ok=True)
charset = set()
dataset = None
gt = {"train": dict(), "valid": dict(), "test": dict()}
for set_name in ["train", "valid", "test"]:
gt = {"train": dict(), "val": dict(), "test": dict()}
for set_name in ["train", "val", "test"]:
set_path = os.path.join(path, set_name)
os.makedirs(set_path, exist_ok=True)
if set_name == "train":
dataset = self.dataset.train_dataset
elif set_name == "valid":
dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)]
elif set_name == "val":
dataset = self.dataset.valid_datasets["{}-val".format(dataset_name)]
elif set_name == "test":
self.dataset.generate_test_loader(
"{}-test".format(dataset_name),
......@@ -1028,14 +1027,15 @@ class OCRManager(GenericTrainingManager):
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
with open(os.path.join(path, "labels.pkl"), "wb") as f:
pickle.dump(
{
"ground_truth": gt,
"charset": sorted(list(charset)),
},
with open(os.path.join(path / "labels.json"), "w") as f:
json.dump(
gt,
f,
sort_keys=True,
indent=4,
)
with open(os.path.join(path / "charset.pkl"), "wb") as f:
pickle.dump(sorted(list(charset)), f)
class Manager(OCRManager):
......@@ -1160,6 +1160,9 @@ class Manager(OCRManager):
"syn_max_lines": self.dataset.train_dataset.get_syn_max_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
"syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
}
return values
......@@ -1259,7 +1262,7 @@ class Manager(OCRManager):
)
predicted_tokens_len += 1
prediction_len[reached_end is False] = i + 1
prediction_len[reached_end == False] = i + 1 # noqa E712
if torch.all(reached_end):
break
......
# -*- coding: utf-8 -*-
import json
import os
import pickle
......@@ -24,14 +25,14 @@ class OCRManager(GenericTrainingManager):
os.makedirs(path, exist_ok=True)
charset = set()
dataset = None
gt = {"train": dict(), "valid": dict(), "test": dict()}
for set_name in ["train", "valid", "test"]:
gt = {"train": dict(), "val": dict(), "test": dict()}
for set_name in ["train", "val", "test"]:
set_path = os.path.join(path, set_name)
os.makedirs(set_path, exist_ok=True)
if set_name == "train":
dataset = self.dataset.train_dataset
elif set_name == "valid":
dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)]
elif set_name == "val":
dataset = self.dataset.valid_datasets["{}-val".format(dataset_name)]
elif set_name == "test":
self.dataset.generate_test_loader(
"{}-test".format(dataset_name),
......@@ -71,11 +72,12 @@ class OCRManager(GenericTrainingManager):
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
with open(os.path.join(path, "labels.pkl"), "wb") as f:
pickle.dump(
{
"ground_truth": gt,
"charset": sorted(list(charset)),
},
with open(os.path.join(path / "labels.json"), "w") as f:
json.dump(
gt,
f,
sort_keys=True,
indent=4,
)
with open(os.path.join(path / "charset.pkl"), "wb") as f:
pickle.dump(sorted(list(charset)), f)
......@@ -10,7 +10,7 @@ from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager
from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler
from dan.schedulers import exponential_dropout_scheduler, linear_scheduler
from dan.transforms import aug_config
......@@ -32,9 +32,9 @@ def train_and_test(rank, params):
model.params["training_params"]["load_epoch"] = "best"
model.load_model()
metrics = ["cer", "wer", "time", "map_cer", "loer"]
metrics = ["cer", "wer", "wer_no_punct", "time"]
for dataset_name in params["dataset_params"]["datasets"].keys():
for set_name in ["test", "valid", "train"]:
for set_name in ["test", "val", "train"]:
model.predict(
"{}-{}".format(dataset_name, set_name),
[
......@@ -46,16 +46,16 @@ def train_and_test(rank, params):
def run():
dataset_name = "simara"
dataset_level = "page"
dataset_variant = "_sem"
dataset_name = "esposalles"
dataset_level = "record"
dataset_variant = ""
params = {
"dataset_params": {
"dataset_manager": OCRDatasetManager,
"dataset_class": OCRDataset,
"datasets": {
dataset_name: "../../../Datasets/formatted/{}_{}{}".format(
dataset_name: "{}_{}{}".format(
dataset_name, dataset_level, dataset_variant
),
},
......@@ -65,9 +65,9 @@ def run():
(dataset_name, "train"),
],
},
"valid": {
"{}-valid".format(dataset_name): [
(dataset_name, "valid"),
"val": {
"{}-val".format(dataset_name): [
(dataset_name, "val"),
],
},
"config": {
......@@ -90,40 +90,41 @@ def run():
},
],
"augmentation": aug_config(0.9, 0.1),
"synthetic_data": None,
# "synthetic_data": {
# "init_proba": 0.9, # begin proba to generate synthetic document
# "end_proba": 0.2, # end proba to generate synthetic document
# "num_steps_proba": 200000, # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples
# "proba_scheduler_function": linear_scheduler, # decrease proba rate linearly
# "start_scheduler_at_max_line": True, # start decreasing proba only after curriculum reach max number of lines
# "dataset_level": dataset_level,
# "curriculum": True, # use curriculum learning (slowly increase number of lines per synthetic samples)
# "crop_curriculum": True, # during curriculum learning, crop images under the last text line
# "curr_start": 0, # start curriculum at iteration
# "curr_step": 10000, # interval to increase the number of lines for curriculum learning
# "min_nb_lines": 1, # initial number of lines for curriculum learning
# "max_nb_lines": max_nb_lines[dataset_name], # maximum number of lines for curriculum learning
# "padding_value": 255,
# # config for synthetic line generation
# "config": {
# "background_color_default": (255, 255, 255),
# "background_color_eps": 15,
# "text_color_default": (0, 0, 0),
# "text_color_eps": 15,
# "font_size_min": 35,
# "font_size_max": 45,
# "color_mode": "RGB",
# "padding_left_ratio_min": 0.00,
# "padding_left_ratio_max": 0.05,
# "padding_right_ratio_min": 0.02,
# "padding_right_ratio_max": 0.2,
# "padding_top_ratio_min": 0.02,
# "padding_top_ratio_max": 0.1,
# "padding_bottom_ratio_min": 0.02,
# "padding_bottom_ratio_max": 0.1,
# },
# }
# "synthetic_data": None,
"synthetic_data": {
"init_proba": 0.9, # begin proba to generate synthetic document
"end_proba": 0.2, # end proba to generate synthetic document
"num_steps_proba": 200000, # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples
"proba_scheduler_function": linear_scheduler, # decrease proba rate linearly
"start_scheduler_at_max_line": True, # start decreasing proba only after curriculum reach max number of lines
"dataset_level": dataset_level,
"curriculum": True, # use curriculum learning (slowly increase number of lines per synthetic samples)
"crop_curriculum": True, # during curriculum learning, crop images under the last text line
"curr_start": 0, # start curriculum at iteration
"curr_step": 10000, # interval to increase the number of lines for curriculum learning
"min_nb_lines": 1, # initial number of lines for curriculum learning
"max_nb_lines": 4, # maximum number of lines for curriculum learning
"padding_value": 255,
"font_path": "fonts/",
# config for synthetic line generation
"config": {
"background_color_default": (255, 255, 255),
"background_color_eps": 15,
"text_color_default": (0, 0, 0),
"text_color_eps": 15,
"font_size_min": 35,
"font_size_max": 45,
"color_mode": "RGB",
"padding_left_ratio_min": 0.00,
"padding_left_ratio_max": 0.05,
"padding_right_ratio_min": 0.02,
"padding_right_ratio_max": 0.2,
"padding_top_ratio_min": 0.02,
"padding_top_ratio_max": 0.1,
"padding_bottom_ratio_min": 0.02,
"padding_bottom_ratio_max": 0.1,
},
},
},
},
"model_params": {
......@@ -134,8 +135,18 @@ def run():
# "transfer_learning": None,
"transfer_learning": {
# model_name: [state_dict_name, checkpoint_path, learnable, strict]
"encoder": ["encoder", "dan_rimes_page.pt", True, True],
"decoder": ["decoder", "dan_rimes_page.pt", True, False],
"encoder": [
"encoder",
"pretrained_models/dan_rimes_page.pt",
True,
True,
],
"decoder": [
"decoder",
"pretrained_models/dan_rimes_page.pt",
True,
False,
],
},
"transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model
"additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset
......@@ -163,7 +174,7 @@ def run():
},
},
"training_params": {
"output_folder": "dan_simara_page", # folder name for checkpoint and results
"output_folder": "dan_esposalles_record", # folder name for checkpoint and results
"max_nb_epochs": 50000, # maximum number of epochs before to stop
"max_training_time": 3600
* 24
......@@ -190,19 +201,21 @@ def run():
"eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-valid".format(
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
"train_metrics": [
"loss_ce",
"cer",
"wer",
"wer_no_punct",
"syn_max_lines",
"syn_prob_lines",
], # Metrics name for training
"eval_metrics": [
"cer",
"wer",
"map_cer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes
"max_char_prediction": 1000, # max number of token prediction
......
......@@ -48,9 +48,9 @@ def run():
(dataset_name, "train"),
],
},
"valid": {
"{}-valid".format(dataset_name): [
(dataset_name, "valid"),
"val": {
"{}-val".format(dataset_name): [
(dataset_name, "val"),
],
},
"config": {
......@@ -135,12 +135,18 @@ def run():
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-valid".format(dataset_name),
"train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training
"set_name_focus_metric": "{}-val".format(dataset_name),
"train_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes to run on cpu only
},
......
......@@ -37,12 +37,13 @@ def train_and_test(rank, params):
metrics = [
"cer",
"wer",
"wer_no_punct",
"time",
]
for dataset_name in params["dataset_params"]["datasets"].keys():
for set_name in [
"test",
"valid",
"val",
"train",
]:
model.predict(
......@@ -73,9 +74,9 @@ def run():
(dataset_name, "train"),
],
},
"valid": {
"{}-valid".format(dataset_name): [
(dataset_name, "valid"),
"val": {
"{}-val".format(dataset_name): [
(dataset_name, "val"),
],
},
"config": {
......@@ -175,14 +176,20 @@ def run():
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-valid".format(
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
"train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training
"train_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes to run on cpu only
},
......