Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (43)
Showing
with 863 additions and 893 deletions
[flake8]
max-line-length = 120
exclude=.cache,.eggs,.git
# Flake8 ignores multiple errors by default;
# the only interesting ignore is W503, which goes against PEP8.
# See https://lintlyci.github.io/Flake8Rules/rules/W503.html
ignore = E203,E501,W503
...@@ -3,6 +3,10 @@ stages: ...@@ -3,6 +3,10 @@ stages:
- build - build
- deploy - deploy
variables:
# Submodule clone
GIT_SUBMODULE_STRATEGY: recursive
lint: lint:
image: python:3.10 image: python:3.10
stage: test stage: test
...@@ -34,14 +38,10 @@ test: ...@@ -34,14 +38,10 @@ test:
variables: variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
ARKINDEX_API_SCHEMA_URL: schema.yml
before_script: before_script:
- pip install tox - pip install tox
# Download OpenAPI schema from last backend build
- curl https://assets.teklia.com/arkindex/openapi.yml > schema.yml
# Add system deps for opencv # Add system deps for opencv
- apt-get update -q - apt-get update -q
- apt-get install -q -y libgl1 - apt-get install -q -y libgl1
......
[submodule "line_image_extractor"]
path = teklia_line_image_extractor
url = ../line_image_extractor.git
repos: repos:
- repo: https://github.com/PyCQA/isort - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 5.12.0 # Ruff version.
rev: v0.0.282
hooks: hooks:
- id: isort - id: ruff
args: ["--profile", "black"] args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: 23.1.0 rev: 23.7.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies:
- 'flake8-coding==1.3.2'
- 'flake8-debugger==4.1.2'
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.4.0
hooks: hooks:
...@@ -35,7 +29,7 @@ repos: ...@@ -35,7 +29,7 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: mixed-line-ending - id: mixed-line-ending
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.2.2 rev: v2.2.5
hooks: hooks:
- id: codespell - id: codespell
args: ['--write-changes'] args: ['--write-changes']
...@@ -46,3 +40,10 @@ repos: ...@@ -46,3 +40,10 @@ repos:
- repo: meta - repo: meta
hooks: hooks:
- id: check-useless-excludes - id: check-useless-excludes
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
hooks:
- id: mdformat
# Optionally add plugins
additional_dependencies:
- mdformat-mkdocs[recommended]
...@@ -4,31 +4,48 @@ ...@@ -4,31 +4,48 @@
## Documentation ## Documentation
To use DAN in your own scripts, install it using pip: To use DAN in your own environment, you need to first clone with its submodules via:
```console ```shell
git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
```
If you forgot the `--recurse-submodules`, you can initialize the submodule using:
```shell
git submodule update --init
```
Then you can install it via pip:
```shell
pip install -e . pip install -e .
``` ```
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/. For more details about this package, make sure to see the documentation available at <https://atr.pages.teklia.com/dan/>.
## Development ## Development
For development and tests purpose it may be useful to install the project as a editable package with pip. For development and tests purpose it may be useful to install the project as a editable package with pip.
* Use a virtualenv (e.g. with virtualenvwrapper `mkvirtualenv -a . dan`) - Use a virtualenv (e.g. with virtualenvwrapper `mkvirtualenv -a . dan`)
* Install `dan` as a package (e.g. `pip install -e .`) - Install `dan` as a package (e.g. `pip install -e .`)
### Linter ### Linter
Code syntax is analyzed before submitting the code.\ Code syntax is analyzed before submitting the code.\
To run the linter tools suite you may use pre-commit. To run the linter tools suite you may use pre-commit.
```shell ```shell
pip install pre-commit pip install pre-commit
pre-commit run -a pre-commit run -a
``` ```
### Run tests ### Run tests
Tests are executed with `tox` using [pytest](https://pytest.org). Tests are executed with `tox` using [pytest](https://pytest.org).
To install `tox`, To install `tox`,
```shell ```shell
pip install tox pip install tox
tox tox
...@@ -41,10 +58,22 @@ Run a single test: `tox -- <test_path>::<test_function>` ...@@ -41,10 +58,22 @@ Run a single test: `tox -- <test_path>::<test_function>`
The tests use a large file stored via [Git-LFS](https://docs.gitlab.com/ee/topics/git/lfs/). Make sure to run `git-lfs pull` before running them. The tests use a large file stored via [Git-LFS](https://docs.gitlab.com/ee/topics/git/lfs/). Make sure to run `git-lfs pull` before running them.
### Update documentation
Please keep the documentation updated when modifying or adding features.
It's pretty easy to do:
```shell
pip install -r doc-requirements.txt
mkdocs serve
```
You can then write in Markdown in the relevant `docs/*.md` files, and see live output on <http://localhost:8000>.
## Inference ## Inference
To apply DAN to an image, one needs to first add a few imports and to load an image. Note that the image should be in RGB. To apply DAN to an image, one needs to first add a few imports and to load an image. Note that the image should be in RGB.
```python ```python
import cv2 import cv2
from dan.predict import DAN from dan.predict import DAN
...@@ -53,16 +82,18 @@ image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB) ...@@ -53,16 +82,18 @@ image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
``` ```
Then one can initialize and load the trained model with the parameters used during training. Then one can initialize and load the trained model with the parameters used during training.
```python ```python
model_path = 'model.pt' model_path = "model.pt"
params_path = 'parameters.yml' params_path = "parameters.yml"
charset_path = 'charset.pkl' charset_path = "charset.pkl"
model = DAN('cpu') model = DAN("cpu")
model.load(model_path, params_path, charset_path, mode="eval") model.load(model_path, params_path, charset_path, mode="eval")
``` ```
To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In the end, one can run the prediction: To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In the end, one can run the prediction:
```python ```python
text, confidence_scores = model.predict(image, confidences=True) text, confidence_scores = model.predict(image, confidences=True)
``` ```
...@@ -71,18 +102,22 @@ text, confidence_scores = model.predict(image, confidences=True) ...@@ -71,18 +102,22 @@ text, confidence_scores = model.predict(image, confidences=True)
This package provides three subcommands. To get more information about any subcommand, use the `--help` option. This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
### Get started
See the [dedicated section](https://atr.pages.teklia.com/dan/get_started/training/) on the official DAN documentation.
### Data extraction from Arkindex ### Data extraction from Arkindex
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/extract/) on the official DAN documentation. See the [dedicated section](https://atr.pages.teklia.com/dan/usage/datasets/extract/) on the official DAN documentation.
### Dataset formatting ### Dataset formatting
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/format/) on the official DAN documentation. See the [dedicated section](https://atr.pages.teklia.com/dan/usage/datasets/format/) on the official DAN documentation.
### Model training ### Model training
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation. See the [dedicated section](https://atr.pages.teklia.com/dan/usage/train/) on the official DAN documentation.
### Model prediction ### Model prediction
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation. See the [dedicated section](https://atr.pages.teklia.com/dan/usage/predict/) on the official DAN documentation.
0.2.0-dev2 0.2.0-dev3
...@@ -27,17 +27,13 @@ def parse_worker_version(worker_version_id): ...@@ -27,17 +27,13 @@ def parse_worker_version(worker_version_id):
return worker_version_id return worker_version_id
def validate_probability(proba): def validate_char(char):
try: if len(char) != 1:
proba = float(proba)
except ValueError:
raise argparse.ArgumentTypeError(f"`{proba}` is not a valid float.")
if proba > 1 or proba < 0:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"`{proba}` is not a valid probability. Must be between 0 and 1 (both exclusive)." f"`{char}` (of length {len(char)}) is not a valid character. Must be a string of length 1."
) )
return proba
return char
def add_extract_parser(subcommands) -> None: def add_extract_parser(subcommands) -> None:
...@@ -53,13 +49,6 @@ def add_extract_parser(subcommands) -> None: ...@@ -53,13 +49,6 @@ def add_extract_parser(subcommands) -> None:
type=pathlib.Path, type=pathlib.Path,
help="Path where the data were exported from Arkindex.", help="Path where the data were exported from Arkindex.",
) )
parser.add_argument(
"--parent",
type=validate_uuid,
nargs="+",
help="ID of the parent folder to import from Arkindex.",
required=False,
)
parser.add_argument( parser.add_argument(
"--element-type", "--element-type",
nargs="+", nargs="+",
...@@ -81,39 +70,46 @@ def add_extract_parser(subcommands) -> None: ...@@ -81,39 +70,46 @@ def add_extract_parser(subcommands) -> None:
required=True, required=True,
) )
# Optional arguments.
parser.add_argument( parser.add_argument(
"--load-entities", action="store_true", help="Extract text with their entities." "--train-folder",
type=validate_uuid,
help="ID of the training folder to extract from Arkindex.",
required=True,
) )
parser.add_argument( parser.add_argument(
"--tokens", "--val-folder",
type=pathlib.Path, type=validate_uuid,
help="Mapping between starting tokens and end tokens. Needed for entities.", help="ID of the validation folder to extract from Arkindex.",
required=False, required=True,
) )
parser.add_argument( parser.add_argument(
"--use-existing-split", "--test-folder",
action="store_true", type=validate_uuid,
help="Use the specified folder IDs for the dataset split.", help="ID of the testing folder to extract from Arkindex.",
required=True,
) )
# Optional arguments.
parser.add_argument( parser.add_argument(
"--train-folder", "--load-entities",
type=validate_uuid, action="store_true",
help="ID of the training folder to import from Arkindex.", help="Extract text with their entities.",
required=False,
) )
parser.add_argument( parser.add_argument(
"--val-folder", "--entity-separators",
type=validate_uuid, type=validate_char,
help="ID of the validation folder to import from Arkindex.", nargs="+",
help="""
Removes all text that does not appear in an entity or in the list of given ordered characters.
If several separators follow each other, keep only the first to appear in the list.
Do not give any arguments to keep the whole text.
""",
required=False, required=False,
) )
parser.add_argument( parser.add_argument(
"--test-folder", "--tokens",
type=validate_uuid, type=pathlib.Path,
help="ID of the testing folder to import from Arkindex.", help="Mapping between starting tokens and end tokens. Needed for entities.",
required=False, required=False,
) )
...@@ -122,28 +118,12 @@ def add_extract_parser(subcommands) -> None: ...@@ -122,28 +118,12 @@ def add_extract_parser(subcommands) -> None:
type=parse_worker_version, type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.", help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False, required=False,
default=False,
) )
parser.add_argument( parser.add_argument(
"--entity-worker-version", "--entity-worker-version",
type=parse_worker_version, type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.", help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False, required=False,
default=False,
)
parser.add_argument(
"--train-prob",
type=validate_probability,
default=0.7,
help="Training set split size.",
)
parser.add_argument(
"--val-prob",
type=validate_probability,
default=0.15,
help="Validation set split size.",
) )
parser.add_argument( parser.add_argument(
...@@ -158,4 +138,12 @@ def add_extract_parser(subcommands) -> None: ...@@ -158,4 +138,12 @@ def add_extract_parser(subcommands) -> None:
help="Images larger than this height will be resized to this width.", help="Images larger than this height will be resized to this width.",
) )
parser.add_argument(
"--cache",
dest="cache_dir",
type=pathlib.Path,
help="Where the images should be cached.",
default=pathlib.Path(".cache"),
)
parser.set_defaults(func=run) parser.set_defaults(func=run)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import ast from typing import List, Union
from dataclasses import dataclass
from itertools import starmap
from typing import List, NamedTuple, Optional, Union
from urllib.parse import urljoin
from arkindex_export import Image from arkindex_export import Image
from arkindex_export.models import Element as ArkindexElement from arkindex_export.models import (
from arkindex_export.models import Entity as ArkindexEntity Element,
from arkindex_export.models import EntityType as ArkindexEntityType Entity,
from arkindex_export.models import Transcription as ArkindexTranscription EntityType,
from arkindex_export.models import TranscriptionEntity as ArkindexTranscriptionEntity Transcription,
from arkindex_export.queries import list_children TranscriptionEntity,
def bounding_box(polygon: list):
"""
Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return int(x), int(y), int(width), int(height)
# DB models
Transcription = NamedTuple(
"Transcription",
id=str,
text=str,
)
Entity = NamedTuple(
"Entity",
type=str,
value=str,
offset=float,
length=float,
) )
from arkindex_export.queries import list_children
@dataclass
class Element:
id: str
type: str
polygon: str
url: str
width: int
height: int
max_width: Optional[int] = None
max_height: Optional[int] = None
def __post_init__(self):
self.max_height = self.max_height or self.height
self.max_width = self.max_width or self.width
@property
def bounding_box(self):
return bounding_box(ast.literal_eval(self.polygon))
@property
def image_url(self):
x, y, width, height = self.bounding_box
return urljoin(
self.url + "/",
f"{x},{y},{width},{height}/!{self.max_width},{self.max_height}/0/default.jpg",
)
def get_elements( def get_elements(
parent_id: str, parent_id: str,
element_type: str, element_type: List[str],
max_width: Optional[int] = None, ):
max_height: Optional[int] = None,
) -> List[Element]:
""" """
Retrieve elements from an SQLite export of an Arkindex corpus Retrieve elements from an SQLite export of an Arkindex corpus
""" """
...@@ -84,23 +24,11 @@ def get_elements( ...@@ -84,23 +24,11 @@ def get_elements(
query = ( query = (
list_children(parent_id=parent_id) list_children(parent_id=parent_id)
.join(Image) .join(Image)
.where(ArkindexElement.type == element_type) .where(Element.type.in_(element_type))
.select(
ArkindexElement.id,
ArkindexElement.type,
ArkindexElement.polygon,
Image.url,
Image.width,
Image.height,
)
)
return list(
starmap(
lambda *x: Element(*x, max_width=max_width, max_height=max_height),
query.tuples(),
)
) )
return query
def build_worker_version_filter(ArkindexModel, worker_version): def build_worker_version_filter(ArkindexModel, worker_version):
""" """
...@@ -118,47 +46,43 @@ def get_transcriptions( ...@@ -118,47 +46,43 @@ def get_transcriptions(
""" """
Retrieve transcriptions from an SQLite export of an Arkindex corpus Retrieve transcriptions from an SQLite export of an Arkindex corpus
""" """
query = ArkindexTranscription.select( query = Transcription.select(
ArkindexTranscription.id, ArkindexTranscription.text Transcription.id, Transcription.text, Transcription.worker_version
).where( ).where((Transcription.element == element_id))
(ArkindexTranscription.element == element_id)
& build_worker_version_filter( if transcription_worker_version is not None:
ArkindexTranscription, worker_version=transcription_worker_version query = query.where(
) build_worker_version_filter(
) Transcription, worker_version=transcription_worker_version
return list( )
starmap(
Transcription,
query.tuples(),
) )
) return query
def get_transcription_entities( def get_transcription_entities(
transcription_id: str, entity_worker_version: Union[str, bool] transcription_id: str, entity_worker_version: Union[str, bool]
) -> List[Entity]: ) -> List[TranscriptionEntity]:
""" """
Retrieve transcription entities from an SQLite export of an Arkindex corpus Retrieve transcription entities from an SQLite export of an Arkindex corpus
""" """
query = ( query = (
ArkindexTranscriptionEntity.select( TranscriptionEntity.select(
ArkindexEntityType.name, EntityType.name.alias("type"),
ArkindexEntity.name, Entity.name.alias("name"),
ArkindexTranscriptionEntity.offset, TranscriptionEntity.offset,
ArkindexTranscriptionEntity.length, TranscriptionEntity.length,
) TranscriptionEntity.worker_version,
.join(ArkindexEntity, on=ArkindexTranscriptionEntity.entity)
.join(ArkindexEntityType, on=ArkindexEntity.type)
.where(
(ArkindexTranscriptionEntity.transcription == transcription_id)
& build_worker_version_filter(
ArkindexTranscriptionEntity, worker_version=entity_worker_version
)
) )
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
.where((TranscriptionEntity.transcription == transcription_id))
) )
return list(
starmap( if entity_worker_version is not None:
Entity, query = query.where(
query.order_by(ArkindexTranscriptionEntity.offset).tuples(), build_worker_version_filter(
TranscriptionEntity, worker_version=entity_worker_version
)
) )
)
return query.order_by(TranscriptionEntity.offset).namedtuples()
...@@ -49,9 +49,9 @@ class NoTranscriptionError(ElementProcessingError): ...@@ -49,9 +49,9 @@ class NoTranscriptionError(ElementProcessingError):
return f"No transcriptions found on element ({self.element_id}) with this config. Skipping." return f"No transcriptions found on element ({self.element_id}) with this config. Skipping."
class UnknownLabelError(ProcessingError): class NoEndTokenError(ProcessingError):
""" """
Raised when the specified label is not known Raised when the specified label has no end token and there is potentially additional text around the labels
""" """
label: str label: str
...@@ -61,4 +61,4 @@ class UnknownLabelError(ProcessingError): ...@@ -61,4 +61,4 @@ class UnknownLabelError(ProcessingError):
self.label = label self.label = label
def __str__(self) -> str: def __str__(self) -> str:
return f"Label `{self.label}` is missing in the NER configuration." return f"Label `{self.label}` has no end token."
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import random import random
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
from uuid import UUID from uuid import UUID
from arkindex_export import open_database import numpy as np
from tqdm import tqdm from tqdm import tqdm
from arkindex_export import open_database
from dan import logger from dan import logger
from dan.datasets.extract.db import ( from dan.datasets.extract.db import (
Element, Element,
Entity,
get_elements, get_elements,
get_transcription_entities, get_transcription_entities,
get_transcriptions, get_transcriptions,
) )
from dan.datasets.extract.exceptions import ( from dan.datasets.extract.exceptions import (
NoEndTokenError,
NoTranscriptionError, NoTranscriptionError,
ProcessingError, ProcessingError,
UnknownLabelError,
) )
from dan.datasets.extract.utils import ( from dan.datasets.extract.utils import (
EntityType, EntityType,
Subset,
download_image, download_image,
insert_token, insert_token,
parse_tokens, parse_tokens,
save_json,
save_text,
) )
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
IMAGES_DIR = "images" # Subpath to the images directory. IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory. LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"] SPLIT_NAMES = ["train", "val", "test"]
IIIF_URL_SUFFIX = "/full/full/0/default.jpg"
class ArkindexExtractor: class ArkindexExtractor:
...@@ -45,80 +46,105 @@ class ArkindexExtractor: ...@@ -45,80 +46,105 @@ class ArkindexExtractor:
def __init__( def __init__(
self, self,
folders: list = [], folders: list = [],
element_type: list = [], element_type: List[str] = [],
parent_element_type: str = None, parent_element_type: str = None,
output: Path = None, output: Path = None,
load_entities: bool = None, load_entities: bool = False,
entity_separators: list = [],
tokens: Path = None, tokens: Path = None,
use_existing_split: bool = None, transcription_worker_version: Optional[Union[str, bool]] = None,
transcription_worker_version: str = None, entity_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
max_width: Optional[int] = None, max_width: Optional[int] = None,
max_height: Optional[int] = None, max_height: Optional[int] = None,
cache_dir: Path = Path(".cache"),
) -> None: ) -> None:
self.folders = folders
self.element_type = element_type self.element_type = element_type
self.parent_element_type = parent_element_type self.parent_element_type = parent_element_type
self.output = output self.output = output
self.load_entities = load_entities self.load_entities = load_entities
self.entity_separators = entity_separators
self.tokens = parse_tokens(tokens) if self.load_entities else None self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.max_width = max_width self.max_width = max_width
self.max_height = max_height self.max_height = max_height
self.subsets = self.get_subsets(folders) self.cache_dir = cache_dir
# Create cache dir if non existent
self.cache_dir.mkdir(exist_ok=True, parents=True)
def get_subsets(self, folders: list) -> List[Subset]: def find_image_in_cache(self, image_id: str) -> Path:
""" """Images are cached to avoid downloading them twice. They are stored under a specific name,
Assign each folder to its split if it's already known. based on their Arkindex ID. Images are saved under the JPEG format.
"""
if self.use_existing_split:
return [
Subset(folder, split) for folder, split in zip(folders, SPLIT_NAMES)
]
else:
return [Subset(folder) for folder in folders]
def _assign_random_split(self): :param image_id: ID of the image. The image is saved under this name.
""" :return: Where the image should be saved in the cache folder.
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
""" """
prob = random.random() return self.cache_dir / f"{image_id}.jpg"
if prob <= self.train_prob:
yield SPLIT_NAMES[0]
elif prob <= self.train_prob + self.val_prob:
yield SPLIT_NAMES[1]
else:
yield SPLIT_NAMES[2]
def get_random_split(self): def _keep_char(self, char: str) -> bool:
return next(self._assign_random_split()) # Keep all text by default if no separator was given
return not self.entity_separators or char in self.entity_separators
def reconstruct_text(self, text: str, entities: List[Entity]): def reconstruct_text(self, full_text: str, entities) -> str:
""" """
Insert tokens delimiting the start/end of each entity on the transcription. Insert tokens delimiting the start/end of each entity on the transcription.
""" """
count = 0 text, text_offset = "", 0
# Keep all text by default if no separator was given
for entity in entities: for entity in entities:
if entity.type not in self.tokens: # Text before entity
raise UnknownLabelError(entity.type) text += "".join(
filter(self._keep_char, full_text[text_offset : entity.offset])
entity_type: EntityType = self.tokens[entity.type]
text = insert_token(
text,
count,
entity_type,
offset=entity.offset,
length=entity.length,
) )
count += entity_type.offset
return text entity_type: EntityType = self.tokens.get(entity.type)
if not entity_type:
logger.warning(
f"Label `{entity.type}` is missing in the NER configuration."
)
# We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends
elif not entity_type.end and not self.entity_separators:
raise NoEndTokenError(entity.type)
# Entity text:
# - with tokens if there is an entity_type
# - without tokens if there is no entity_type but we want to keep the whole text
if entity_type or not self.entity_separators:
text += insert_token(
full_text,
entity_type,
offset=entity.offset,
length=entity.length,
)
text_offset = entity.offset + entity.length
# Remaining text after the last entity
text += "".join(filter(self._keep_char, full_text[text_offset:]))
if not self.entity_separators:
return text
# Add some clean up to avoid several separators between entities
text, full_text = "", text
for char in full_text:
last_char = text[-1] if len(text) else ""
# Keep the current character if there are no two consecutive separators
if (
char not in self.entity_separators
or last_char not in self.entity_separators
):
text += char
# If several separators follow each other, keep only one according to the given order
elif self.entity_separators.index(char) < self.entity_separators.index(
last_char
):
text = text[:-1] + char
# Remove separators at the beginning and end of text
return text.strip("".join(self.entity_separators))
def extract_transcription(self, element: Element): def extract_transcription(self, element: Element):
""" """
...@@ -133,14 +159,60 @@ class ArkindexExtractor: ...@@ -133,14 +159,60 @@ class ArkindexExtractor:
transcription = random.choice(transcriptions) transcription = random.choice(transcriptions)
if self.load_entities: if not self.load_entities:
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
else:
return transcription.text.strip() return transcription.text.strip()
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
def retrieve_image(self, child: Element):
"""Get or download image of the element. Checks in cache before downloading.
:param child: Processed element.
:return: The element's image.
"""
cached_img_path = self.find_image_in_cache(child.image.id)
if not cached_img_path.exists():
# Save in cache
download_image(child.image.url + IIIF_URL_SUFFIX).save(
cached_img_path, format="jpeg"
)
return read_img(cached_img_path)
def get_image(self, child: Element, destination: Path) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param child: Processed element.
:param destination: Where the image should be saved.
"""
polygon = json.loads(str(child.polygon))
if self.max_height or self.max_width:
polygon = resize(
polygon,
self.max_width,
self.max_height,
scale_x=1.0,
scale_y_top=1.0,
scale_y_bottom=1.0,
)
# Extract the polygon in the image
image = extract(
img=self.retrieve_image(child),
polygon=np.array(polygon),
bbox=polygon_to_bbox(polygon),
# Hardcoded while we don't have a configuration file
extraction_mode=Extraction.deskew_min_area_rect,
max_deskew_angle=45,
)
# Save the image to disk
save_img(path=destination, img=image)
def process_element( def process_element(
self, self,
element: Element, element: Element,
...@@ -152,18 +224,17 @@ class ArkindexExtractor: ...@@ -152,18 +224,17 @@ class ArkindexExtractor:
""" """
text = self.extract_transcription(element) text = self.extract_transcription(element)
txt_path = Path( base_path = Path(split, f"{element.type}_{element.id}")
self.output, LABELS_DIR, split, f"{element.type}_{element.id}.txt" Path(self.output, LABELS_DIR, base_path).with_suffix(".txt").write_text(text)
)
save_text(txt_path, text) self.get_image(
im_path = Path( element, Path(self.output, IMAGES_DIR, base_path).with_suffix(".jpg")
self.output, IMAGES_DIR, split, f"{element.type}_{element.id}.jpg"
) )
download_image(element, im_path)
return element.id return element.id
def process_parent( def process_parent(
self, self,
pbar,
parent: Element, parent: Element,
split: str, split: str,
): ):
...@@ -171,7 +242,10 @@ class ArkindexExtractor: ...@@ -171,7 +242,10 @@ class ArkindexExtractor:
Extract data from a parent element. Extract data from a parent element.
""" """
data = defaultdict(list) data = defaultdict(list)
base_description = (
f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
)
pbar.set_description(desc=base_description)
if self.element_type == [parent.type]: if self.element_type == [parent.type]:
try: try:
data[parent.type].append(self.process_element(parent, split)) data[parent.type].append(self.process_element(parent, split))
...@@ -179,84 +253,64 @@ class ArkindexExtractor: ...@@ -179,84 +253,64 @@ class ArkindexExtractor:
logger.warning(f"Skipping {parent.id}: {str(e)}") logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements # Extract children elements
else: else:
for element_type in self.element_type: children = get_elements(
for element in get_elements( parent.id,
parent.id, self.element_type,
element_type, )
max_width=self.max_width,
max_height=self.max_height, nb_children = children.count()
): for idx, element in enumerate(children, start=1):
try: # Update description to update the children processing progress
data[element_type].append(self.process_element(element, split)) pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
except ProcessingError as e: try:
logger.warning(f"Skipping {element.id}: {str(e)}") data[element.type].append(self.process_element(element, split))
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
return data return data
def run(self): def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels. # Iterate over the subsets to find the page images and labels.
for idx, subset in enumerate(self.subsets, start=1): for folder_id, split in zip(self.folders, SPLIT_NAMES):
# Iterate over the pages to create splits at page level. with tqdm(
for parent in tqdm(
get_elements( get_elements(
subset.id, folder_id,
self.parent_element_type, [self.parent_element_type],
max_width=self.max_width,
max_height=self.max_height,
), ),
desc=f"Processing {subset} {idx}/{len(self.subsets)}", desc=f"Extracting data from ({folder_id}) for split ({split})",
): ) as pbar:
split = subset.split or self.get_random_split() # Iterate over the pages to create splits at page level.
split_dict[split][parent.id] = self.process_parent( for parent in pbar:
parent=parent, self.process_parent(
split=split, pbar=pbar,
) parent=parent,
save_json(self.output / "split.json", split_dict) split=split,
)
# Progress bar updates
pbar.update()
pbar.refresh()
def run( def run(
database: Path, database: Path,
parent: list, element_type: List[str],
element_type: str,
parent_element_type: str, parent_element_type: str,
output: Path, output: Path,
load_entities: bool, load_entities: bool,
entity_separators: list,
tokens: Path, tokens: Path,
use_existing_split: bool,
train_folder: UUID, train_folder: UUID,
val_folder: UUID, val_folder: UUID,
test_folder: UUID, test_folder: UUID,
transcription_worker_version: Union[str, bool], transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Union[str, bool], entity_worker_version: Optional[Union[str, bool]],
train_prob,
val_prob,
max_width: Optional[int], max_width: Optional[int],
max_height: Optional[int], max_height: Optional[int],
cache_dir: Path,
): ):
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
assert use_existing_split ^ bool(
parent
), "Only one of `--use-existing-split` and `--parent` must be set"
assert database.exists(), f"No file found @ {database}" assert database.exists(), f"No file found @ {database}"
open_database(path=database) open_database(path=database)
if use_existing_split: folders = [str(train_folder), str(val_folder), str(test_folder)]
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [str(train_folder), str(val_folder), str(test_folder)]
else:
folders = [str(parent_id) for parent_id in parent]
if load_entities: if load_entities:
assert tokens, "Please provide the entities to match." assert tokens, "Please provide the entities to match."
...@@ -272,12 +326,11 @@ def run( ...@@ -272,12 +326,11 @@ def run(
parent_element_type=parent_element_type, parent_element_type=parent_element_type,
output=output, output=output,
load_entities=load_entities, load_entities=load_entities,
entity_separators=entity_separators,
tokens=tokens, tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version, transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version, entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
max_width=max_width, max_width=max_width,
max_height=max_height, max_height=max_height,
cache_dir=cache_dir,
).run() ).run()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import logging import logging
import time from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
import cv2 import requests
import imageio.v2 as iio
import yaml import yaml
from numpy import ndarray from PIL import Image
from tenacity import (
from dan.datasets.extract.db import Element retry,
from dan.datasets.extract.exceptions import ImageDownloadError retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_RETRIES = 5 # See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
class Subset(NamedTuple):
id: str
split: str = None
def __str__(self) -> str: def _retry_log(retry_state, *args, **kwargs):
return ( logger.warning(
f"Subset(id='{self.id}', split='{self.split.capitalize()}')" f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
if self.split f"retrying in {retry_state.idle_for} seconds"
else f"Subset(id='{self.id}')" )
)
class EntityType(NamedTuple): class EntityType(NamedTuple):
...@@ -39,60 +36,54 @@ class EntityType(NamedTuple): ...@@ -39,60 +36,54 @@ class EntityType(NamedTuple):
return len(self.start) + len(self.end) return len(self.start) + len(self.end)
def download_image(element: Element, im_path: Path): @retry(
tries = 1 stop=stop_after_attempt(3),
# retry loop wait=wait_exponential(multiplier=2),
while True: retry=retry_if_exception_type(requests.RequestException),
if tries > MAX_RETRIES: before_sleep=_retry_log,
raise ImageDownloadError(element.id, Exception("Maximum retries reached.")) reraise=True,
try: )
image = iio.imread(element.image_url) def _retried_request(url):
save_image(im_path, image) resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
return resp.raise_for_status()
except TimeoutError: return resp
logger.warning("Timeout, retry in 1 second.")
time.sleep(1)
tries += 1
except Exception as e:
raise ImageDownloadError(element.id, e)
def save_text(path: Path, text: str):
with path.open("w") as f:
f.write(text)
def save_image(path: Path, image: ndarray): def download_image(url):
cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) """
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = _retried_request(url)
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
def save_json(path: Path, data: dict): return image
with path.open("w") as outfile:
json.dump(data, outfile, indent=4)
def insert_token( def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str:
text: str, count: int, entity_type: EntityType, offset: int, length: int
) -> str:
""" """
Insert the given tokens at the right position in the text Insert the given tokens at the right position in the text
""" """
return ( return (
# Text before entity
text[: count + offset]
# Starting token # Starting token
+ entity_type.start (entity_type.start if entity_type else "")
# Entity # Entity
+ text[count + offset : count + offset + length] + text[offset : offset + length]
# End token # End token
+ entity_type.end + (entity_type.end if entity_type else "")
# Text after entity
+ text[count + offset + length :]
) )
def parse_tokens(filename: Path) -> dict: def parse_tokens(filename: Path) -> dict:
with filename.open() as f: return {
return { name: EntityType(**tokens)
name: EntityType(**tokens) for name, tokens in yaml.safe_load(f).items() for name, tokens in yaml.safe_load(filename.read_text()).items()
} }
...@@ -63,10 +63,12 @@ class ATRDatasetFormatter: ...@@ -63,10 +63,12 @@ class ATRDatasetFormatter:
def parse_labels(self, set_name, file_name): def parse_labels(self, set_name, file_name):
return { return {
"img_path": os.path.join( "img_path": os.path.realpath(
self.image_folder, os.path.join(
set_name, self.image_folder,
f"{os.path.splitext(file_name)[0]}.{self.image_format}", set_name,
f"{os.path.splitext(file_name)[0]}.{self.image_format}",
)
), ),
"label": self.read_file( "label": self.read_file(
os.path.join(self.labels_folder, set_name, file_name) os.path.join(self.labels_folder, set_name, file_name)
......
...@@ -2,24 +2,13 @@ ...@@ -2,24 +2,13 @@
import torch import torch
from torch import relu, softmax from torch import relu, softmax
from torch.nn import ( from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
LSTM,
Conv1d,
Dropout,
Embedding,
LayerNorm,
Linear,
Module,
ModuleList,
)
from torch.nn.init import xavier_uniform_ from torch.nn.init import xavier_uniform_
class PositionalEncoding1D(Module): class PositionalEncoding1D(Module):
def __init__(self, dim, len_max, device): def __init__(self, dim, len_max, device):
super(PositionalEncoding1D, self).__init__() super(PositionalEncoding1D, self).__init__()
self.len_max = len_max
self.dim = dim
self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False) self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False)
div = torch.exp( div = torch.exp(
...@@ -46,9 +35,6 @@ class PositionalEncoding1D(Module): ...@@ -46,9 +35,6 @@ class PositionalEncoding1D(Module):
class PositionalEncoding2D(Module): class PositionalEncoding2D(Module):
def __init__(self, dim, h_max, w_max, device): def __init__(self, dim, h_max, w_max, device):
super(PositionalEncoding2D, self).__init__() super(PositionalEncoding2D, self).__init__()
self.h_max = h_max
self.max_w = w_max
self.dim = dim
self.pe = torch.zeros( self.pe = torch.zeros(
(1, dim, h_max, w_max), device=device, requires_grad=False (1, dim, h_max, w_max), device=device, requires_grad=False
) )
...@@ -177,31 +163,28 @@ class GlobalDecoderLayer(Module): ...@@ -177,31 +163,28 @@ class GlobalDecoderLayer(Module):
def __init__(self, params): def __init__(self, params):
super(GlobalDecoderLayer, self).__init__() super(GlobalDecoderLayer, self).__init__()
self.emb_dim = params["enc_dim"]
self.dim_feedforward = params["dec_dim_feedforward"]
self.self_att = CustomMultiHeadAttention( self.self_att = CustomMultiHeadAttention(
embed_dim=self.emb_dim, embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"], num_heads=params["dec_num_heads"],
proj_value=True, proj_value=True,
dropout=params["dec_att_dropout"], dropout=params["dec_att_dropout"],
) )
self.norm1 = LayerNorm(self.emb_dim) self.norm1 = LayerNorm(params["enc_dim"])
self.att = CustomMultiHeadAttention( self.att = CustomMultiHeadAttention(
embed_dim=self.emb_dim, embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"], num_heads=params["dec_num_heads"],
proj_value=True, proj_value=True,
dropout=params["dec_att_dropout"], dropout=params["dec_att_dropout"],
) )
self.linear1 = Linear(self.emb_dim, self.dim_feedforward) self.linear1 = Linear(params["enc_dim"], params["dec_dim_feedforward"])
self.linear2 = Linear(self.dim_feedforward, self.emb_dim) self.linear2 = Linear(params["dec_dim_feedforward"], params["enc_dim"])
self.dropout = Dropout(params["dec_res_dropout"]) self.dropout = Dropout(params["dec_res_dropout"])
self.norm2 = LayerNorm(self.emb_dim) self.norm2 = LayerNorm(params["enc_dim"])
self.norm3 = LayerNorm(self.emb_dim) self.norm3 = LayerNorm(params["enc_dim"])
def forward( def forward(
self, self,
...@@ -319,20 +302,12 @@ class FeaturesUpdater(Module): ...@@ -319,20 +302,12 @@ class FeaturesUpdater(Module):
def __init__(self, params): def __init__(self, params):
super(FeaturesUpdater, self).__init__() super(FeaturesUpdater, self).__init__()
self.enc_dim = params["enc_dim"]
self.enc_h_max = params["h_max"]
self.enc_w_max = params["w_max"]
self.pe_2d = PositionalEncoding2D( self.pe_2d = PositionalEncoding2D(
self.enc_dim, self.enc_h_max, self.enc_w_max, params["device"] params["enc_dim"], params["h_max"], params["w_max"], params["device"]
)
self.use_2d_positional_encoding = (
"use_2d_pe" not in params or params["use_2d_pe"]
) )
def get_pos_features(self, features): def get_pos_features(self, features):
if self.use_2d_positional_encoding: return self.pe_2d(features)
return self.pe_2d(features)
return features
class GlobalHTADecoder(Module): class GlobalHTADecoder(Module):
...@@ -342,31 +317,23 @@ class GlobalHTADecoder(Module): ...@@ -342,31 +317,23 @@ class GlobalHTADecoder(Module):
def __init__(self, params): def __init__(self, params):
super(GlobalHTADecoder, self).__init__() super(GlobalHTADecoder, self).__init__()
self.enc_dim = params["enc_dim"]
self.dec_l_max = params["l_max"]
self.dropout = Dropout(params["dec_pred_dropout"]) self.dropout = Dropout(params["dec_pred_dropout"])
self.dec_att_win = ( self.dec_att_win = (
params["attention_win"] if params["attention_win"] is not None else 1 params["attention_win"] if params["attention_win"] is not None else 1
) )
self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
self.use_lstm = params["use_lstm"]
self.features_updater = FeaturesUpdater(params) self.features_updater = FeaturesUpdater(params)
self.att_decoder = GlobalAttDecoder(params) self.att_decoder = GlobalAttDecoder(params)
self.emb = Embedding( self.emb = Embedding(
num_embeddings=params["vocab_size"] + 3, embedding_dim=self.enc_dim num_embeddings=params["vocab_size"] + 3, embedding_dim=params["enc_dim"]
) )
self.pe_1d = PositionalEncoding1D( self.pe_1d = PositionalEncoding1D(
self.enc_dim, self.dec_l_max, params["device"] params["enc_dim"], params["l_max"], params["device"]
) )
if self.use_lstm:
self.lstm_predict = LSTM(self.enc_dim, self.enc_dim)
vocab_size = params["vocab_size"] + 1 vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1) self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
def forward( def forward(
self, self,
...@@ -388,9 +355,7 @@ class GlobalHTADecoder(Module): ...@@ -388,9 +355,7 @@ class GlobalHTADecoder(Module):
pos_tokens = self.emb(tokens).permute(0, 2, 1) pos_tokens = self.emb(tokens).permute(0, 2, 1)
# Add 1D Positional Encoding # Add 1D Positional Encoding
if self.use_1d_pe: pos_tokens = self.pe_1d(pos_tokens, start=start).permute(2, 0, 1)
pos_tokens = self.pe_1d(pos_tokens, start=start)
pos_tokens = pos_tokens.permute(2, 0, 1)
if num_pred is None: if num_pred is None:
num_pred = tokens.size(1) num_pred = tokens.size(1)
...@@ -440,9 +405,6 @@ class GlobalHTADecoder(Module): ...@@ -440,9 +405,6 @@ class GlobalHTADecoder(Module):
keep_all_weights=keep_all_weights, keep_all_weights=keep_all_weights,
) )
if self.use_lstm:
output, hidden_predict = self.lstm_predict(output, hidden_predict)
dp_output = self.dropout(relu(output)) dp_output = self.dropout(relu(output))
preds = self.end_conv(dp_output.permute(1, 2, 0)) preds = self.end_conv(dp_output.permute(1, 2, 0))
......
...@@ -92,9 +92,7 @@ class FCN_Encoder(Module): ...@@ -92,9 +92,7 @@ class FCN_Encoder(Module):
self.init_blocks = ModuleList( self.init_blocks = ModuleList(
[ [
ConvBlock( ConvBlock(3, 16, stride=(1, 1), dropout=self.dropout),
params["input_channels"], 16, stride=(1, 1), dropout=self.dropout
),
ConvBlock(16, 32, stride=(2, 2), dropout=self.dropout), ConvBlock(16, 32, stride=(2, 2), dropout=self.dropout),
ConvBlock(32, 64, stride=(2, 2), dropout=self.dropout), ConvBlock(32, 64, stride=(2, 2), dropout=self.dropout),
ConvBlock(64, 128, stride=(2, 2), dropout=self.dropout), ConvBlock(64, 128, stride=(2, 2), dropout=self.dropout),
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import copy
import json import json
import os import os
import random
import numpy as np import numpy as np
import torch from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from dan.datasets.utils import natural_sort from dan.datasets.utils import natural_sort
from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms from dan.utils import read_image, token_to_ind
class DatasetManager: class OCRDataset(Dataset):
def __init__(self, params, device: str): """
self.params = params Dataset class to handle dataset loading
self.dataset_class = None """
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None def __init__(
self.valid_loaders = dict() self,
self.test_loaders = dict() set_name,
paths_and_sets,
charset,
tokens,
preprocessing_transforms,
augmentation_transforms,
load_in_memory=False,
mean=None,
std=None,
):
self.set_name = set_name
self.charset = charset
self.tokens = tokens
self.load_in_memory = load_in_memory
self.mean = mean
self.std = std
self.train_sampler = None # Pre-processing, augmentation
self.valid_samplers = dict() self.preprocessing_transforms = preprocessing_transforms
self.test_samplers = dict() self.augmentation_transforms = augmentation_transforms
self.generator = torch.Generator() # Factor to reduce the height and width of the feature vector before feeding the decoder.
self.generator.manual_seed(0) self.reduce_dims_factor = np.array([32, 8, 1])
self.batch_size = { # Load samples and preprocess images if load_in_memory is True
"train": self.params["batch_size"], self.samples = self.load_samples(paths_and_sets)
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
def apply_specific_treatment_after_dataset_loading(self, dataset): # Curriculum config
raise NotImplementedError self.curriculum_config = None
def load_datasets(self): def __len__(self):
""" """
Load training and validation datasets Return the dataset size
""" """
self.train_dataset = self.dataset_class( return len(self.samples)
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
(
self.params["config"]["mean"],
self.params["config"]["std"],
) = self.train_dataset.compute_std_mean()
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
def load_ddp_samplers(self): def __getitem__(self, idx):
""" """
Load training and validation data samplers Return an item from the dataset (image and label)
""" """
if self.params["use_ddp"]: # Load preprocessed image
self.train_sampler = DistributedSampler( sample = copy.deepcopy(self.samples[idx])
self.train_dataset, if not self.load_in_memory:
num_replicas=self.params["num_gpu"], sample["img"] = self.get_sample_img(idx)
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self): # Convert to numpy
""" sample["img"] = np.array(sample["img"])
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys(): # Apply data augmentation
self.valid_loaders[key] = DataLoader( if self.augmentation_transforms:
self.valid_datasets[key], sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod # Image normalization
def seed_worker(worker_id): sample["img"] = (sample["img"] - self.mean) / self.std
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list): # Get final height and width
""" sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
Load test dataset, data sampler and data loader sample["img"]
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params, "test", custom_name, paths_and_sets
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
) )
def get_paths_and_sets(self, dataset_names_folds): # Convert label into tokens
paths_and_sets = list() sample["token_label"], sample["label_len"] = self.convert_sample_label(
for dataset_name, fold in dataset_names_folds: sample["label"]
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.mean = (
np.array(params["config"]["mean"])
if "mean" in params["config"].keys()
else None
)
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
) )
self.preprocessing_transforms = get_preprocessing_transforms( return sample
params["config"]["preprocessings"]
)
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
self.samples = self.load_samples(
paths_and_sets, load_in_memory=self.load_in_memory
)
if self.load_in_memory:
self.preprocess_all_images()
self.curriculum_config = None
def __len__(self):
return len(self.samples)
@staticmethod def load_samples(self, paths_and_sets):
def load_image(path):
with Image.open(path) as pil_img:
return pil_img.convert("RGB")
@staticmethod
def load_samples(paths_and_sets, load_in_memory=True):
""" """
Load images and labels Load images and labels
""" """
...@@ -266,16 +107,20 @@ class GenericDataset(Dataset): ...@@ -266,16 +107,20 @@ class GenericDataset(Dataset):
"path": os.path.abspath(filename), "path": os.path.abspath(filename),
} }
) )
if load_in_memory: if self.load_in_memory:
samples[-1]["img"] = GenericDataset.load_image(filename) samples[-1]["img"] = self.preprocessing_transforms(
read_image(filename)
)
return samples return samples
def preprocess_all_images(self) -> None: def get_sample_img(self, i):
""" """
Iterate over all samples and apply pre-processing Get image by index
""" """
for i, sample in enumerate(self.samples): if self.load_in_memory:
self.samples[i]["img"] = self.preprocessing_transforms(sample["img"]) return self.samples[i]["img"]
return self.preprocessing_transforms(read_image(self.samples[i]["path"]))
def compute_std_mean(self): def compute_std_mean(self):
""" """
...@@ -284,34 +129,46 @@ class GenericDataset(Dataset): ...@@ -284,34 +129,46 @@ class GenericDataset(Dataset):
if self.mean is not None and self.std is not None: if self.mean is not None and self.std is not None:
return self.mean, self.std return self.mean, self.std
sum = np.zeros((3,)) total = np.zeros((3,))
diff = np.zeros((3,)) diff = np.zeros((3,))
nb_pixels = 0 nb_pixels = 0
for metric in ["mean", "std"]: for metric in ["mean", "std"]:
for ind in range(len(self.samples)): for ind in range(len(self.samples)):
img = np.array( img = np.array(self.get_sample_img(ind))
self.get_sample_img(ind)
if self.load_in_memory
else self.preprocessing_transforms(self.get_sample_img(ind)),
)
if metric == "mean": if metric == "mean":
sum += np.sum(img, axis=(0, 1)) total += np.sum(img, axis=(0, 1))
nb_pixels += np.prod(img.shape[:2]) nb_pixels += np.prod(img.shape[:2])
elif metric == "std": elif metric == "std":
diff += [ diff += [
np.sum((img[:, :, k] - self.mean[k]) ** 2) for k in range(3) np.sum((img[:, :, k] - self.mean[k]) ** 2) for k in range(3)
] ]
if metric == "mean": if metric == "mean":
self.mean = sum / nb_pixels self.mean = total / nb_pixels
elif metric == "std": elif metric == "std":
self.std = np.sqrt(diff / nb_pixels) self.std = np.sqrt(diff / nb_pixels)
return self.mean, self.std return self.mean, self.std
def get_sample_img(self, i): def compute_final_size(self, img):
""" """
Get image by index Compute the final image size and position after feature extraction
""" """
if self.load_in_memory: image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
return self.samples[i]["img"]
else: if self.set_name == "train":
return GenericDataset.load_image(self.samples[i]["path"]) image_reduced_shape = [max(1, t) for t in image_reduced_shape]
image_position = [
[0, img.shape[0]],
[0, img.shape[1]],
]
return image_reduced_shape, image_position
def convert_sample_label(self, label):
"""
Tokenize the label and return its length
"""
token_label = token_to_ind(self.charset, label)
token_label.append(self.tokens["end"])
label_len = len(token_label)
token_label.insert(0, self.tokens["start"])
return token_label, label_len
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import pickle import pickle
import random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from dan.manager.dataset import DatasetManager, GenericDataset from dan.manager.dataset import OCRDataset
from dan.utils import pad_images, pad_sequences_1D, token_to_ind from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import pad_images, pad_sequences_1D
class OCRDatasetManager(DatasetManager): class OCRDatasetManager:
"""
Specific class to handle OCR/HTR tasks
"""
def __init__(self, params, device: str): def __init__(self, params, device: str):
super(OCRDatasetManager, self).__init__(params, device) self.params = params
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
self.dataset_class = OCRDataset self.mean = (
self.charset = ( np.array(params["config"]["mean"])
params["charset"] if "charset" in params else self.get_merged_charsets() if "mean" in params["config"].keys()
else None
) )
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
)
self.generator = torch.Generator()
self.generator.manual_seed(0)
self.tokens = {"pad": len(self.charset) + 2} self.load_in_memory = (
self.tokens["end"] = len(self.charset) self.params["config"]["load_in_memory"]
self.tokens["start"] = len(self.charset) + 1 if "load_in_memory" in self.params["config"]
else True
)
self.charset = self.get_charset()
self.tokens = self.get_tokens()
self.params["config"]["padding_token"] = self.tokens["pad"] self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self): self.my_collate_function = OCRCollateFunction(self.params["config"])
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
)
self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"], to_pil_image=True
)
def load_datasets(self):
"""
Load training and validation datasets
"""
self.train_dataset = OCRDataset(
set_name="train",
paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=self.augmentation,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
self.mean, self.std = self.train_dataset.compute_std_mean()
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = OCRDataset(
set_name="val",
paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
def load_ddp_samplers(self):
"""
Load training and validation data samplers
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.params["batch_size"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=1,
sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod
def seed_worker(worker_id):
"""
Set worker seed
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = OCRDataset(
set_name="test",
paths_and_sets=paths_and_sets,
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=1,
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
"""
Set the right path for each data set
"""
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
def get_charset(self):
""" """
Merge the charset of the different datasets used Merge the charset of the different datasets used
""" """
if "charset" in self.params:
return self.params["charset"]
datasets = self.params["datasets"] datasets = self.params["datasets"]
charset = set() charset = set()
for key in datasets.keys(): for key in datasets.keys():
...@@ -40,87 +229,15 @@ class OCRDatasetManager(DatasetManager): ...@@ -40,87 +229,15 @@ class OCRDatasetManager(DatasetManager):
charset.remove("") charset.remove("")
return sorted(list(charset)) return sorted(list(charset))
def apply_specific_treatment_after_dataset_loading(self, dataset): def get_tokens(self):
dataset.charset = self.charset
dataset.tokens = self.tokens
dataset.convert_labels()
class OCRDataset(GenericDataset):
"""
Specific class to handle OCR/HTR datasets
"""
def __init__(
self,
params,
set_name,
custom_name,
paths_and_sets,
augmentation_transforms=None,
):
super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
self.charset = None
self.tokens = None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.collate_function = OCRCollateFunction
self.augmentation_transforms = augmentation_transforms
def __getitem__(self, idx):
sample = dict(**self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
# Convert to numpy
sample["img"] = np.array(sample["img"])
# Get initial height and width
initial_h, initial_w, _ = sample["img"].shape
# Data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
# Normalization
sample["img"] = (sample["img"] - self.mean) / self.std
# Get final height and width (tensor mode)
final_h, final_w, _ = sample["img"].shape
sample["resize_ratio"] = [final_h / initial_h, final_w / initial_w]
sample["img_reduced_shape"] = np.ceil(
sample["img"].shape / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
sample["img_reduced_shape"] = [
max(1, t) for t in sample["img_reduced_shape"]
]
sample["img_position"] = [
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
return sample
def convert_labels(self):
""" """
Label str to token at character level Get special tokens
""" """
for i in range(len(self.samples)): return {
self.samples[i] = self.convert_sample_labels(self.samples[i]) "end": len(self.charset),
"start": len(self.charset) + 1,
def convert_sample_labels(self, sample): "pad": len(self.charset) + 2,
label = sample["label"] }
sample["label"] = label
sample["token_label"] = token_to_ind(self.charset, label)
sample["token_label"].append(self.tokens["end"])
sample["label_len"] = len(sample["token_label"])
sample["token_label"].insert(0, self.tokens["start"])
return sample
class OCRCollateFunction: class OCRCollateFunction:
...@@ -134,12 +251,13 @@ class OCRCollateFunction: ...@@ -134,12 +251,13 @@ class OCRCollateFunction:
def __call__(self, batch_data): def __call__(self, batch_data):
labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value) labels = pad_sequences_1D(labels, padding_value=self.label_padding_value).long()
labels = torch.tensor(labels).long()
imgs = [batch_data[i]["img"] for i in range(len(batch_data))] imgs = [
torch.from_numpy(batch_data[i]["img"]).permute(2, 0, 1)
for i in range(len(batch_data))
]
imgs = pad_images(imgs) imgs = pad_images(imgs)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = { formatted_batch_data = {
formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))] formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))]
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import os import os
import random import random
from copy import deepcopy from copy import deepcopy
from enum import Enum
from time import time from time import time
import numpy as np import numpy as np
...@@ -21,7 +21,7 @@ from dan.manager.metrics import MetricManager ...@@ -21,7 +21,7 @@ from dan.manager.metrics import MetricManager
from dan.manager.ocr import OCRDatasetManager from dan.manager.ocr import OCRDatasetManager
from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.schedulers import DropoutScheduler from dan.schedulers import DropoutScheduler
from dan.utils import ind_to_token from dan.utils import fix_ddp_layers_names, ind_to_token
if MLFLOW_AVAILABLE: if MLFLOW_AVAILABLE:
import mlflow import mlflow
...@@ -34,7 +34,6 @@ class GenericTrainingManager: ...@@ -34,7 +34,6 @@ class GenericTrainingManager:
self.params = params self.params = params
self.dropout_scheduler = None self.dropout_scheduler = None
self.models = {} self.models = {}
self.begin_time = None
self.dataset = None self.dataset = None
self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0] self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0]
self.paths = None self.paths = None
...@@ -56,6 +55,11 @@ class GenericTrainingManager: ...@@ -56,6 +55,11 @@ class GenericTrainingManager:
self.params["model_params"]["use_amp"] = self.params["training_params"][ self.params["model_params"]["use_amp"] = self.params["training_params"][
"use_amp" "use_amp"
] ]
self.nb_gpu = (
self.params["training_params"]["nb_gpu"]
if self.params["training_params"]["use_ddp"]
else 1
)
def init_paths(self): def init_paths(self):
""" """
...@@ -84,14 +88,6 @@ class GenericTrainingManager: ...@@ -84,14 +88,6 @@ class GenericTrainingManager:
self.params["dataset_params"]["batch_size"] = self.params["training_params"][ self.params["dataset_params"]["batch_size"] = self.params["training_params"][
"batch_size" "batch_size"
] ]
if "valid_batch_size" in self.params["training_params"]:
self.params["dataset_params"]["valid_batch_size"] = self.params[
"training_params"
]["valid_batch_size"]
if "test_batch_size" in self.params["training_params"]:
self.params["dataset_params"]["test_batch_size"] = self.params[
"training_params"
]["test_batch_size"]
self.params["dataset_params"]["num_gpu"] = self.params["training_params"][ self.params["dataset_params"]["num_gpu"] = self.params["training_params"][
"nb_gpu" "nb_gpu"
] ]
...@@ -193,7 +189,9 @@ class GenericTrainingManager: ...@@ -193,7 +189,9 @@ class GenericTrainingManager:
# make the model compatible with Distributed Data Parallel if used # make the model compatible with Distributed Data Parallel if used
if self.params["training_params"]["use_ddp"]: if self.params["training_params"]["use_ddp"]:
self.models[model_name] = DDP( self.models[model_name] = DDP(
self.models[model_name], [self.ddp_config["rank"]] self.models[model_name],
[self.ddp_config["rank"]],
output_device=self.ddp_config["rank"],
) )
# Handle curriculum dropout # Handle curriculum dropout
...@@ -223,7 +221,10 @@ class GenericTrainingManager: ...@@ -223,7 +221,10 @@ class GenericTrainingManager:
if self.params["training_params"]["load_epoch"] in ("best", "last"): if self.params["training_params"]["load_epoch"] in ("best", "last"):
for filename in os.listdir(self.paths["checkpoints"]): for filename in os.listdir(self.paths["checkpoints"]):
if self.params["training_params"]["load_epoch"] in filename: if self.params["training_params"]["load_epoch"] in filename:
return torch.load(os.path.join(self.paths["checkpoints"], filename)) return torch.load(
os.path.join(self.paths["checkpoints"], filename),
map_location=self.device,
)
return None return None
def load_existing_model(self, checkpoint, strict=True): def load_existing_model(self, checkpoint, strict=True):
...@@ -239,8 +240,14 @@ class GenericTrainingManager: ...@@ -239,8 +240,14 @@ class GenericTrainingManager:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
# Load model weights from past training # Load model weights from past training
for model_name in self.models.keys(): for model_name in self.models.keys():
# Transform to DDP/from DDP model
checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
checkpoint[f"{model_name}_state_dict"],
self.params["training_params"]["use_ddp"],
)
self.models[model_name].load_state_dict( self.models[model_name].load_state_dict(
checkpoint["{}_state_dict".format(model_name)], strict=strict checkpoint[f"{model_name}_state_dict"], strict=strict
) )
def init_new_model(self): def init_new_model(self):
...@@ -261,8 +268,15 @@ class GenericTrainingManager: ...@@ -261,8 +268,15 @@ class GenericTrainingManager:
state_dict_name, path, learnable, strict = self.params["model_params"][ state_dict_name, path, learnable, strict = self.params["model_params"][
"transfer_learning" "transfer_learning"
][model_name] ][model_name]
# Loading pretrained weights file # Loading pretrained weights file
checkpoint = torch.load(path) checkpoint = torch.load(path, map_location=self.device)
# Transform to DDP/from DDP model
checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
checkpoint[f"{model_name}_state_dict"],
self.params["training_params"]["use_ddp"],
)
try: try:
# Load pretrained weights for model # Load pretrained weights for model
self.models[model_name].load_state_dict( self.models[model_name].load_state_dict(
...@@ -462,23 +476,37 @@ class GenericTrainingManager: ...@@ -462,23 +476,37 @@ class GenericTrainingManager:
def save_params(self): def save_params(self):
""" """
Output text file containing a summary of all hyperparameters chosen for the training Output a yaml file containing a summary of all hyperparameters chosen for the training
and a yaml file containing parameters used for inference
""" """
def compute_nb_params(module): def compute_nb_params(module):
return sum([np.prod(p.size()) for p in list(module.parameters())]) return sum([np.prod(p.size()) for p in list(module.parameters())])
def class_to_str_dict(my_dict): def class_to_str_dict(my_dict):
for key in my_dict.keys(): for key in my_dict:
if callable(my_dict[key]): if key == "preprocessings":
my_dict[key] = [
{
key: value.value if isinstance(value, Enum) else value
for key, value in preprocessing.items()
}
for preprocessing in my_dict[key]
]
elif callable(my_dict[key]):
my_dict[key] = my_dict[key].__name__ my_dict[key] = my_dict[key].__name__
elif isinstance(my_dict[key], np.ndarray): elif isinstance(my_dict[key], np.ndarray):
my_dict[key] = my_dict[key].tolist() my_dict[key] = my_dict[key].tolist()
elif isinstance(my_dict[key], list) and isinstance(
my_dict[key][0], tuple
):
my_dict[key] = [list(elt) for elt in my_dict[key]]
elif isinstance(my_dict[key], dict): elif isinstance(my_dict[key], dict):
my_dict[key] = class_to_str_dict(my_dict[key]) my_dict[key] = class_to_str_dict(my_dict[key])
return my_dict return my_dict
path = os.path.join(self.paths["results"], "params") # Save training parameters
path = os.path.join(self.paths["results"], "training_parameters.yml")
if os.path.isfile(path): if os.path.isfile(path):
return return
params = class_to_str_dict(my_dict=deepcopy(self.params)) params = class_to_str_dict(my_dict=deepcopy(self.params))
...@@ -491,8 +519,45 @@ class GenericTrainingManager: ...@@ -491,8 +519,45 @@ class GenericTrainingManager:
] ]
total_params += current_params total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params) params["model_params"]["total_params"] = "{:,}".format(total_params)
params["mean"] = self.dataset.mean.tolist()
params["std"] = self.dataset.std.tolist()
with open(path, "w") as f:
yaml.dump(params, f)
# Save inference parameters
path = os.path.join(self.paths["results"], "inference_parameters.yml")
if os.path.isfile(path):
return
inference_params = {
"parameters": {
"mean": params["mean"],
"std": params["std"],
"max_char_prediction": params["training_params"]["max_char_prediction"],
"encoder": {
"dropout": params["model_params"]["dropout"],
},
"decoder": {
key: params["model_params"][key]
for key in [
"enc_dim",
"l_max",
"h_max",
"w_max",
"dec_num_layers",
"dec_num_heads",
"dec_res_dropout",
"dec_pred_dropout",
"dec_att_dropout",
"dec_dim_feedforward",
"vocab_size",
"attention_win",
]
},
"preprocessings": params["dataset_params"]["config"]["preprocessings"],
},
}
with open(path, "w") as f: with open(path, "w") as f:
json.dump(params, f, indent=4) yaml.dump(inference_params, f)
def backward_loss(self, loss, retain_graph=False): def backward_loss(self, loss, retain_graph=False):
self.scaler.scale(loss).backward(retain_graph=retain_graph) self.scaler.scale(loss).backward(retain_graph=retain_graph)
...@@ -529,10 +594,7 @@ class GenericTrainingManager: ...@@ -529,10 +594,7 @@ class GenericTrainingManager:
self.writer = SummaryWriter(self.paths["results"]) self.writer = SummaryWriter(self.paths["results"])
self.save_params() self.save_params()
# init variables # init variables
self.begin_time = time()
focus_metric_name = self.params["training_params"]["focus_metric"]
nb_epochs = self.params["training_params"]["max_nb_epochs"] nb_epochs = self.params["training_params"]["max_nb_epochs"]
interval_save_weights = self.params["training_params"]["interval_save_weights"]
metric_names = self.params["training_params"]["train_metrics"] metric_names = self.params["training_params"]["train_metrics"]
display_values = None display_values = None
...@@ -544,13 +606,6 @@ class GenericTrainingManager: ...@@ -544,13 +606,6 @@ class GenericTrainingManager:
self.init_curriculum() self.init_curriculum()
# perform epochs # perform epochs
for num_epoch in range(self.latest_epoch + 1, nb_epochs): for num_epoch in range(self.latest_epoch + 1, nb_epochs):
# Check maximum training time stop condition
if (
self.params["training_params"]["max_training_time"]
and time() - self.begin_time
> self.params["training_params"]["max_training_time"]
):
break
# set models trainable # set models trainable
for model_name in self.models.keys(): for model_name in self.models.keys():
self.models[model_name].train() self.models[model_name].train()
...@@ -563,7 +618,6 @@ class GenericTrainingManager: ...@@ -563,7 +618,6 @@ class GenericTrainingManager:
self.metric_manager["train"] = MetricManager( self.metric_manager["train"] = MetricManager(
metric_names=metric_names, dataset_name=self.dataset_name metric_names=metric_names, dataset_name=self.dataset_name
) )
with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar: with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar:
pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
# iterates over mini-batch data # iterates over mini-batch data
...@@ -612,7 +666,7 @@ class GenericTrainingManager: ...@@ -612,7 +666,7 @@ class GenericTrainingManager:
self.metric_manager["train"].update_metrics(batch_metrics) self.metric_manager["train"].update_metrics(batch_metrics)
display_values = self.metric_manager["train"].get_display_values() display_values = self.metric_manager["train"].get_display_values()
pbar.set_postfix(values=str(display_values)) pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"])) pbar.update(len(batch_data["names"]) * self.nb_gpu)
# Log MLflow metrics # Log MLflow metrics
logging_metrics( logging_metrics(
...@@ -651,25 +705,9 @@ class GenericTrainingManager: ...@@ -651,25 +705,9 @@ class GenericTrainingManager:
) )
if valid_set_name == self.params["training_params"][ if valid_set_name == self.params["training_params"][
"set_name_focus_metric" "set_name_focus_metric"
] and ( ] and (self.best is None or eval_values["cer"] <= self.best):
self.best is None
or (
eval_values[focus_metric_name] <= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "low"
)
or (
eval_values[focus_metric_name] >= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "high"
)
):
self.save_model(epoch=num_epoch, name="best") self.save_model(epoch=num_epoch, name="best")
self.best = eval_values[focus_metric_name] self.best = eval_values["cer"]
# Handle curriculum learning update # Handle curriculum learning update
if self.dataset.train_dataset.curriculum_config: if self.dataset.train_dataset.curriculum_config:
...@@ -684,8 +722,6 @@ class GenericTrainingManager: ...@@ -684,8 +722,6 @@ class GenericTrainingManager:
# save model weights # save model weights
if self.is_master: if self.is_master:
self.save_model(epoch=num_epoch, name="last") self.save_model(epoch=num_epoch, name="last")
if interval_save_weights and num_epoch % interval_save_weights == 0:
self.save_model(epoch=num_epoch, name="weights", keep_weights=True)
self.writer.flush() self.writer.flush()
def evaluate(self, set_name, mlflow_logging=False, **kwargs): def evaluate(self, set_name, mlflow_logging=False, **kwargs):
...@@ -723,7 +759,7 @@ class GenericTrainingManager: ...@@ -723,7 +759,7 @@ class GenericTrainingManager:
display_values = self.metric_manager[set_name].get_display_values() display_values = self.metric_manager[set_name].get_display_values()
pbar.set_postfix(values=str(display_values)) pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"])) pbar.update(len(batch_data["names"]) * self.nb_gpu)
# log metrics in MLflow # log metrics in MLflow
logging_metrics( logging_metrics(
...@@ -775,7 +811,7 @@ class GenericTrainingManager: ...@@ -775,7 +811,7 @@ class GenericTrainingManager:
].get_display_values() ].get_display_values()
pbar.set_postfix(values=str(display_values)) pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"])) pbar.update(len(batch_data["names"]) * self.nb_gpu)
# log metrics in MLflow # log metrics in MLflow
logging_name = custom_name.split("-")[1] logging_name = custom_name.split("-")[1]
...@@ -977,9 +1013,14 @@ class Manager(OCRManager): ...@@ -977,9 +1013,14 @@ class Manager(OCRManager):
features_size = raw_features.size() features_size = raw_features.size()
b, c, h, w = features_size b, c, h, w = features_size
pos_features = self.models["decoder"].features_updater.get_pos_features( if self.params["training_params"]["use_ddp"]:
raw_features pos_features = self.models[
) "decoder"
].module.features_updater.get_pos_features(raw_features)
else:
pos_features = self.models["decoder"].features_updater.get_pos_features(
raw_features
)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1 2, 0, 1
) )
...@@ -1072,9 +1113,14 @@ class Manager(OCRManager): ...@@ -1072,9 +1113,14 @@ class Manager(OCRManager):
else: else:
features = self.models["encoder"](x) features = self.models["encoder"](x)
features_size = features.size() features_size = features.size()
pos_features = self.models["decoder"].features_updater.get_pos_features( if self.params["training_params"]["use_ddp"]:
features pos_features = self.models[
) "decoder"
].module.features_updater.get_pos_features(features)
else:
pos_features = self.models["decoder"].features_updater.get_pos_features(
features
)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1 2, 0, 1
) )
......
...@@ -138,7 +138,6 @@ def get_config(): ...@@ -138,7 +138,6 @@ def get_config():
}, },
"transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model "transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model
"additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset "additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset
"input_channels": 3, # number of channels of input image
"dropout": 0.5, # dropout rate for encoder "dropout": 0.5, # dropout rate for encoder
"enc_dim": 256, # dimension of extracted features "enc_dim": 256, # dimension of extracted features
"nb_layers": 5, # encoder "nb_layers": 5, # encoder
...@@ -151,9 +150,6 @@ def get_config(): ...@@ -151,9 +150,6 @@ def get_config():
"dec_pred_dropout": 0.1, # dropout rate before decision layer "dec_pred_dropout": 0.1, # dropout rate before decision layer
"dec_att_dropout": 0.1, # dropout rate in multi head attention "dec_att_dropout": 0.1, # dropout rate in multi head attention
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window "attention_win": 100, # length of attention window
# Curriculum dropout # Curriculum dropout
"dropout_scheduler": { "dropout_scheduler": {
...@@ -163,14 +159,9 @@ def get_config(): ...@@ -163,14 +159,9 @@ def get_config():
}, },
"training_params": { "training_params": {
"output_folder": "outputs/dan_esposalles_record", # folder name for checkpoint and results "output_folder": "outputs/dan_esposalles_record", # folder name for checkpoint and results
"max_nb_epochs": 710, # maximum number of epochs before to stop "max_nb_epochs": 800, # maximum number of epochs before to stop
"max_training_time": 3600
* 24
* 1.9, # maximum time before to stop (in seconds)
"load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate "load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate
"interval_save_weights": None, # None: keep best and last only
"batch_size": 2, # mini-batch size for training "batch_size": 2, # mini-batch size for training
"valid_batch_size": 4, # mini-batch size for valdiation
"use_ddp": False, # Use DistributedDataParallel "use_ddp": False, # Use DistributedDataParallel
"ddp_port": "20027", "ddp_port": "20027",
"use_amp": True, # Enable automatic mix-precision "use_amp": True, # Enable automatic mix-precision
...@@ -187,8 +178,6 @@ def get_config(): ...@@ -187,8 +178,6 @@ def get_config():
"lr_schedulers": None, # Learning rate schedulers "lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training "eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format( "set_name_focus_metric": "{}-val".format(
dataset_name dataset_name
), # Which dataset to focus on to select best weights ), # Which dataset to focus on to select best weights
...@@ -258,18 +247,18 @@ def serialize_config(config): ...@@ -258,18 +247,18 @@ def serialize_config(config):
return serialized_config return serialized_config
def start_training(config) -> None: def start_training(config, mlflow_logging: bool) -> None:
if ( if (
config["training_params"]["use_ddp"] config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"] and not config["training_params"]["force_cpu"]
): ):
mp.spawn( mp.spawn(
train_and_test, train_and_test,
args=(config, True), args=(config, mlflow_logging),
nprocs=config["training_params"]["nb_gpu"], nprocs=config["training_params"]["nb_gpu"],
) )
else: else:
train_and_test(0, config, True) train_and_test(0, config, mlflow_logging)
def run(): def run():
...@@ -286,7 +275,7 @@ def run(): ...@@ -286,7 +275,7 @@ def run():
raise MLflowNotInstalled() raise MLflowNotInstalled()
if "mlflow" not in config: if "mlflow" not in config:
start_training(config) start_training(config, mlflow_logging=False)
else: else:
labels_path = ( labels_path = (
Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
...@@ -314,4 +303,4 @@ def run(): ...@@ -314,4 +303,4 @@ def run():
dictionary=artifact, dictionary=artifact,
artifact_file=filename, artifact_file=filename,
) )
start_training(config) start_training(config, mlflow_logging=True)
...@@ -18,6 +18,7 @@ def add_predict_parser(subcommands) -> None: ...@@ -18,6 +18,7 @@ def add_predict_parser(subcommands) -> None:
image_or_folder_input = parser.add_mutually_exclusive_group(required=True) image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
image_or_folder_input.add_argument( image_or_folder_input.add_argument(
"--image", "--image",
type=pathlib.Path,
help="Path to the image to predict.", help="Path to the image to predict.",
) )
image_or_folder_input.add_argument( image_or_folder_input.add_argument(
...@@ -50,6 +51,12 @@ def add_predict_parser(subcommands) -> None: ...@@ -50,6 +51,12 @@ def add_predict_parser(subcommands) -> None:
help="Path to the output folder.", help="Path to the output folder.",
required=True, required=True,
) )
parser.add_argument(
"--tokens",
type=pathlib.Path,
required=True,
help="Path to a yaml file containing a mapping between starting tokens and end tokens. Needed for entities.",
)
# Optional arguments. # Optional arguments.
parser.add_argument( parser.add_argument(
"--image-extension", "--image-extension",
...@@ -57,26 +64,12 @@ def add_predict_parser(subcommands) -> None: ...@@ -57,26 +64,12 @@ def add_predict_parser(subcommands) -> None:
help="The extension of the images in the folder.", help="The extension of the images in the folder.",
default=".jpg", default=".jpg",
) )
parser.add_argument(
"--scale",
type=float,
default=1.0,
required=False,
help="Image scaling factor before feeding it to DAN",
)
parser.add_argument(
"--image-max-width",
type=int,
default=1800,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument( parser.add_argument(
"--temperature", "--temperature",
type=float, type=float,
default=1.0, default=1.0,
help="Temperature scaling scalar parameter", help="Temperature scaling scalar parameter",
required=True, required=False,
) )
parser.add_argument( parser.add_argument(
"--confidence-score", "--confidence-score",
...@@ -147,4 +140,18 @@ def add_predict_parser(subcommands) -> None: ...@@ -147,4 +140,18 @@ def add_predict_parser(subcommands) -> None:
type=int, type=int,
default=0, default=0,
) )
parser.add_argument(
"--gpu-device",
help="Use a specific GPU if available.",
type=int,
required=False,
)
parser.add_argument(
"--batch-size",
help="Size of prediction batches.",
type=int,
default=1,
required=False,
)
parser.set_defaults(func=run) parser.set_defaults(func=run)
...@@ -4,6 +4,7 @@ import re ...@@ -4,6 +4,7 @@ import re
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan import logger from dan import logger
...@@ -70,14 +71,18 @@ def split_text_and_confidences( ...@@ -70,14 +71,18 @@ def split_text_and_confidences(
texts = list(text) texts = list(text)
offset = 0 offset = 0
elif level == "word": elif level == "word":
texts, probs = compute_prob_by_separator(text, confidences, word_separators) texts, confidences = compute_prob_by_separator(
text, confidences, word_separators
)
offset = 1 offset = 1
elif level == "line": elif level == "line":
texts, probs = compute_prob_by_separator(text, confidences, line_separators) texts, confidences = compute_prob_by_separator(
text, confidences, line_separators
)
offset = 1 offset = 1
else: else:
logger.error("Level should be either 'char', 'word', or 'line'") logger.error("Level should be either 'char', 'word', or 'line'")
return texts, [np.around(num, 2) for num in probs], offset return texts, [np.around(num, 2) for num in confidences], offset
def get_predicted_polygons_with_confidence( def get_predicted_polygons_with_confidence(
...@@ -175,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale): ...@@ -175,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale):
blend = Image.composite(image, coverage_vector, mask) blend = Image.composite(image, coverage_vector, mask)
# Resize to save time # Resize to save time
blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS) blend = blend.resize((int(width * scale), int(height * scale)), Image.LANCZOS)
return blend return blend
...@@ -288,7 +293,7 @@ def plot_attention( ...@@ -288,7 +293,7 @@ def plot_attention(
): ):
""" """
Create a gif by blending attention maps to the image for each text piece (char, word or line) Create a gif by blending attention maps to the image for each text piece (char, word or line)
:param image: Input image in PIL format :param image: Input image as torch.Tensor
:param text: Text predicted by DAN :param text: Text predicted by DAN
:param weights: Attention weights of size (n_char, feature_height, feature_width) :param weights: Attention weights of size (n_char, feature_height, feature_width)
:param level: Level to display (must be in [char, word, line]) :param level: Level to display (must be in [char, word, line])
...@@ -298,13 +303,11 @@ def plot_attention( ...@@ -298,13 +303,11 @@ def plot_attention(
:param line_separators: List of line separators :param line_separators: List of line separators
:param display_polygons: Whether to plot extracted polygons :param display_polygons: Whether to plot extracted polygons
""" """
image = to_pil_image(image)
height, width, _ = image.shape
attention_map = [] attention_map = []
# Convert to PIL Image and create mask # Convert to PIL Image and create mask
mask = Image.new("L", (width, height), color=(110)) mask = Image.new("L", (image.width, image.height), color=(110))
image = Image.fromarray(image)
# Split text into characters, words or lines # Split text into characters, words or lines
text_list, offset = split_text(text, level, word_separators, line_separators) text_list, offset = split_text(text, level, word_separators, line_separators)
...@@ -316,7 +319,7 @@ def plot_attention( ...@@ -316,7 +319,7 @@ def plot_attention(
for text_piece in text_list: for text_piece in text_list:
# Accumulate weights for the current word/line and resize to original image size # Accumulate weights for the current word/line and resize to original image size
coverage_vector = compute_coverage( coverage_vector = compute_coverage(
text_piece, max_value, tot_len, weights, (width, height) text_piece, max_value, tot_len, weights, (image.width, image.height)
) )
# Get polygons if flag is set: # Get polygons if flag is set:
...@@ -329,7 +332,7 @@ def plot_attention( ...@@ -329,7 +332,7 @@ def plot_attention(
weights, weights,
threshold_method=threshold_method, threshold_method=threshold_method,
threshold_value=threshold_value, threshold_value=threshold_value,
size=(width, height), size=(image.width, image.height),
) )
if contour is not None: if contour is not None:
......