Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (4)
Showing
with 698 additions and 313 deletions
......@@ -4,7 +4,7 @@ stages:
- deploy
lint:
image: python:3.8
image: python:3.10
stage: test
cache:
......@@ -24,6 +24,34 @@ lint:
script:
- pre-commit run -a
test:
image: python:3.10
stage: test
cache:
paths:
- .cache/pip
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
ARKINDEX_API_SCHEMA_URL: schema.yml
before_script:
- pip install tox
# Download OpenAPI schema from last backend build
- curl https://assets.teklia.com/arkindex/openapi.yml > schema.yml
# Add system deps for opencv
- apt-get update -q
- apt-get install -q -y libgl1
except:
- schedules
script:
- tox
bump-python-deps:
stage: deploy
image: registry.gitlab.com/teklia/devops:latest
......@@ -32,7 +60,7 @@ bump-python-deps:
- schedules
script:
- devops python-deps requirements.txt tests-requirements.txt
- devops python-deps requirements.txt
release-notes:
stage: deploy
......
include requirements.txt
include VERSION
......@@ -21,9 +21,9 @@ We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page a
We obtained the following results:
| | CER (%) | WER (%) | LOER (%) | mAP_cer (%) |
|:-----------------------:|---------|:-------:|:--------:|-------------|
| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 |
| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 |
| :---------------------: | ------- | :-----: | :------: | ----------- |
| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 |
| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 |
| READ 2016 (double page) | 3.69 | 14.20 | 4.60 | 93.92 |
......@@ -92,3 +92,106 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In
```python
text, confidence_scores = model.predict(image, confidences=True)
```
### Commands
This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
#### Data extraction from Arkindex
Use the `teklia-dan dataset extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model.
The available arguments are
| Parameter | Description | Type | Default |
| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- |
| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str/uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `Path` | |
| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` |
| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | |
| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | |
| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | |
| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering. | `str/uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | `str/uuid` | |
| `--train-prob` | Training set split size | `float` | `0.7` |
| `--val-prob` | Validation set split size | `float` | `0.15` |
The `--tokens` argument expects a file with the following format.
```yaml
---
INTITULE:
start: ⓘ
end: Ⓘ
DATE:
start: ⓓ
end: Ⓓ
COTE_SERIE:
start: ⓢ
end: Ⓢ
ANALYSE_COMPL.:
start: ⓒ
end: Ⓒ
PRECISIONS_SUR_COTE:
start: ⓟ
end: Ⓟ
COTE_ARTICLE:
start: ⓐ
end: Ⓐ
CLASSEMENT:
start: ⓛ
end: Ⓛ
```
To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command:
```shell
teklia-dan dataset extract \
--parent 665e84ea-bd97-4912-91b0-1f4a844287ff \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
with `tokens.yml` having the content described just above.
To do the same but only use the data from three folders, the commands becomes:
```shell
teklia-dan dataset extract \
--parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes:
```shell
teklia-dan dataset extract \
--use-existing-split \
--train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c
--val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \
--test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615) that are children of **single_pages**, use the following command:
```shell
teklia-dan dataset extract \
--parent 48852284-fc02-41bb-9a42-4458e5a51615 \
--element-type text_zone annotation \
--parent-element-type single_page \
--output data
```
#### Model training
`teklia-dan train` with multiple arguments.
#### Synthetic data generation
`teklia-dan generate` with multiple arguments
# -*- coding: utf-8 -*-
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
......@@ -2,17 +2,17 @@
import argparse
import errno
from dan.datasets.extract.extract_from_arkindex import add_extract_parser
from dan.ocr.line.generate_synthetic import add_generate_parser
from dan.ocr.train import add_train_parser
from dan.datasets import add_dataset_parser
from dan.ocr import add_train_parser
from dan.ocr.line import add_generate_parser
def get_parser():
parser = argparse.ArgumentParser(prog="teklia-dan")
subcommands = parser.add_subparsers(metavar="subcommand")
add_dataset_parser(subcommands)
add_train_parser(subcommands)
add_extract_parser(subcommands)
add_generate_parser(subcommands)
return parser
......
# -*- coding: utf-8 -*-
"""
Preprocess datasets for training.
"""
from dan.datasets.extract import add_extract_parser
def add_dataset_parser(subcommands) -> None:
parser = subcommands.add_parser(
"dataset",
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_extract_parser(subcommands)
# -*- coding: utf-8 -*-
"""
Extract dataset from Arkindex using API.
"""
import pathlib
import uuid
from dan.datasets.extract.extract_from_arkindex import run
MANUAL_SOURCE = "manual"
def parse_worker_version(worker_version_id):
if worker_version_id == MANUAL_SOURCE:
return False
return worker_version_id
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--parent",
type=uuid.UUID,
nargs="+",
help="ID of the parent folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parent-element-type",
type=str,
help="Type of the parent element containing the data.",
required=False,
default="page",
)
parser.add_argument(
"--output",
type=pathlib.Path,
help="Path where the data will be generated.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--load-entities", action="store_true", help="Extract text with their entities"
)
parser.add_argument(
"--tokens",
type=pathlib.Path,
help="Mapping between starting tokens and end tokens. Needed for entities.",
required=False,
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Use the specified folder IDs for the dataset split.",
)
parser.add_argument(
"--train-folder",
type=uuid.UUID,
help="ID of the training folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--val-folder",
type=uuid.UUID,
help="ID of the validation folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--test-folder",
type=uuid.UUID,
help="ID of the testing folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--transcription-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
)
parser.add_argument(
"--entity-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set split size."
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set split size"
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
"""
The arkindex_utils module
======================
"""
import errno
import logging
import sys
from apistar.exceptions import ErrorResponse
def retrieve_corpus(client, corpus_name: str) -> str:
"""
Retrieve the corpus id from the corpus name.
:param client: The arkindex client.
:param corpus_name: The name of the corpus to retrieve.
:return target_corpus: The id of the retrieved corpus.
"""
for corpus in client.request("ListCorpus"):
if corpus["name"] == corpus_name:
target_corpus = corpus["id"]
try:
logging.info(f"Corpus id retrieved: {target_corpus}")
except NameError:
logging.error(f"Corpus {corpus_name} not found")
sys.exit(errno.EINVAL)
return target_corpus
def retrieve_subsets(
client, corpus: str, parents_types: list, parents_names: list
) -> list:
"""
Retrieve the requested subsets.
:param client: The arkindex client.
:param corpus: The id of the retrieved corpus.
:param parents_types: The types of parents of the elements to retrieve.
:param parents_names: The names of parents of the elements to retrieve.
:return subsets: The retrieved subsets.
"""
subsets = []
for parent_type in parents_types:
try:
subsets.extend(
client.request("ListElements", corpus=corpus, type=parent_type)[
"results"
]
)
except ErrorResponse as e:
logging.error(f"{e.content}: {parent_type}")
sys.exit(errno.EINVAL)
# Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets.
if parents_names is not None:
logging.info(f"Retrieving {parents_names} subset(s)")
subsets = [subset for subset in subsets if subset["name"] in parents_names]
else:
logging.info("Retrieving all subsets")
if len(subsets) == 0:
logging.info("No subset found")
return subsets
# -*- coding: utf-8 -*-
"""
The extraction module
======================
"""
import logging
import os
import random
from collections import defaultdict
from pathlib import Path
from typing import List, NamedTuple
import cv2
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
from dan import logger
from dan.datasets.extract.utils import (
insert_token,
parse_tokens,
save_image,
save_json,
save_text,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
Entity = NamedTuple("Entity", offset=int, length=int, label=str)
class ArkindexExtractor:
"""
Extract data from Arkindex
"""
def __init__(
self,
client: ArkindexClient,
folders: list = [],
element_type: list = [],
parent_element_type: list = ["page"],
split_names: list = [],
output: Path = None,
load_entities: bool = None,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
) -> None:
self.client = client
self.element_type = element_type
self.parent_element_type = parent_element_type
self.split_names = split_names
self.output = output
self.load_entities = load_entities
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.get_subsets(folders)
def get_subsets(self, folders: list):
"""
Assign each folder to its split if it's already known.
Assign None if it's unknown.
"""
if self.use_existing_split:
self.subsets = [
(folder, split) for folder, split in zip(folders, self.split_names)
]
else:
self.subsets = [(folder, None) for folder in folders]
def _assign_random_split(self):
"""
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
"""
prob = random.random()
if prob <= self.train_prob:
yield self.split_names[0]
elif prob <= self.train_prob + self.val_prob:
yield self.split_names[1]
else:
yield self.split_names[2]
def get_random_split(self):
return next(self._assign_random_split())
def extract_entities(self, transcription: dict):
entities = self.client.paginate(
"ListTranscriptionEntities",
id=transcription["id"],
worker_version=self.entity_worker_version,
)
if entities is None:
logger.warning(
f"No entities found on transcription ({transcription['id']})."
)
return
return [
Entity(
offset=entity["offset"],
length=entity["length"],
label=entity["entity"]["metas"]["subtype"],
)
for entity in entities
]
def reconstruct_text(self, text: str, entities: List[Entity]):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
count = 0
for entity in entities:
matching_tokens = self.tokens[entity.label]
start_token, end_token = (
matching_tokens["start"],
matching_tokens["end"],
)
text, count = insert_token(
text,
count,
start_token,
end_token,
offset=entity.offset,
length=entity.length,
)
return text
def extract_transcription(
self,
element: dict,
):
"""
Extract the element's transcription.
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = self.client.request(
"ListTranscriptions",
id=element["id"],
worker_version=self.transcription_worker_version,
)
if transcriptions["count"] != 1:
logger.warning(
f"More than one transcription found on element ({element['id']}) with this config."
)
return
transcription = transcriptions["results"].pop()
if self.load_entities:
entities = self.extract_entities(transcription)
return self.reconstruct_text(transcription["text"], entities)
else:
return transcription["text"].strip()
def process_element(
self,
element: dict,
split: str,
):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
"""
text = self.extract_transcription(
element,
)
IMAGES_DIR = "./images/" # Path to the images directory.
LABELS_DIR = "./labels/" # Path to the labels directory.
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL.": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": "",
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--corpus",
type=str,
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
type=str,
help="Type of parents of the elements.",
required=True,
)
parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--parents-names",
nargs="+",
type=str,
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities", action="store_true", help="Extract text without entities"
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test",
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set probability"
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set probability"
)
parser.set_defaults(func=run)
if not text:
logging.warning(f"Skipping {element['id']}")
else:
logging.info(f"Processed {element['type']} {element['id']}")
im_path = os.path.join(
self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg"
)
txt_path = os.path.join(
self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt"
)
save_text(txt_path, text)
try:
image = iio.imread(element["zone"]["url"])
save_image(im_path, image)
except Exception:
logger.error(f"Couldn't retrieve image of element ({element['id']}")
raise
return element["id"]
def process_parent(
self,
parent: dict,
split: str,
):
"""
Extract data from a parent element.
Depending on the given types,
"""
data = defaultdict(list)
if self.element_type == [parent["type"]]:
data[self.element_type[0]] = [
self.process_element(
parent,
split,
)
]
# Extract children elements
else:
for element_type in self.element_type:
for element in self.client.paginate(
"ListElementChildren",
id=parent["id"],
type=element_type,
recursive=True,
):
data[element_type].append(
self.process_element(
element,
split,
)
)
return data
def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels.
for subset_id, subset_split in self.subsets:
page_idx = 0
# Iterate over the pages to create splits at page level.
for parent in tqdm(
self.client.paginate(
"ListElementChildren",
id=subset_id,
type=self.parent_element_type,
recursive=True,
)
):
page_idx += 1
split = subset_split or self.get_random_split()
split_dict[split][parent["id"]] = self.process_parent(
parent=parent,
split=split,
)
save_json(self.output / "split.json", split_dict)
def run(
corpus,
parent,
element_type,
parents_types,
output_dir,
parents_names,
no_entities,
parent_element_type,
output,
load_entities,
tokens,
use_existing_split,
train_folder,
val_folder,
test_folder,
transcription_worker_version,
entity_worker_version,
train_prob,
val_prob,
):
# Get and initialize the parameters.
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)
# Login to arkindex.
client = ArkindexClient(**options_from_env())
corpus = retrieve_corpus(client, corpus)
subsets = retrieve_subsets(client, corpus, parents_types, parents_names)
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
# Iterate over the subsets to find the page images and labels.
for subset in subsets:
if use_existing_split:
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [train_folder, val_folder, test_folder]
else:
folders = parent
os.makedirs(os.path.join(output_dir, IMAGES_DIR, subset["name"]), exist_ok=True)
os.makedirs(os.path.join(output_dir, LABELS_DIR, subset["name"]), exist_ok=True)
if load_entities:
assert tokens, "Please provide the entities to match."
for page in tqdm(
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
),
desc="Set " + subset["name"],
):
# Login to arkindex.
assert (
"ARKINDEX_API_URL" in os.environ
), "The ARKINDEX API URL was not found in your environment."
assert (
"ARKINDEX_API_TOKEN" in os.environ
), "Your API credentials was not found in your environment."
client = ArkindexClient(**options_from_env())
image = iio.imread(page["zone"]["url"])
cv2.imwrite(
os.path.join(
output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg"
),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)
# Create out directories
split_names = ["train", "val", "test"]
for split in split_names:
os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True)
os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True)
tr = client.request(
"ListTranscriptions", id=page["id"], worker_version=None
)["results"]
tr = [one for one in tr if one["worker_version_id"] is None]
assert len(tr) == 1, page["id"]
for one_tr in tr:
ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[
"results"
]
ent = [one for one in ent if one["worker_version_id"] is None]
if len(ent) == 0:
continue
else:
text = one_tr["text"]
new_text = text
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]]
end_token = SEM_MATCHING_TOKENS[start_token]
new_text = (
new_text[: count + e["offset"]]
+ start_token
+ new_text[count + e["offset"] :]
)
count += 1
new_text = (
new_text[: count + e["offset"] + e["length"]]
+ end_token
+ new_text[count + e["offset"] + e["length"] :]
)
count += 1
with open(
os.path.join(
output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt"
),
"w",
) as f:
f.write(new_text)
ArkindexExtractor(
client=client,
folders=folders,
element_type=element_type,
parent_element_type=parent_element_type,
split_names=split_names,
output=output,
load_entities=load_entities,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
).run()
# -*- coding: utf-8 -*-
import json
import cv2
import yaml
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
def insert_token(text, count, start_token, end_token, offset, length):
"""
Insert the given tokens at the right position in the text
"""
text = (
# Text before entity
text[: count + offset]
# Starting token
+ start_token
# Entity
+ text[count + offset : count + 1 + offset + length]
# End token
+ end_token
# Text after entity
+ text[count + 1 + offset + length :]
)
return text, count + 2
def parse_tokens(filename):
with open(filename) as f:
return yaml.safe_load(f)
# -*- coding: utf-8 -*-
import json
import random
import re
import cv2
random.seed(42)
def convert(text):
return int(text) if text.isdigit() else text.lower()
......@@ -14,30 +8,3 @@ def convert(text):
def natural_sort(data):
return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)])
def assign_random_split(train_prob, val_prob):
"""
assuming train_prob + val_prob + test_prob = 1
"""
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "val"
else:
return "test"
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
# -*- coding: utf-8 -*-
"""
Train a new DAN model.
"""
from dan.ocr.document import add_document_parser
from dan.ocr.line import add_line_parser
def add_train_parser(subcommands) -> None:
parser = subcommands.add_parser(
"train",
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_line_parser(subcommands)
add_document_parser(subcommands)
# -*- coding: utf-8 -*-
"""
Train a DAN model at document level.
"""
from dan.ocr.document.train import run
def add_document_parser(subcommands) -> None:
parser = subcommands.add_parser(
"document",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
......@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config
def add_document_parser(subcommands) -> None:
parser = subcommands.add_parser(
"document",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......
# -*- coding: utf-8 -*-
from dan.ocr.line.generate_synthetic import run as run_generate
from dan.ocr.line.train import run as run_train
def add_generate_parser(subcommands) -> None:
parser = subcommands.add_parser(
"generate",
description=__doc__,
help="Generate synthetic data to train DAN models.",
)
parser.set_defaults(func=run_generate)
def add_line_parser(subcommands) -> None:
parser = subcommands.add_parser(
"line",
description=__doc__,
help="Train a DAN model at line level.",
)
parser.set_defaults(func=run_train)
......@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
from dan.transforms import line_aug_config
def add_generate_parser(subcommands) -> None:
parser = subcommands.add_parser(
"generate",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......
......@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
from dan.transforms import line_aug_config
def add_line_parser(subcommands) -> None:
parser = subcommands.add_parser(
"line",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......
# -*- coding: utf-8 -*-
from dan.ocr.document.train import add_document_parser
from dan.ocr.line.train import add_line_parser
def add_train_parser(subcommands) -> None:
parser = subcommands.add_parser(
"train",
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_line_parser(subcommands)
add_document_parser(subcommands)
arkindex-client==1.0.11
editdistance==0.6.0
fontTools==4.29.1
editdistance==0.6.1
fontTools==4.38.0
imageio==2.16.0
networkx==2.6.3
numpy==1.22.3
opencv-python==4.5.5.64
PyYAML==6.0
tensorboard==0.2.1
tensorboard==2.8.0
torch==1.11.0
torchvision==0.12.0
tqdm==4.62.3
# -*- coding: utf-8 -*-
import os
import pytest
@pytest.fixture(autouse=True)
def setup_environment(responses):
"""Setup needed environment variables"""
# Allow accessing remote API schemas
# defaulting to the prod environment
schema_url = os.environ.get(
"ARKINDEX_API_SCHEMA_URL",
"https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
)
responses.add_passthru(schema_url)
# Set schema url in environment
os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url