Skip to content
Snippets Groups Projects
Commit e316fcc1 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Solene Tarride
Browse files

Implement extraction command

parent 825bb4ab
No related branches found
No related tags found
1 merge request!11Implement extraction command
Showing
with 667 additions and 310 deletions
......@@ -32,7 +32,7 @@ bump-python-deps:
- schedules
script:
- devops python-deps requirements.txt tests-requirements.txt
- devops python-deps requirements.txt
release-notes:
stage: deploy
......
include requirements.txt
include VERSION
......@@ -21,9 +21,9 @@ We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page a
We obtained the following results:
| | CER (%) | WER (%) | LOER (%) | mAP_cer (%) |
|:-----------------------:|---------|:-------:|:--------:|-------------|
| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 |
| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 |
| :---------------------: | ------- | :-----: | :------: | ----------- |
| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 |
| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 |
| READ 2016 (double page) | 3.69 | 14.20 | 4.60 | 93.92 |
......@@ -92,3 +92,106 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In
```python
text, confidence_scores = model.predict(image, confidences=True)
```
### Commands
This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
#### Data extraction from Arkindex
Use the `teklia-dan dataset extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model.
The available arguments are
| Parameter | Description | Type | Default |
| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- |
| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str/uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `Path` | |
| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` |
| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | |
| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | |
| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | |
| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering. | `str/uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | `str/uuid` | |
| `--train-prob` | Training set split size | `float` | `0.7` |
| `--val-prob` | Validation set split size | `float` | `0.15` |
The `--tokens` argument expects a file with the following format.
```yaml
---
INTITULE:
start: ⓘ
end: Ⓘ
DATE:
start: ⓓ
end: Ⓓ
COTE_SERIE:
start: ⓢ
end: Ⓢ
ANALYSE_COMPL.:
start: ⓒ
end: Ⓒ
PRECISIONS_SUR_COTE:
start: ⓟ
end: Ⓟ
COTE_ARTICLE:
start: ⓐ
end: Ⓐ
CLASSEMENT:
start: ⓛ
end: Ⓛ
```
To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command:
```shell
teklia-dan dataset extract \
--parent 665e84ea-bd97-4912-91b0-1f4a844287ff \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
with `tokens.yml` having the content described just above.
To do the same but only use the data from three folders, the commands becomes:
```shell
teklia-dan dataset extract \
--parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes:
```shell
teklia-dan dataset extract \
--use-existing-split \
--train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c
--val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \
--test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615) that are children of **single_pages**, use the following command:
```shell
teklia-dan dataset extract \
--parent 48852284-fc02-41bb-9a42-4458e5a51615 \
--element-type text_zone annotation \
--parent-element-type single_page \
--output data
```
#### Model training
`teklia-dan train` with multiple arguments.
#### Synthetic data generation
`teklia-dan generate` with multiple arguments
# -*- coding: utf-8 -*-
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
......@@ -2,17 +2,17 @@
import argparse
import errno
from dan.datasets.extract.extract_from_arkindex import add_extract_parser
from dan.ocr.line.generate_synthetic import add_generate_parser
from dan.ocr.train import add_train_parser
from dan.datasets import add_dataset_parser
from dan.ocr import add_train_parser
from dan.ocr.line import add_generate_parser
def get_parser():
parser = argparse.ArgumentParser(prog="teklia-dan")
subcommands = parser.add_subparsers(metavar="subcommand")
add_dataset_parser(subcommands)
add_train_parser(subcommands)
add_extract_parser(subcommands)
add_generate_parser(subcommands)
return parser
......
# -*- coding: utf-8 -*-
"""
Preprocess datasets for training.
"""
from dan.datasets.extract import add_extract_parser
def add_dataset_parser(subcommands) -> None:
parser = subcommands.add_parser(
"dataset",
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_extract_parser(subcommands)
# -*- coding: utf-8 -*-
"""
Extract dataset from Arkindex using API.
"""
import pathlib
import uuid
from dan.datasets.extract.extract_from_arkindex import run
MANUAL_SOURCE = "manual"
def parse_worker_version(worker_version_id):
if worker_version_id == MANUAL_SOURCE:
return False
return worker_version_id
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--parent",
type=uuid.UUID,
nargs="+",
help="ID of the parent folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parent-element-type",
type=str,
help="Type of the parent element containing the data.",
required=False,
default="page",
)
parser.add_argument(
"--output",
type=pathlib.Path,
help="Path where the data will be generated.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--load-entities", action="store_true", help="Extract text with their entities"
)
parser.add_argument(
"--tokens",
type=pathlib.Path,
help="Mapping between starting tokens and end tokens. Needed for entities.",
required=False,
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Use the specified folder IDs for the dataset split.",
)
parser.add_argument(
"--train-folder",
type=uuid.UUID,
help="ID of the training folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--val-folder",
type=uuid.UUID,
help="ID of the validation folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--test-folder",
type=uuid.UUID,
help="ID of the testing folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--transcription-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
)
parser.add_argument(
"--entity-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set split size."
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set split size"
)
parser.set_defaults(func=run)
# -*- coding: utf-8 -*-
"""
The arkindex_utils module
======================
"""
import errno
import logging
import sys
from apistar.exceptions import ErrorResponse
def retrieve_corpus(client, corpus_name: str) -> str:
"""
Retrieve the corpus id from the corpus name.
:param client: The arkindex client.
:param corpus_name: The name of the corpus to retrieve.
:return target_corpus: The id of the retrieved corpus.
"""
for corpus in client.request("ListCorpus"):
if corpus["name"] == corpus_name:
target_corpus = corpus["id"]
try:
logging.info(f"Corpus id retrieved: {target_corpus}")
except NameError:
logging.error(f"Corpus {corpus_name} not found")
sys.exit(errno.EINVAL)
return target_corpus
def retrieve_subsets(
client, corpus: str, parents_types: list, parents_names: list
) -> list:
"""
Retrieve the requested subsets.
:param client: The arkindex client.
:param corpus: The id of the retrieved corpus.
:param parents_types: The types of parents of the elements to retrieve.
:param parents_names: The names of parents of the elements to retrieve.
:return subsets: The retrieved subsets.
"""
subsets = []
for parent_type in parents_types:
try:
subsets.extend(
client.request("ListElements", corpus=corpus, type=parent_type)[
"results"
]
)
except ErrorResponse as e:
logging.error(f"{e.content}: {parent_type}")
sys.exit(errno.EINVAL)
# Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets.
if parents_names is not None:
logging.info(f"Retrieving {parents_names} subset(s)")
subsets = [subset for subset in subsets if subset["name"] in parents_names]
else:
logging.info("Retrieving all subsets")
if len(subsets) == 0:
logging.info("No subset found")
return subsets
# -*- coding: utf-8 -*-
"""
The extraction module
======================
"""
import logging
import os
import random
from collections import defaultdict
from pathlib import Path
from typing import List, NamedTuple
import cv2
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
from dan import logger
from dan.datasets.extract.utils import (
insert_token,
parse_tokens,
save_image,
save_json,
save_text,
)
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
Entity = NamedTuple("Entity", offset=int, length=int, label=str)
class ArkindexExtractor:
"""
Extract data from Arkindex
"""
def __init__(
self,
client: ArkindexClient,
folders: list = [],
element_type: list = [],
parent_element_type: list = ["page"],
split_names: list = [],
output: Path = None,
load_entities: bool = None,
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: str = None,
entity_worker_version: str = None,
train_prob: float = None,
val_prob: float = None,
) -> None:
self.client = client
self.element_type = element_type
self.parent_element_type = parent_element_type
self.split_names = split_names
self.output = output
self.load_entities = load_entities
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.get_subsets(folders)
def get_subsets(self, folders: list):
"""
Assign each folder to its split if it's already known.
Assign None if it's unknown.
"""
if self.use_existing_split:
self.subsets = [
(folder, split) for folder, split in zip(folders, self.split_names)
]
else:
self.subsets = [(folder, None) for folder in folders]
def _assign_random_split(self):
"""
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
"""
prob = random.random()
if prob <= self.train_prob:
yield self.split_names[0]
elif prob <= self.train_prob + self.val_prob:
yield self.split_names[1]
else:
yield self.split_names[2]
def get_random_split(self):
return next(self._assign_random_split())
def extract_entities(self, transcription: dict):
entities = self.client.paginate(
"ListTranscriptionEntities",
id=transcription["id"],
worker_version=self.entity_worker_version,
)
if entities is None:
logger.warning(
f"No entities found on transcription ({transcription['id']})."
)
return
return [
Entity(
offset=entity["offset"],
length=entity["length"],
label=entity["entity"]["metas"]["subtype"],
)
for entity in entities
]
def reconstruct_text(self, text: str, entities: List[Entity]):
"""
Insert tokens delimiting the start/end of each entity on the transcription.
"""
count = 0
for entity in entities:
matching_tokens = self.tokens[entity.label]
start_token, end_token = (
matching_tokens["start"],
matching_tokens["end"],
)
text, count = insert_token(
text,
count,
start_token,
end_token,
offset=entity.offset,
length=entity.length,
)
return text
def extract_transcription(
self,
element: dict,
):
"""
Extract the element's transcription.
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = self.client.request(
"ListTranscriptions",
id=element["id"],
worker_version=self.transcription_worker_version,
)
if transcriptions["count"] != 1:
logger.warning(
f"More than one transcription found on element ({element['id']}) with this config."
)
return
transcription = transcriptions["results"].pop()
if self.load_entities:
entities = self.extract_entities(transcription)
return self.reconstruct_text(transcription["text"], entities)
else:
return transcription["text"].strip()
def process_element(
self,
element: dict,
split: str,
):
"""
Extract an element's data and save it to disk.
The output path is directly related to the split of the element.
"""
text = self.extract_transcription(
element,
)
IMAGES_DIR = "./images/" # Path to the images directory.
LABELS_DIR = "./labels/" # Path to the labels directory.
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL.": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": "",
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--corpus",
type=str,
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
type=str,
help="Type of parents of the elements.",
required=True,
)
parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--parents-names",
nargs="+",
type=str,
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities", action="store_true", help="Extract text without entities"
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test",
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set probability"
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set probability"
)
parser.set_defaults(func=run)
if not text:
logging.warning(f"Skipping {element['id']}")
else:
logging.info(f"Processed {element['type']} {element['id']}")
im_path = os.path.join(
self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg"
)
txt_path = os.path.join(
self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt"
)
save_text(txt_path, text)
try:
image = iio.imread(element["zone"]["url"])
save_image(im_path, image)
except Exception:
logger.error(f"Couldn't retrieve image of element ({element['id']}")
raise
return element["id"]
def process_parent(
self,
parent: dict,
split: str,
):
"""
Extract data from a parent element.
Depending on the given types,
"""
data = defaultdict(list)
if self.element_type == [parent["type"]]:
data[self.element_type[0]] = [
self.process_element(
parent,
split,
)
]
# Extract children elements
else:
for element_type in self.element_type:
for element in self.client.paginate(
"ListElementChildren",
id=parent["id"],
type=element_type,
recursive=True,
):
data[element_type].append(
self.process_element(
element,
split,
)
)
return data
def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels.
for subset_id, subset_split in self.subsets:
page_idx = 0
# Iterate over the pages to create splits at page level.
for parent in tqdm(
self.client.paginate(
"ListElementChildren",
id=subset_id,
type=self.parent_element_type,
recursive=True,
)
):
page_idx += 1
split = subset_split or self.get_random_split()
split_dict[split][parent["id"]] = self.process_parent(
parent=parent,
split=split,
)
save_json(self.output / "split.json", split_dict)
def run(
corpus,
parent,
element_type,
parents_types,
output_dir,
parents_names,
no_entities,
parent_element_type,
output,
load_entities,
tokens,
use_existing_split,
train_folder,
val_folder,
test_folder,
transcription_worker_version,
entity_worker_version,
train_prob,
val_prob,
):
# Get and initialize the parameters.
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)
# Login to arkindex.
client = ArkindexClient(**options_from_env())
corpus = retrieve_corpus(client, corpus)
subsets = retrieve_subsets(client, corpus, parents_types, parents_names)
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
# Iterate over the subsets to find the page images and labels.
for subset in subsets:
if use_existing_split:
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [train_folder, val_folder, test_folder]
else:
folders = parent
os.makedirs(os.path.join(output_dir, IMAGES_DIR, subset["name"]), exist_ok=True)
os.makedirs(os.path.join(output_dir, LABELS_DIR, subset["name"]), exist_ok=True)
if load_entities:
assert tokens, "Please provide the entities to match."
for page in tqdm(
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
),
desc="Set " + subset["name"],
):
# Login to arkindex.
assert (
"ARKINDEX_API_URL" in os.environ
), "The ARKINDEX API URL was not found in your environment."
assert (
"ARKINDEX_API_TOKEN" in os.environ
), "Your API credentials was not found in your environment."
client = ArkindexClient(**options_from_env())
image = iio.imread(page["zone"]["url"])
cv2.imwrite(
os.path.join(
output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg"
),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)
# Create out directories
split_names = ["train", "val", "test"]
for split in split_names:
os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True)
os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True)
tr = client.request(
"ListTranscriptions", id=page["id"], worker_version=None
)["results"]
tr = [one for one in tr if one["worker_version_id"] is None]
assert len(tr) == 1, page["id"]
for one_tr in tr:
ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[
"results"
]
ent = [one for one in ent if one["worker_version_id"] is None]
if len(ent) == 0:
continue
else:
text = one_tr["text"]
new_text = text
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]]
end_token = SEM_MATCHING_TOKENS[start_token]
new_text = (
new_text[: count + e["offset"]]
+ start_token
+ new_text[count + e["offset"] :]
)
count += 1
new_text = (
new_text[: count + e["offset"] + e["length"]]
+ end_token
+ new_text[count + e["offset"] + e["length"] :]
)
count += 1
with open(
os.path.join(
output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt"
),
"w",
) as f:
f.write(new_text)
ArkindexExtractor(
client=client,
folders=folders,
element_type=element_type,
parent_element_type=parent_element_type,
split_names=split_names,
output=output,
load_entities=load_entities,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
).run()
# -*- coding: utf-8 -*-
import json
import cv2
import yaml
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
def insert_token(text, count, start_token, end_token, offset, length):
"""
Insert the given tokens at the right position in the text
"""
text = (
# Text before entity
text[: count + offset]
# Starting token
+ start_token
# Entity
+ text[count + offset : count + 1 + offset + length]
# End token
+ end_token
# Text after entity
+ text[count + 1 + offset + length :]
)
return text, count + 2
def parse_tokens(filename):
with open(filename) as f:
return yaml.safe_load(f)
# -*- coding: utf-8 -*-
import json
import random
import re
import cv2
random.seed(42)
def convert(text):
return int(text) if text.isdigit() else text.lower()
......@@ -14,30 +8,3 @@ def convert(text):
def natural_sort(data):
return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)])
def assign_random_split(train_prob, val_prob):
"""
assuming train_prob + val_prob + test_prob = 1
"""
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "val"
else:
return "test"
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
# -*- coding: utf-8 -*-
"""
Train a new DAN model.
"""
from dan.ocr.document import add_document_parser
from dan.ocr.line import add_line_parser
def add_train_parser(subcommands) -> None:
parser = subcommands.add_parser(
"train",
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_line_parser(subcommands)
add_document_parser(subcommands)
# -*- coding: utf-8 -*-
"""
Train a DAN model at document level.
"""
from dan.ocr.document.train import run
def add_document_parser(subcommands) -> None:
parser = subcommands.add_parser(
"document",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
......@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config
def add_document_parser(subcommands) -> None:
parser = subcommands.add_parser(
"document",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......
# -*- coding: utf-8 -*-
from dan.ocr.line.generate_synthetic import run as run_generate
from dan.ocr.line.train import run as run_train
def add_generate_parser(subcommands) -> None:
parser = subcommands.add_parser(
"generate",
description=__doc__,
help="Generate synthetic data to train DAN models.",
)
parser.set_defaults(func=run_generate)
def add_line_parser(subcommands) -> None:
parser = subcommands.add_parser(
"line",
description=__doc__,
help="Train a DAN model at line level.",
)
parser.set_defaults(func=run_train)
......@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
from dan.transforms import line_aug_config
def add_generate_parser(subcommands) -> None:
parser = subcommands.add_parser(
"generate",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......
......@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
from dan.transforms import line_aug_config
def add_line_parser(subcommands) -> None:
parser = subcommands.add_parser(
"line",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......
# -*- coding: utf-8 -*-
from dan.ocr.document.train import add_document_parser
from dan.ocr.line.train import add_line_parser
def add_train_parser(subcommands) -> None:
parser = subcommands.add_parser(
"train",
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_line_parser(subcommands)
add_document_parser(subcommands)
......@@ -6,7 +6,7 @@ networkx==2.6.3
numpy==1.22.3
opencv-python==4.5.5.64
PyYAML==6.0
tensorboard==0.2.1
tensorboard==2.8.0
torch==1.11.0
torchvision==0.12.0
tqdm==4.62.3
# -*- coding: utf-8 -*-
import os
import pytest
@pytest.fixture(autouse=True)
def setup_environment(responses):
"""Setup needed environment variables"""
# Allow accessing remote API schemas
# defaulting to the prod environment
schema_url = os.environ.get(
"ARKINDEX_API_SCHEMA_URL",
"https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
)
responses.add_passthru(schema_url)
# Set schema url in environment
os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment