Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (43)
Showing
with 863 additions and 893 deletions
[flake8]
max-line-length = 120
exclude=.cache,.eggs,.git
# Flake8 ignores multiple errors by default;
# the only interesting ignore is W503, which goes against PEP8.
# See https://lintlyci.github.io/Flake8Rules/rules/W503.html
ignore = E203,E501,W503
......@@ -3,6 +3,10 @@ stages:
- build
- deploy
variables:
# Submodule clone
GIT_SUBMODULE_STRATEGY: recursive
lint:
image: python:3.10
stage: test
......@@ -34,14 +38,10 @@ test:
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
ARKINDEX_API_SCHEMA_URL: schema.yml
before_script:
- pip install tox
# Download OpenAPI schema from last backend build
- curl https://assets.teklia.com/arkindex/openapi.yml > schema.yml
# Add system deps for opencv
- apt-get update -q
- apt-get install -q -y libgl1
......
[submodule "line_image_extractor"]
path = teklia_line_image_extractor
url = ../line_image_extractor.git
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.282
hooks:
- id: isort
args: ["--profile", "black"]
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/ambv/black
rev: 23.1.0
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies:
- 'flake8-coding==1.3.2'
- 'flake8-debugger==4.1.2'
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
......@@ -35,7 +29,7 @@ repos:
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.5
hooks:
- id: codespell
args: ['--write-changes']
......@@ -46,3 +40,10 @@ repos:
- repo: meta
hooks:
- id: check-useless-excludes
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
hooks:
- id: mdformat
# Optionally add plugins
additional_dependencies:
- mdformat-mkdocs[recommended]
......@@ -4,31 +4,48 @@
## Documentation
To use DAN in your own scripts, install it using pip:
To use DAN in your own environment, you need to first clone with its submodules via:
```console
```shell
git clone --recurse-submodules git@gitlab.teklia.com:atr/dan.git
```
If you forgot the `--recurse-submodules`, you can initialize the submodule using:
```shell
git submodule update --init
```
Then you can install it via pip:
```shell
pip install -e .
```
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
For more details about this package, make sure to see the documentation available at <https://atr.pages.teklia.com/dan/>.
## Development
For development and tests purpose it may be useful to install the project as a editable package with pip.
* Use a virtualenv (e.g. with virtualenvwrapper `mkvirtualenv -a . dan`)
* Install `dan` as a package (e.g. `pip install -e .`)
- Use a virtualenv (e.g. with virtualenvwrapper `mkvirtualenv -a . dan`)
- Install `dan` as a package (e.g. `pip install -e .`)
### Linter
Code syntax is analyzed before submitting the code.\
To run the linter tools suite you may use pre-commit.
```shell
pip install pre-commit
pre-commit run -a
```
### Run tests
Tests are executed with `tox` using [pytest](https://pytest.org).
To install `tox`,
```shell
pip install tox
tox
......@@ -41,10 +58,22 @@ Run a single test: `tox -- <test_path>::<test_function>`
The tests use a large file stored via [Git-LFS](https://docs.gitlab.com/ee/topics/git/lfs/). Make sure to run `git-lfs pull` before running them.
### Update documentation
Please keep the documentation updated when modifying or adding features.
It's pretty easy to do:
```shell
pip install -r doc-requirements.txt
mkdocs serve
```
You can then write in Markdown in the relevant `docs/*.md` files, and see live output on <http://localhost:8000>.
## Inference
To apply DAN to an image, one needs to first add a few imports and to load an image. Note that the image should be in RGB.
```python
import cv2
from dan.predict import DAN
......@@ -53,16 +82,18 @@ image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
```
Then one can initialize and load the trained model with the parameters used during training.
```python
model_path = 'model.pt'
params_path = 'parameters.yml'
charset_path = 'charset.pkl'
model_path = "model.pt"
params_path = "parameters.yml"
charset_path = "charset.pkl"
model = DAN('cpu')
model = DAN("cpu")
model.load(model_path, params_path, charset_path, mode="eval")
```
To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In the end, one can run the prediction:
```python
text, confidence_scores = model.predict(image, confidences=True)
```
......@@ -71,18 +102,22 @@ text, confidence_scores = model.predict(image, confidences=True)
This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
### Get started
See the [dedicated section](https://atr.pages.teklia.com/dan/get_started/training/) on the official DAN documentation.
### Data extraction from Arkindex
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/extract/) on the official DAN documentation.
See the [dedicated section](https://atr.pages.teklia.com/dan/usage/datasets/extract/) on the official DAN documentation.
### Dataset formatting
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/format/) on the official DAN documentation.
See the [dedicated section](https://atr.pages.teklia.com/dan/usage/datasets/format/) on the official DAN documentation.
### Model training
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
See the [dedicated section](https://atr.pages.teklia.com/dan/usage/train/) on the official DAN documentation.
### Model prediction
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation.
See the [dedicated section](https://atr.pages.teklia.com/dan/usage/predict/) on the official DAN documentation.
0.2.0-dev2
0.2.0-dev3
......@@ -27,17 +27,13 @@ def parse_worker_version(worker_version_id):
return worker_version_id
def validate_probability(proba):
try:
proba = float(proba)
except ValueError:
raise argparse.ArgumentTypeError(f"`{proba}` is not a valid float.")
if proba > 1 or proba < 0:
def validate_char(char):
if len(char) != 1:
raise argparse.ArgumentTypeError(
f"`{proba}` is not a valid probability. Must be between 0 and 1 (both exclusive)."
f"`{char}` (of length {len(char)}) is not a valid character. Must be a string of length 1."
)
return proba
return char
def add_extract_parser(subcommands) -> None:
......@@ -53,13 +49,6 @@ def add_extract_parser(subcommands) -> None:
type=pathlib.Path,
help="Path where the data were exported from Arkindex.",
)
parser.add_argument(
"--parent",
type=validate_uuid,
nargs="+",
help="ID of the parent folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--element-type",
nargs="+",
......@@ -81,39 +70,46 @@ def add_extract_parser(subcommands) -> None:
required=True,
)
# Optional arguments.
parser.add_argument(
"--load-entities", action="store_true", help="Extract text with their entities."
"--train-folder",
type=validate_uuid,
help="ID of the training folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--tokens",
type=pathlib.Path,
help="Mapping between starting tokens and end tokens. Needed for entities.",
required=False,
"--val-folder",
type=validate_uuid,
help="ID of the validation folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Use the specified folder IDs for the dataset split.",
"--test-folder",
type=validate_uuid,
help="ID of the testing folder to extract from Arkindex.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--train-folder",
type=validate_uuid,
help="ID of the training folder to import from Arkindex.",
required=False,
"--load-entities",
action="store_true",
help="Extract text with their entities.",
)
parser.add_argument(
"--val-folder",
type=validate_uuid,
help="ID of the validation folder to import from Arkindex.",
"--entity-separators",
type=validate_char,
nargs="+",
help="""
Removes all text that does not appear in an entity or in the list of given ordered characters.
If several separators follow each other, keep only the first to appear in the list.
Do not give any arguments to keep the whole text.
""",
required=False,
)
parser.add_argument(
"--test-folder",
type=validate_uuid,
help="ID of the testing folder to import from Arkindex.",
"--tokens",
type=pathlib.Path,
help="Mapping between starting tokens and end tokens. Needed for entities.",
required=False,
)
......@@ -122,28 +118,12 @@ def add_extract_parser(subcommands) -> None:
type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=False,
)
parser.add_argument(
"--entity-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=False,
)
parser.add_argument(
"--train-prob",
type=validate_probability,
default=0.7,
help="Training set split size.",
)
parser.add_argument(
"--val-prob",
type=validate_probability,
default=0.15,
help="Validation set split size.",
)
parser.add_argument(
......@@ -158,4 +138,12 @@ def add_extract_parser(subcommands) -> None:
help="Images larger than this height will be resized to this width.",
)
parser.add_argument(
"--cache",
dest="cache_dir",
type=pathlib.Path,
help="Where the images should be cached.",
default=pathlib.Path(".cache"),
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
import ast
from dataclasses import dataclass
from itertools import starmap
from typing import List, NamedTuple, Optional, Union
from urllib.parse import urljoin
from typing import List, Union
from arkindex_export import Image
from arkindex_export.models import Element as ArkindexElement
from arkindex_export.models import Entity as ArkindexEntity
from arkindex_export.models import EntityType as ArkindexEntityType
from arkindex_export.models import Transcription as ArkindexTranscription
from arkindex_export.models import TranscriptionEntity as ArkindexTranscriptionEntity
from arkindex_export.queries import list_children
def bounding_box(polygon: list):
"""
Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
"""
all_x, all_y = zip(*polygon)
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return int(x), int(y), int(width), int(height)
# DB models
Transcription = NamedTuple(
"Transcription",
id=str,
text=str,
)
Entity = NamedTuple(
"Entity",
type=str,
value=str,
offset=float,
length=float,
from arkindex_export.models import (
Element,
Entity,
EntityType,
Transcription,
TranscriptionEntity,
)
@dataclass
class Element:
id: str
type: str
polygon: str
url: str
width: int
height: int
max_width: Optional[int] = None
max_height: Optional[int] = None
def __post_init__(self):
self.max_height = self.max_height or self.height
self.max_width = self.max_width or self.width
@property
def bounding_box(self):
return bounding_box(ast.literal_eval(self.polygon))
@property
def image_url(self):
x, y, width, height = self.bounding_box
return urljoin(
self.url + "/",
f"{x},{y},{width},{height}/!{self.max_width},{self.max_height}/0/default.jpg",
)
from arkindex_export.queries import list_children
def get_elements(
parent_id: str,
element_type: str,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
) -> List[Element]:
element_type: List[str],
):
"""
Retrieve elements from an SQLite export of an Arkindex corpus
"""
......@@ -84,23 +24,11 @@ def get_elements(
query = (
list_children(parent_id=parent_id)
.join(Image)
.where(ArkindexElement.type == element_type)
.select(
ArkindexElement.id,
ArkindexElement.type,
ArkindexElement.polygon,
Image.url,
Image.width,
Image.height,
)
)
return list(
starmap(
lambda *x: Element(*x, max_width=max_width, max_height=max_height),
query.tuples(),
)
.where(Element.type.in_(element_type))
)
return query
def build_worker_version_filter(ArkindexModel, worker_version):
"""
......@@ -118,47 +46,43 @@ def get_transcriptions(
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
"""
query = ArkindexTranscription.select(
ArkindexTranscription.id, ArkindexTranscription.text
).where(
(ArkindexTranscription.element == element_id)
& build_worker_version_filter(
ArkindexTranscription, worker_version=transcription_worker_version
)
)
return list(
starmap(
Transcription,
query.tuples(),
query = Transcription.select(
Transcription.id, Transcription.text, Transcription.worker_version
).where((Transcription.element == element_id))
if transcription_worker_version is not None:
query = query.where(
build_worker_version_filter(
Transcription, worker_version=transcription_worker_version
)
)
)
return query
def get_transcription_entities(
transcription_id: str, entity_worker_version: Union[str, bool]
) -> List[Entity]:
) -> List[TranscriptionEntity]:
"""
Retrieve transcription entities from an SQLite export of an Arkindex corpus
"""
query = (
ArkindexTranscriptionEntity.select(
ArkindexEntityType.name,
ArkindexEntity.name,
ArkindexTranscriptionEntity.offset,
ArkindexTranscriptionEntity.length,
)
.join(ArkindexEntity, on=ArkindexTranscriptionEntity.entity)
.join(ArkindexEntityType, on=ArkindexEntity.type)
.where(
(ArkindexTranscriptionEntity.transcription == transcription_id)
& build_worker_version_filter(
ArkindexTranscriptionEntity, worker_version=entity_worker_version
)
TranscriptionEntity.select(
EntityType.name.alias("type"),
Entity.name.alias("name"),
TranscriptionEntity.offset,
TranscriptionEntity.length,
TranscriptionEntity.worker_version,
)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
.where((TranscriptionEntity.transcription == transcription_id))
)
return list(
starmap(
Entity,
query.order_by(ArkindexTranscriptionEntity.offset).tuples(),
if entity_worker_version is not None:
query = query.where(
build_worker_version_filter(
TranscriptionEntity, worker_version=entity_worker_version
)
)
)
return query.order_by(TranscriptionEntity.offset).namedtuples()
......@@ -49,9 +49,9 @@ class NoTranscriptionError(ElementProcessingError):
return f"No transcriptions found on element ({self.element_id}) with this config. Skipping."
class UnknownLabelError(ProcessingError):
class NoEndTokenError(ProcessingError):
"""
Raised when the specified label is not known
Raised when the specified label has no end token and there is potentially additional text around the labels
"""
label: str
......@@ -61,4 +61,4 @@ class UnknownLabelError(ProcessingError):
self.label = label
def __str__(self) -> str:
return f"Label `{self.label}` is missing in the NER configuration."
return f"Label `{self.label}` has no end token."
# -*- coding: utf-8 -*-
import json
import random
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Union
from uuid import UUID
from arkindex_export import open_database
import numpy as np
from tqdm import tqdm
from arkindex_export import open_database
from dan import logger
from dan.datasets.extract.db import (
Element,
Entity,
get_elements,
get_transcription_entities,
get_transcriptions,
)
from dan.datasets.extract.exceptions import (
NoEndTokenError,
NoTranscriptionError,
ProcessingError,
UnknownLabelError,
)
from dan.datasets.extract.utils import (
EntityType,
Subset,
download_image,
insert_token,
parse_tokens,
save_json,
save_text,
)
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
SPLIT_NAMES = ["train", "val", "test"]
IIIF_URL_SUFFIX = "/full/full/0/default.jpg"
class ArkindexExtractor:
......@@ -45,80 +46,105 @@ class ArkindexExtractor:
def __init__(
self,
folders: list = [],
element_type: list = [],
element_type: List[str] = [],
parent_element_type: str = None,
output: Path = None,
load_entities: bool = None,
load_entities: bool = False,
entity_separators: list = [],
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
cache_dir: Path = Path(".cache"),
) -> None:
self.folders = folders
self.element_type = element_type
self.parent_element_type = parent_element_type
self.output = output
self.load_entities = load_entities
self.entity_separators = entity_separators
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.max_width = max_width
self.max_height = max_height
self.subsets = self.get_subsets(folders)
self.cache_dir = cache_dir
# Create cache dir if non existent
self.cache_dir.mkdir(exist_ok=True, parents=True)
def get_subsets(self, folders: list) -> List[Subset]:
"""
Assign each folder to its split if it's already known.
"""
if self.use_existing_split:
return [
Subset(folder, split) for folder, split in zip(folders, SPLIT_NAMES)
]
else:
return [Subset(folder) for folder in folders]
def find_image_in_cache(self, image_id: str) -> Path:
"""Images are cached to avoid downloading them twice. They are stored under a specific name,
based on their Arkindex ID. Images are saved under the JPEG format.
def _assign_random_split(self):
"""
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
:param image_id: ID of the image. The image is saved under this name.
:return: Where the image should be saved in the cache folder.
"""
prob = random.random()
if prob <= self.train_prob:
yield SPLIT_NAMES[0]
elif prob <= self.train_prob + self.val_prob:
yield SPLIT_NAMES[1]
else:
yield SPLIT_NAMES[2]
return self.cache_dir / f"{image_id}.jpg"
def get_random_split(self):
return next(self._assign_random_split())
def _keep_char(self, char: str) -> bool:
# Keep all text by default if no separator was given
return not self.entity_separators or char in self.entity_separators
def reconstruct_text(self, text: str, entities: List[Entity]):
def reconstruct_text(self, full_text: str, entities) -> str:
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
count = 0
text, text_offset = "", 0
# Keep all text by default if no separator was given
for entity in entities:
if entity.type not in self.tokens:
raise UnknownLabelError(entity.type)
entity_type: EntityType = self.tokens[entity.type]
text = insert_token(
text,
count,
entity_type,
offset=entity.offset,
length=entity.length,
# Text before entity
text += "".join(
filter(self._keep_char, full_text[text_offset : entity.offset])
)
count += entity_type.offset
return text
entity_type: EntityType = self.tokens.get(entity.type)
if not entity_type:
logger.warning(
f"Label `{entity.type}` is missing in the NER configuration."
)
# We keep the whole text, so we need an end token for each entity to know exactly when an entity begins and ends
elif not entity_type.end and not self.entity_separators:
raise NoEndTokenError(entity.type)
# Entity text:
# - with tokens if there is an entity_type
# - without tokens if there is no entity_type but we want to keep the whole text
if entity_type or not self.entity_separators:
text += insert_token(
full_text,
entity_type,
offset=entity.offset,
length=entity.length,
)
text_offset = entity.offset + entity.length
# Remaining text after the last entity
text += "".join(filter(self._keep_char, full_text[text_offset:]))
if not self.entity_separators:
return text
# Add some clean up to avoid several separators between entities
text, full_text = "", text
for char in full_text:
last_char = text[-1] if len(text) else ""
# Keep the current character if there are no two consecutive separators
if (
char not in self.entity_separators
or last_char not in self.entity_separators
):
text += char
# If several separators follow each other, keep only one according to the given order
elif self.entity_separators.index(char) < self.entity_separators.index(
last_char
):
text = text[:-1] + char
# Remove separators at the beginning and end of text
return text.strip("".join(self.entity_separators))
def extract_transcription(self, element: Element):
"""
......@@ -133,14 +159,60 @@ class ArkindexExtractor:
transcription = random.choice(transcriptions)
if self.load_entities:
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
else:
if not self.load_entities:
return transcription.text.strip()
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
def retrieve_image(self, child: Element):
"""Get or download image of the element. Checks in cache before downloading.
:param child: Processed element.
:return: The element's image.
"""
cached_img_path = self.find_image_in_cache(child.image.id)
if not cached_img_path.exists():
# Save in cache
download_image(child.image.url + IIIF_URL_SUFFIX).save(
cached_img_path, format="jpeg"
)
return read_img(cached_img_path)
def get_image(self, child: Element, destination: Path) -> None:
"""Save the element's image to the given path and applies any image operations needed.
:param child: Processed element.
:param destination: Where the image should be saved.
"""
polygon = json.loads(str(child.polygon))
if self.max_height or self.max_width:
polygon = resize(
polygon,
self.max_width,
self.max_height,
scale_x=1.0,
scale_y_top=1.0,
scale_y_bottom=1.0,
)
# Extract the polygon in the image
image = extract(
img=self.retrieve_image(child),
polygon=np.array(polygon),
bbox=polygon_to_bbox(polygon),
# Hardcoded while we don't have a configuration file
extraction_mode=Extraction.deskew_min_area_rect,
max_deskew_angle=45,
)
# Save the image to disk
save_img(path=destination, img=image)
def process_element(
self,
element: Element,
......@@ -152,18 +224,17 @@ class ArkindexExtractor:
"""
text = self.extract_transcription(element)
txt_path = Path(
self.output, LABELS_DIR, split, f"{element.type}_{element.id}.txt"
)
save_text(txt_path, text)
im_path = Path(
self.output, IMAGES_DIR, split, f"{element.type}_{element.id}.jpg"
base_path = Path(split, f"{element.type}_{element.id}")
Path(self.output, LABELS_DIR, base_path).with_suffix(".txt").write_text(text)
self.get_image(
element, Path(self.output, IMAGES_DIR, base_path).with_suffix(".jpg")
)
download_image(element, im_path)
return element.id
def process_parent(
self,
pbar,
parent: Element,
split: str,
):
......@@ -171,7 +242,10 @@ class ArkindexExtractor:
Extract data from a parent element.
"""
data = defaultdict(list)
base_description = (
f"Extracting data from {parent.type} ({parent.id}) for split ({split})"
)
pbar.set_description(desc=base_description)
if self.element_type == [parent.type]:
try:
data[parent.type].append(self.process_element(parent, split))
......@@ -179,84 +253,64 @@ class ArkindexExtractor:
logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements
else:
for element_type in self.element_type:
for element in get_elements(
parent.id,
element_type,
max_width=self.max_width,
max_height=self.max_height,
):
try:
data[element_type].append(self.process_element(element, split))
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
children = get_elements(
parent.id,
self.element_type,
)
nb_children = children.count()
for idx, element in enumerate(children, start=1):
# Update description to update the children processing progress
pbar.set_description(desc=base_description + f" ({idx}/{nb_children})")
try:
data[element.type].append(self.process_element(element, split))
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
return data
def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels.
for idx, subset in enumerate(self.subsets, start=1):
# Iterate over the pages to create splits at page level.
for parent in tqdm(
for folder_id, split in zip(self.folders, SPLIT_NAMES):
with tqdm(
get_elements(
subset.id,
self.parent_element_type,
max_width=self.max_width,
max_height=self.max_height,
folder_id,
[self.parent_element_type],
),
desc=f"Processing {subset} {idx}/{len(self.subsets)}",
):
split = subset.split or self.get_random_split()
split_dict[split][parent.id] = self.process_parent(
parent=parent,
split=split,
)
save_json(self.output / "split.json", split_dict)
desc=f"Extracting data from ({folder_id}) for split ({split})",
) as pbar:
# Iterate over the pages to create splits at page level.
for parent in pbar:
self.process_parent(
pbar=pbar,
parent=parent,
split=split,
)
# Progress bar updates
pbar.update()
pbar.refresh()
def run(
database: Path,
parent: list,
element_type: str,
element_type: List[str],
parent_element_type: str,
output: Path,
load_entities: bool,
entity_separators: list,
tokens: Path,
use_existing_split: bool,
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Union[str, bool],
entity_worker_version: Union[str, bool],
train_prob,
val_prob,
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
max_width: Optional[int],
max_height: Optional[int],
cache_dir: Path,
):
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
assert use_existing_split ^ bool(
parent
), "Only one of `--use-existing-split` and `--parent` must be set"
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
if use_existing_split:
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [str(train_folder), str(val_folder), str(test_folder)]
else:
folders = [str(parent_id) for parent_id in parent]
folders = [str(train_folder), str(val_folder), str(test_folder)]
if load_entities:
assert tokens, "Please provide the entities to match."
......@@ -272,12 +326,11 @@ def run(
parent_element_type=parent_element_type,
output=output,
load_entities=load_entities,
entity_separators=entity_separators,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
max_width=max_width,
max_height=max_height,
cache_dir=cache_dir,
).run()
# -*- coding: utf-8 -*-
import json
import logging
import time
from io import BytesIO
from pathlib import Path
from typing import NamedTuple
import cv2
import imageio.v2 as iio
import requests
import yaml
from numpy import ndarray
from dan.datasets.extract.db import Element
from dan.datasets.extract.exceptions import ImageDownloadError
from PIL import Image
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
MAX_RETRIES = 5
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
class Subset(NamedTuple):
id: str
split: str = None
def __str__(self) -> str:
return (
f"Subset(id='{self.id}', split='{self.split.capitalize()}')"
if self.split
else f"Subset(id='{self.id}')"
)
def _retry_log(retry_state, *args, **kwargs):
logger.warning(
f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
f"retrying in {retry_state.idle_for} seconds"
)
class EntityType(NamedTuple):
......@@ -39,60 +36,54 @@ class EntityType(NamedTuple):
return len(self.start) + len(self.end)
def download_image(element: Element, im_path: Path):
tries = 1
# retry loop
while True:
if tries > MAX_RETRIES:
raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
try:
image = iio.imread(element.image_url)
save_image(im_path, image)
return
except TimeoutError:
logger.warning("Timeout, retry in 1 second.")
time.sleep(1)
tries += 1
except Exception as e:
raise ImageDownloadError(element.id, e)
def save_text(path: Path, text: str):
with path.open("w") as f:
f.write(text)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
retry=retry_if_exception_type(requests.RequestException),
before_sleep=_retry_log,
reraise=True,
)
def _retried_request(url):
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
resp.raise_for_status()
return resp
def save_image(path: Path, image: ndarray):
cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def download_image(url):
"""
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = _retried_request(url)
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.debug(
"Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
)
def save_json(path: Path, data: dict):
with path.open("w") as outfile:
json.dump(data, outfile, indent=4)
return image
def insert_token(
text: str, count: int, entity_type: EntityType, offset: int, length: int
) -> str:
def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str:
"""
Insert the given tokens at the right position in the text
"""
return (
# Text before entity
text[: count + offset]
# Starting token
+ entity_type.start
(entity_type.start if entity_type else "")
# Entity
+ text[count + offset : count + offset + length]
+ text[offset : offset + length]
# End token
+ entity_type.end
# Text after entity
+ text[count + offset + length :]
+ (entity_type.end if entity_type else "")
)
def parse_tokens(filename: Path) -> dict:
with filename.open() as f:
return {
name: EntityType(**tokens) for name, tokens in yaml.safe_load(f).items()
}
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
}
......@@ -63,10 +63,12 @@ class ATRDatasetFormatter:
def parse_labels(self, set_name, file_name):
return {
"img_path": os.path.join(
self.image_folder,
set_name,
f"{os.path.splitext(file_name)[0]}.{self.image_format}",
"img_path": os.path.realpath(
os.path.join(
self.image_folder,
set_name,
f"{os.path.splitext(file_name)[0]}.{self.image_format}",
)
),
"label": self.read_file(
os.path.join(self.labels_folder, set_name, file_name)
......
......@@ -2,24 +2,13 @@
import torch
from torch import relu, softmax
from torch.nn import (
LSTM,
Conv1d,
Dropout,
Embedding,
LayerNorm,
Linear,
Module,
ModuleList,
)
from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
from torch.nn.init import xavier_uniform_
class PositionalEncoding1D(Module):
def __init__(self, dim, len_max, device):
super(PositionalEncoding1D, self).__init__()
self.len_max = len_max
self.dim = dim
self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False)
div = torch.exp(
......@@ -46,9 +35,6 @@ class PositionalEncoding1D(Module):
class PositionalEncoding2D(Module):
def __init__(self, dim, h_max, w_max, device):
super(PositionalEncoding2D, self).__init__()
self.h_max = h_max
self.max_w = w_max
self.dim = dim
self.pe = torch.zeros(
(1, dim, h_max, w_max), device=device, requires_grad=False
)
......@@ -177,31 +163,28 @@ class GlobalDecoderLayer(Module):
def __init__(self, params):
super(GlobalDecoderLayer, self).__init__()
self.emb_dim = params["enc_dim"]
self.dim_feedforward = params["dec_dim_feedforward"]
self.self_att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.norm1 = LayerNorm(self.emb_dim)
self.norm1 = LayerNorm(params["enc_dim"])
self.att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.linear1 = Linear(self.emb_dim, self.dim_feedforward)
self.linear2 = Linear(self.dim_feedforward, self.emb_dim)
self.linear1 = Linear(params["enc_dim"], params["dec_dim_feedforward"])
self.linear2 = Linear(params["dec_dim_feedforward"], params["enc_dim"])
self.dropout = Dropout(params["dec_res_dropout"])
self.norm2 = LayerNorm(self.emb_dim)
self.norm3 = LayerNorm(self.emb_dim)
self.norm2 = LayerNorm(params["enc_dim"])
self.norm3 = LayerNorm(params["enc_dim"])
def forward(
self,
......@@ -319,20 +302,12 @@ class FeaturesUpdater(Module):
def __init__(self, params):
super(FeaturesUpdater, self).__init__()
self.enc_dim = params["enc_dim"]
self.enc_h_max = params["h_max"]
self.enc_w_max = params["w_max"]
self.pe_2d = PositionalEncoding2D(
self.enc_dim, self.enc_h_max, self.enc_w_max, params["device"]
)
self.use_2d_positional_encoding = (
"use_2d_pe" not in params or params["use_2d_pe"]
params["enc_dim"], params["h_max"], params["w_max"], params["device"]
)
def get_pos_features(self, features):
if self.use_2d_positional_encoding:
return self.pe_2d(features)
return features
return self.pe_2d(features)
class GlobalHTADecoder(Module):
......@@ -342,31 +317,23 @@ class GlobalHTADecoder(Module):
def __init__(self, params):
super(GlobalHTADecoder, self).__init__()
self.enc_dim = params["enc_dim"]
self.dec_l_max = params["l_max"]
self.dropout = Dropout(params["dec_pred_dropout"])
self.dec_att_win = (
params["attention_win"] if params["attention_win"] is not None else 1
)
self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
self.use_lstm = params["use_lstm"]
self.features_updater = FeaturesUpdater(params)
self.att_decoder = GlobalAttDecoder(params)
self.emb = Embedding(
num_embeddings=params["vocab_size"] + 3, embedding_dim=self.enc_dim
num_embeddings=params["vocab_size"] + 3, embedding_dim=params["enc_dim"]
)
self.pe_1d = PositionalEncoding1D(
self.enc_dim, self.dec_l_max, params["device"]
params["enc_dim"], params["l_max"], params["device"]
)
if self.use_lstm:
self.lstm_predict = LSTM(self.enc_dim, self.enc_dim)
vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1)
self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
def forward(
self,
......@@ -388,9 +355,7 @@ class GlobalHTADecoder(Module):
pos_tokens = self.emb(tokens).permute(0, 2, 1)
# Add 1D Positional Encoding
if self.use_1d_pe:
pos_tokens = self.pe_1d(pos_tokens, start=start)
pos_tokens = pos_tokens.permute(2, 0, 1)
pos_tokens = self.pe_1d(pos_tokens, start=start).permute(2, 0, 1)
if num_pred is None:
num_pred = tokens.size(1)
......@@ -440,9 +405,6 @@ class GlobalHTADecoder(Module):
keep_all_weights=keep_all_weights,
)
if self.use_lstm:
output, hidden_predict = self.lstm_predict(output, hidden_predict)
dp_output = self.dropout(relu(output))
preds = self.end_conv(dp_output.permute(1, 2, 0))
......
......@@ -92,9 +92,7 @@ class FCN_Encoder(Module):
self.init_blocks = ModuleList(
[
ConvBlock(
params["input_channels"], 16, stride=(1, 1), dropout=self.dropout
),
ConvBlock(3, 16, stride=(1, 1), dropout=self.dropout),
ConvBlock(16, 32, stride=(2, 2), dropout=self.dropout),
ConvBlock(32, 64, stride=(2, 2), dropout=self.dropout),
ConvBlock(64, 128, stride=(2, 2), dropout=self.dropout),
......
# -*- coding: utf-8 -*-
import copy
import json
import os
import random
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset
from dan.datasets.utils import natural_sort
from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import read_image, token_to_ind
class DatasetManager:
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
class OCRDataset(Dataset):
"""
Dataset class to handle dataset loading
"""
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
def __init__(
self,
set_name,
paths_and_sets,
charset,
tokens,
preprocessing_transforms,
augmentation_transforms,
load_in_memory=False,
mean=None,
std=None,
):
self.set_name = set_name
self.charset = charset
self.tokens = tokens
self.load_in_memory = load_in_memory
self.mean = mean
self.std = std
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
# Pre-processing, augmentation
self.preprocessing_transforms = preprocessing_transforms
self.augmentation_transforms = augmentation_transforms
self.generator = torch.Generator()
self.generator.manual_seed(0)
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.batch_size = {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
def apply_specific_treatment_after_dataset_loading(self, dataset):
raise NotImplementedError
# Curriculum config
self.curriculum_config = None
def load_datasets(self):
def __len__(self):
"""
Load training and validation datasets
Return the dataset size
"""
self.train_dataset = self.dataset_class(
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
(
self.params["config"]["mean"],
self.params["config"]["std"],
) = self.train_dataset.compute_std_mean()
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
return len(self.samples)
def load_ddp_samplers(self):
def __getitem__(self, idx):
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
# Load preprocessed image
sample = copy.deepcopy(self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Convert to numpy
sample["img"] = np.array(sample["img"])
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Apply data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
@staticmethod
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Image normalization
sample["img"] = (sample["img"] - self.mean) / self.std
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params, "test", custom_name, paths_and_sets
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
# Get final height and width
sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
sample["img"]
)
def get_paths_and_sets(self, dataset_names_folds):
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.mean = (
np.array(params["config"]["mean"])
if "mean" in params["config"].keys()
else None
)
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
# Convert label into tokens
sample["token_label"], sample["label_len"] = self.convert_sample_label(
sample["label"]
)
self.preprocessing_transforms = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
self.samples = self.load_samples(
paths_and_sets, load_in_memory=self.load_in_memory
)
if self.load_in_memory:
self.preprocess_all_images()
self.curriculum_config = None
def __len__(self):
return len(self.samples)
return sample
@staticmethod
def load_image(path):
with Image.open(path) as pil_img:
return pil_img.convert("RGB")
@staticmethod
def load_samples(paths_and_sets, load_in_memory=True):
def load_samples(self, paths_and_sets):
"""
Load images and labels
"""
......@@ -266,16 +107,20 @@ class GenericDataset(Dataset):
"path": os.path.abspath(filename),
}
)
if load_in_memory:
samples[-1]["img"] = GenericDataset.load_image(filename)
if self.load_in_memory:
samples[-1]["img"] = self.preprocessing_transforms(
read_image(filename)
)
return samples
def preprocess_all_images(self) -> None:
def get_sample_img(self, i):
"""
Iterate over all samples and apply pre-processing
Get image by index
"""
for i, sample in enumerate(self.samples):
self.samples[i]["img"] = self.preprocessing_transforms(sample["img"])
if self.load_in_memory:
return self.samples[i]["img"]
return self.preprocessing_transforms(read_image(self.samples[i]["path"]))
def compute_std_mean(self):
"""
......@@ -284,34 +129,46 @@ class GenericDataset(Dataset):
if self.mean is not None and self.std is not None:
return self.mean, self.std
sum = np.zeros((3,))
total = np.zeros((3,))
diff = np.zeros((3,))
nb_pixels = 0
for metric in ["mean", "std"]:
for ind in range(len(self.samples)):
img = np.array(
self.get_sample_img(ind)
if self.load_in_memory
else self.preprocessing_transforms(self.get_sample_img(ind)),
)
img = np.array(self.get_sample_img(ind))
if metric == "mean":
sum += np.sum(img, axis=(0, 1))
total += np.sum(img, axis=(0, 1))
nb_pixels += np.prod(img.shape[:2])
elif metric == "std":
diff += [
np.sum((img[:, :, k] - self.mean[k]) ** 2) for k in range(3)
]
if metric == "mean":
self.mean = sum / nb_pixels
self.mean = total / nb_pixels
elif metric == "std":
self.std = np.sqrt(diff / nb_pixels)
return self.mean, self.std
def get_sample_img(self, i):
def compute_final_size(self, img):
"""
Get image by index
Compute the final image size and position after feature extraction
"""
if self.load_in_memory:
return self.samples[i]["img"]
else:
return GenericDataset.load_image(self.samples[i]["path"])
image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
if self.set_name == "train":
image_reduced_shape = [max(1, t) for t in image_reduced_shape]
image_position = [
[0, img.shape[0]],
[0, img.shape[1]],
]
return image_reduced_shape, image_position
def convert_sample_label(self, label):
"""
Tokenize the label and return its length
"""
token_label = token_to_ind(self.charset, label)
token_label.append(self.tokens["end"])
label_len = len(token_label)
token_label.insert(0, self.tokens["start"])
return token_label, label_len
# -*- coding: utf-8 -*-
import os
import pickle
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from dan.manager.dataset import DatasetManager, GenericDataset
from dan.utils import pad_images, pad_sequences_1D, token_to_ind
from dan.manager.dataset import OCRDataset
from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import pad_images, pad_sequences_1D
class OCRDatasetManager(DatasetManager):
"""
Specific class to handle OCR/HTR tasks
"""
class OCRDatasetManager:
def __init__(self, params, device: str):
super(OCRDatasetManager, self).__init__(params, device)
self.params = params
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
self.dataset_class = OCRDataset
self.charset = (
params["charset"] if "charset" in params else self.get_merged_charsets()
self.mean = (
np.array(params["config"]["mean"])
if "mean" in params["config"].keys()
else None
)
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
)
self.generator = torch.Generator()
self.generator.manual_seed(0)
self.tokens = {"pad": len(self.charset) + 2}
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
self.charset = self.get_charset()
self.tokens = self.get_tokens()
self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self):
self.my_collate_function = OCRCollateFunction(self.params["config"])
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
)
self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"], to_pil_image=True
)
def load_datasets(self):
"""
Load training and validation datasets
"""
self.train_dataset = OCRDataset(
set_name="train",
paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=self.augmentation,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
self.mean, self.std = self.train_dataset.compute_std_mean()
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = OCRDataset(
set_name="val",
paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
def load_ddp_samplers(self):
"""
Load training and validation data samplers
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.params["batch_size"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=1,
sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod
def seed_worker(worker_id):
"""
Set worker seed
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = OCRDataset(
set_name="test",
paths_and_sets=paths_and_sets,
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=1,
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
"""
Set the right path for each data set
"""
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
def get_charset(self):
"""
Merge the charset of the different datasets used
"""
if "charset" in self.params:
return self.params["charset"]
datasets = self.params["datasets"]
charset = set()
for key in datasets.keys():
......@@ -40,87 +229,15 @@ class OCRDatasetManager(DatasetManager):
charset.remove("")
return sorted(list(charset))
def apply_specific_treatment_after_dataset_loading(self, dataset):
dataset.charset = self.charset
dataset.tokens = self.tokens
dataset.convert_labels()
class OCRDataset(GenericDataset):
"""
Specific class to handle OCR/HTR datasets
"""
def __init__(
self,
params,
set_name,
custom_name,
paths_and_sets,
augmentation_transforms=None,
):
super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
self.charset = None
self.tokens = None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.collate_function = OCRCollateFunction
self.augmentation_transforms = augmentation_transforms
def __getitem__(self, idx):
sample = dict(**self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
# Convert to numpy
sample["img"] = np.array(sample["img"])
# Get initial height and width
initial_h, initial_w, _ = sample["img"].shape
# Data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
# Normalization
sample["img"] = (sample["img"] - self.mean) / self.std
# Get final height and width (tensor mode)
final_h, final_w, _ = sample["img"].shape
sample["resize_ratio"] = [final_h / initial_h, final_w / initial_w]
sample["img_reduced_shape"] = np.ceil(
sample["img"].shape / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
sample["img_reduced_shape"] = [
max(1, t) for t in sample["img_reduced_shape"]
]
sample["img_position"] = [
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
return sample
def convert_labels(self):
def get_tokens(self):
"""
Label str to token at character level
Get special tokens
"""
for i in range(len(self.samples)):
self.samples[i] = self.convert_sample_labels(self.samples[i])
def convert_sample_labels(self, sample):
label = sample["label"]
sample["label"] = label
sample["token_label"] = token_to_ind(self.charset, label)
sample["token_label"].append(self.tokens["end"])
sample["label_len"] = len(sample["token_label"])
sample["token_label"].insert(0, self.tokens["start"])
return sample
return {
"end": len(self.charset),
"start": len(self.charset) + 1,
"pad": len(self.charset) + 2,
}
class OCRCollateFunction:
......@@ -134,12 +251,13 @@ class OCRCollateFunction:
def __call__(self, batch_data):
labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
labels = torch.tensor(labels).long()
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value).long()
imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
imgs = [
torch.from_numpy(batch_data[i]["img"]).permute(2, 0, 1)
for i in range(len(batch_data))
]
imgs = pad_images(imgs)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = {
formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))]
......
# -*- coding: utf-8 -*-
import json
import os
import random
from copy import deepcopy
from enum import Enum
from time import time
import numpy as np
......@@ -21,7 +21,7 @@ from dan.manager.metrics import MetricManager
from dan.manager.ocr import OCRDatasetManager
from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.schedulers import DropoutScheduler
from dan.utils import ind_to_token
from dan.utils import fix_ddp_layers_names, ind_to_token
if MLFLOW_AVAILABLE:
import mlflow
......@@ -34,7 +34,6 @@ class GenericTrainingManager:
self.params = params
self.dropout_scheduler = None
self.models = {}
self.begin_time = None
self.dataset = None
self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0]
self.paths = None
......@@ -56,6 +55,11 @@ class GenericTrainingManager:
self.params["model_params"]["use_amp"] = self.params["training_params"][
"use_amp"
]
self.nb_gpu = (
self.params["training_params"]["nb_gpu"]
if self.params["training_params"]["use_ddp"]
else 1
)
def init_paths(self):
"""
......@@ -84,14 +88,6 @@ class GenericTrainingManager:
self.params["dataset_params"]["batch_size"] = self.params["training_params"][
"batch_size"
]
if "valid_batch_size" in self.params["training_params"]:
self.params["dataset_params"]["valid_batch_size"] = self.params[
"training_params"
]["valid_batch_size"]
if "test_batch_size" in self.params["training_params"]:
self.params["dataset_params"]["test_batch_size"] = self.params[
"training_params"
]["test_batch_size"]
self.params["dataset_params"]["num_gpu"] = self.params["training_params"][
"nb_gpu"
]
......@@ -193,7 +189,9 @@ class GenericTrainingManager:
# make the model compatible with Distributed Data Parallel if used
if self.params["training_params"]["use_ddp"]:
self.models[model_name] = DDP(
self.models[model_name], [self.ddp_config["rank"]]
self.models[model_name],
[self.ddp_config["rank"]],
output_device=self.ddp_config["rank"],
)
# Handle curriculum dropout
......@@ -223,7 +221,10 @@ class GenericTrainingManager:
if self.params["training_params"]["load_epoch"] in ("best", "last"):
for filename in os.listdir(self.paths["checkpoints"]):
if self.params["training_params"]["load_epoch"] in filename:
return torch.load(os.path.join(self.paths["checkpoints"], filename))
return torch.load(
os.path.join(self.paths["checkpoints"], filename),
map_location=self.device,
)
return None
def load_existing_model(self, checkpoint, strict=True):
......@@ -239,8 +240,14 @@ class GenericTrainingManager:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
# Load model weights from past training
for model_name in self.models.keys():
# Transform to DDP/from DDP model
checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
checkpoint[f"{model_name}_state_dict"],
self.params["training_params"]["use_ddp"],
)
self.models[model_name].load_state_dict(
checkpoint["{}_state_dict".format(model_name)], strict=strict
checkpoint[f"{model_name}_state_dict"], strict=strict
)
def init_new_model(self):
......@@ -261,8 +268,15 @@ class GenericTrainingManager:
state_dict_name, path, learnable, strict = self.params["model_params"][
"transfer_learning"
][model_name]
# Loading pretrained weights file
checkpoint = torch.load(path)
checkpoint = torch.load(path, map_location=self.device)
# Transform to DDP/from DDP model
checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
checkpoint[f"{model_name}_state_dict"],
self.params["training_params"]["use_ddp"],
)
try:
# Load pretrained weights for model
self.models[model_name].load_state_dict(
......@@ -462,23 +476,37 @@ class GenericTrainingManager:
def save_params(self):
"""
Output text file containing a summary of all hyperparameters chosen for the training
Output a yaml file containing a summary of all hyperparameters chosen for the training
and a yaml file containing parameters used for inference
"""
def compute_nb_params(module):
return sum([np.prod(p.size()) for p in list(module.parameters())])
def class_to_str_dict(my_dict):
for key in my_dict.keys():
if callable(my_dict[key]):
for key in my_dict:
if key == "preprocessings":
my_dict[key] = [
{
key: value.value if isinstance(value, Enum) else value
for key, value in preprocessing.items()
}
for preprocessing in my_dict[key]
]
elif callable(my_dict[key]):
my_dict[key] = my_dict[key].__name__
elif isinstance(my_dict[key], np.ndarray):
my_dict[key] = my_dict[key].tolist()
elif isinstance(my_dict[key], list) and isinstance(
my_dict[key][0], tuple
):
my_dict[key] = [list(elt) for elt in my_dict[key]]
elif isinstance(my_dict[key], dict):
my_dict[key] = class_to_str_dict(my_dict[key])
return my_dict
path = os.path.join(self.paths["results"], "params")
# Save training parameters
path = os.path.join(self.paths["results"], "training_parameters.yml")
if os.path.isfile(path):
return
params = class_to_str_dict(my_dict=deepcopy(self.params))
......@@ -491,8 +519,45 @@ class GenericTrainingManager:
]
total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params)
params["mean"] = self.dataset.mean.tolist()
params["std"] = self.dataset.std.tolist()
with open(path, "w") as f:
yaml.dump(params, f)
# Save inference parameters
path = os.path.join(self.paths["results"], "inference_parameters.yml")
if os.path.isfile(path):
return
inference_params = {
"parameters": {
"mean": params["mean"],
"std": params["std"],
"max_char_prediction": params["training_params"]["max_char_prediction"],
"encoder": {
"dropout": params["model_params"]["dropout"],
},
"decoder": {
key: params["model_params"][key]
for key in [
"enc_dim",
"l_max",
"h_max",
"w_max",
"dec_num_layers",
"dec_num_heads",
"dec_res_dropout",
"dec_pred_dropout",
"dec_att_dropout",
"dec_dim_feedforward",
"vocab_size",
"attention_win",
]
},
"preprocessings": params["dataset_params"]["config"]["preprocessings"],
},
}
with open(path, "w") as f:
json.dump(params, f, indent=4)
yaml.dump(inference_params, f)
def backward_loss(self, loss, retain_graph=False):
self.scaler.scale(loss).backward(retain_graph=retain_graph)
......@@ -529,10 +594,7 @@ class GenericTrainingManager:
self.writer = SummaryWriter(self.paths["results"])
self.save_params()
# init variables
self.begin_time = time()
focus_metric_name = self.params["training_params"]["focus_metric"]
nb_epochs = self.params["training_params"]["max_nb_epochs"]
interval_save_weights = self.params["training_params"]["interval_save_weights"]
metric_names = self.params["training_params"]["train_metrics"]
display_values = None
......@@ -544,13 +606,6 @@ class GenericTrainingManager:
self.init_curriculum()
# perform epochs
for num_epoch in range(self.latest_epoch + 1, nb_epochs):
# Check maximum training time stop condition
if (
self.params["training_params"]["max_training_time"]
and time() - self.begin_time
> self.params["training_params"]["max_training_time"]
):
break
# set models trainable
for model_name in self.models.keys():
self.models[model_name].train()
......@@ -563,7 +618,6 @@ class GenericTrainingManager:
self.metric_manager["train"] = MetricManager(
metric_names=metric_names, dataset_name=self.dataset_name
)
with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar:
pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
# iterates over mini-batch data
......@@ -612,7 +666,7 @@ class GenericTrainingManager:
self.metric_manager["train"].update_metrics(batch_metrics)
display_values = self.metric_manager["train"].get_display_values()
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]))
pbar.update(len(batch_data["names"]) * self.nb_gpu)
# Log MLflow metrics
logging_metrics(
......@@ -651,25 +705,9 @@ class GenericTrainingManager:
)
if valid_set_name == self.params["training_params"][
"set_name_focus_metric"
] and (
self.best is None
or (
eval_values[focus_metric_name] <= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "low"
)
or (
eval_values[focus_metric_name] >= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "high"
)
):
] and (self.best is None or eval_values["cer"] <= self.best):
self.save_model(epoch=num_epoch, name="best")
self.best = eval_values[focus_metric_name]
self.best = eval_values["cer"]
# Handle curriculum learning update
if self.dataset.train_dataset.curriculum_config:
......@@ -684,8 +722,6 @@ class GenericTrainingManager:
# save model weights
if self.is_master:
self.save_model(epoch=num_epoch, name="last")
if interval_save_weights and num_epoch % interval_save_weights == 0:
self.save_model(epoch=num_epoch, name="weights", keep_weights=True)
self.writer.flush()
def evaluate(self, set_name, mlflow_logging=False, **kwargs):
......@@ -723,7 +759,7 @@ class GenericTrainingManager:
display_values = self.metric_manager[set_name].get_display_values()
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]))
pbar.update(len(batch_data["names"]) * self.nb_gpu)
# log metrics in MLflow
logging_metrics(
......@@ -775,7 +811,7 @@ class GenericTrainingManager:
].get_display_values()
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]))
pbar.update(len(batch_data["names"]) * self.nb_gpu)
# log metrics in MLflow
logging_name = custom_name.split("-")[1]
......@@ -977,9 +1013,14 @@ class Manager(OCRManager):
features_size = raw_features.size()
b, c, h, w = features_size
pos_features = self.models["decoder"].features_updater.get_pos_features(
raw_features
)
if self.params["training_params"]["use_ddp"]:
pos_features = self.models[
"decoder"
].module.features_updater.get_pos_features(raw_features)
else:
pos_features = self.models["decoder"].features_updater.get_pos_features(
raw_features
)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1
)
......@@ -1072,9 +1113,14 @@ class Manager(OCRManager):
else:
features = self.models["encoder"](x)
features_size = features.size()
pos_features = self.models["decoder"].features_updater.get_pos_features(
features
)
if self.params["training_params"]["use_ddp"]:
pos_features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
pos_features = self.models["decoder"].features_updater.get_pos_features(
features
)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1
)
......
......@@ -138,7 +138,6 @@ def get_config():
},
"transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model
"additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset
"input_channels": 3, # number of channels of input image
"dropout": 0.5, # dropout rate for encoder
"enc_dim": 256, # dimension of extracted features
"nb_layers": 5, # encoder
......@@ -151,9 +150,6 @@ def get_config():
"dec_pred_dropout": 0.1, # dropout rate before decision layer
"dec_att_dropout": 0.1, # dropout rate in multi head attention
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window
# Curriculum dropout
"dropout_scheduler": {
......@@ -163,14 +159,9 @@ def get_config():
},
"training_params": {
"output_folder": "outputs/dan_esposalles_record", # folder name for checkpoint and results
"max_nb_epochs": 710, # maximum number of epochs before to stop
"max_training_time": 3600
* 24
* 1.9, # maximum time before to stop (in seconds)
"max_nb_epochs": 800, # maximum number of epochs before to stop
"load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate
"interval_save_weights": None, # None: keep best and last only
"batch_size": 2, # mini-batch size for training
"valid_batch_size": 4, # mini-batch size for valdiation
"use_ddp": False, # Use DistributedDataParallel
"ddp_port": "20027",
"use_amp": True, # Enable automatic mix-precision
......@@ -187,8 +178,6 @@ def get_config():
"lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
......@@ -258,18 +247,18 @@ def serialize_config(config):
return serialized_config
def start_training(config) -> None:
def start_training(config, mlflow_logging: bool) -> None:
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
args=(config, mlflow_logging),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
train_and_test(0, config, mlflow_logging)
def run():
......@@ -286,7 +275,7 @@ def run():
raise MLflowNotInstalled()
if "mlflow" not in config:
start_training(config)
start_training(config, mlflow_logging=False)
else:
labels_path = (
Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
......@@ -314,4 +303,4 @@ def run():
dictionary=artifact,
artifact_file=filename,
)
start_training(config)
start_training(config, mlflow_logging=True)
......@@ -18,6 +18,7 @@ def add_predict_parser(subcommands) -> None:
image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
image_or_folder_input.add_argument(
"--image",
type=pathlib.Path,
help="Path to the image to predict.",
)
image_or_folder_input.add_argument(
......@@ -50,6 +51,12 @@ def add_predict_parser(subcommands) -> None:
help="Path to the output folder.",
required=True,
)
parser.add_argument(
"--tokens",
type=pathlib.Path,
required=True,
help="Path to a yaml file containing a mapping between starting tokens and end tokens. Needed for entities.",
)
# Optional arguments.
parser.add_argument(
"--image-extension",
......@@ -57,26 +64,12 @@ def add_predict_parser(subcommands) -> None:
help="The extension of the images in the folder.",
default=".jpg",
)
parser.add_argument(
"--scale",
type=float,
default=1.0,
required=False,
help="Image scaling factor before feeding it to DAN",
)
parser.add_argument(
"--image-max-width",
type=int,
default=1800,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Temperature scaling scalar parameter",
required=True,
required=False,
)
parser.add_argument(
"--confidence-score",
......@@ -147,4 +140,18 @@ def add_predict_parser(subcommands) -> None:
type=int,
default=0,
)
parser.add_argument(
"--gpu-device",
help="Use a specific GPU if available.",
type=int,
required=False,
)
parser.add_argument(
"--batch-size",
help="Size of prediction batches.",
type=int,
default=1,
required=False,
)
parser.set_defaults(func=run)
......@@ -4,6 +4,7 @@ import re
import cv2
import numpy as np
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan import logger
......@@ -70,14 +71,18 @@ def split_text_and_confidences(
texts = list(text)
offset = 0
elif level == "word":
texts, probs = compute_prob_by_separator(text, confidences, word_separators)
texts, confidences = compute_prob_by_separator(
text, confidences, word_separators
)
offset = 1
elif level == "line":
texts, probs = compute_prob_by_separator(text, confidences, line_separators)
texts, confidences = compute_prob_by_separator(
text, confidences, line_separators
)
offset = 1
else:
logger.error("Level should be either 'char', 'word', or 'line'")
return texts, [np.around(num, 2) for num in probs], offset
return texts, [np.around(num, 2) for num in confidences], offset
def get_predicted_polygons_with_confidence(
......@@ -175,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale):
blend = Image.composite(image, coverage_vector, mask)
# Resize to save time
blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS)
blend = blend.resize((int(width * scale), int(height * scale)), Image.LANCZOS)
return blend
......@@ -288,7 +293,7 @@ def plot_attention(
):
"""
Create a gif by blending attention maps to the image for each text piece (char, word or line)
:param image: Input image in PIL format
:param image: Input image as torch.Tensor
:param text: Text predicted by DAN
:param weights: Attention weights of size (n_char, feature_height, feature_width)
:param level: Level to display (must be in [char, word, line])
......@@ -298,13 +303,11 @@ def plot_attention(
:param line_separators: List of line separators
:param display_polygons: Whether to plot extracted polygons
"""
height, width, _ = image.shape
image = to_pil_image(image)
attention_map = []
# Convert to PIL Image and create mask
mask = Image.new("L", (width, height), color=(110))
image = Image.fromarray(image)
mask = Image.new("L", (image.width, image.height), color=(110))
# Split text into characters, words or lines
text_list, offset = split_text(text, level, word_separators, line_separators)
......@@ -316,7 +319,7 @@ def plot_attention(
for text_piece in text_list:
# Accumulate weights for the current word/line and resize to original image size
coverage_vector = compute_coverage(
text_piece, max_value, tot_len, weights, (width, height)
text_piece, max_value, tot_len, weights, (image.width, image.height)
)
# Get polygons if flag is set:
......@@ -329,7 +332,7 @@ def plot_attention(
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
size=(width, height),
size=(image.width, image.height),
)
if contour is not None:
......