From 30524dfe4eac089c75778ec3caadbd6c6035d1ef Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 16 Nov 2022 13:20:08 +0000 Subject: [PATCH] implement extraction --- README.md | 108 ++++- dan/__init__.py | 8 + dan/cli.py | 4 +- dan/datasets/extract/arkindex_utils.py | 66 --- dan/datasets/extract/extract_from_arkindex.py | 431 ++++++++++++++---- dan/datasets/extract/utils.py | 102 ++--- dan/datasets/utils.py | 33 -- dan/ocr/document/train.py | 4 + dan/ocr/line/generate_synthetic.py | 4 + dan/ocr/line/train.py | 4 + dan/ocr/train.py | 4 + requirements.txt | 2 +- 12 files changed, 519 insertions(+), 251 deletions(-) delete mode 100644 dan/datasets/extract/arkindex_utils.py diff --git a/README.md b/README.md index 9b29f204..af2f58c5 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,9 @@ We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page a We obtained the following results: | | CER (%) | WER (%) | LOER (%) | mAP_cer (%) | -|:-----------------------:|---------|:-------:|:--------:|-------------| -| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 | -| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 | +| :---------------------: | ------- | :-----: | :------: | ----------- | +| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 | +| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 | | READ 2016 (double page) | 3.69 | 14.20 | 4.60 | 93.92 | @@ -92,3 +92,105 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In ```python text, confidence_scores = model.predict(image, confidences=True) ``` + +### Commands + +This package provides three subcommands. To get more information about any subcommand, use the `--help` option. + +#### Data extraction from Arkindex + +Use the `teklia-dan extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model. +The available arguments are + +| Parameter | Description | Type | Default | +| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- | +| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | str/uuid | | +| `--element-type` | Type of the elements to extract. You may specify multiple types. | str | | +| `--output` | Folder where the data will be generated. Must exist. | Path | | +| `--load-entities` | Extract text with their entities. Needed for NER tasks. | bool | False | +| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | Path | | +| `--use-existing-split` | Use the specified folder IDs for the dataset split. | bool | | +| `--train-folder` | ID of the training folder to import from Arkindex. | uuid | | +| `--val-folder` | ID of the validation folder to import from Arkindex. | uuid | | +| `--test-folder` | ID of the training folder to import from Arkindex. | uuid | | +| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering. | str/uuid | | +| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | str/uuid | | +| `--train-prob` | Training set split size | float | 0,7 | +| `--val-prob` | Validation set split size | float | 0,15 | + +The `--tokens` argument expects a file with the following format. +```yaml +--- +INTITULE: + start: ⓘ + end: â’¾ +DATE: + start: ⓘ + end: â’¹ +COTE_SERIE: + start: ⓘ + end: Ⓢ +ANALYSE_COMPL.: + start: ⓘ + end: â’¸ +PRECISIONS_SUR_COTE: + start: ⓘ + end: â“… +COTE_ARTICLE: + start: ⓘ + end: â’¶ +CLASSEMENT: + start: ⓘ + end: â“ +``` + + +To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command: +```shell +teklia-dan extract \ + --parent 665e84ea-bd97-4912-91b0-1f4a844287ff \ + --element-type page \ + --output data \ + --load-entities \ + --tokens tokens.yml +``` +with `tokens.yml` having the content described just above. + + +To do the same but only use the data from three folders, the commands becomes: +```shell +teklia-dan extract \ + --parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \ + --element-type page \ + --output data \ + --load-entities \ + --tokens tokens.yml +``` + +To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes: +```shell +teklia-dan extract \ + --use-existing-split \ + --train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c + --val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \ + --test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \ + --element-type page \ + --output data \ + --load-entities \ + --tokens tokens.yml +``` + +To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615), use the following command: +```shell +teklia-dan extract \ + --parent 48852284-fc02-41bb-9a42-4458e5a51615 \ + --element-type text_zone annotation \ + --output data +``` + +#### Model training +`teklia-dan train` with multiple arguments. + +#### Synthetic data generation +`teklia-dan generate` with multiple arguments + diff --git a/dan/__init__.py b/dan/__init__.py index e69de29b..b74e8889 100644 --- a/dan/__init__.py +++ b/dan/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s/%(name)s: %(message)s", +) +logger = logging.getLogger(__name__) diff --git a/dan/cli.py b/dan/cli.py index fff057b9..8b06e4e7 100644 --- a/dan/cli.py +++ b/dan/cli.py @@ -8,11 +8,11 @@ from dan.ocr.train import add_train_parser def get_parser(): - parser = argparse.ArgumentParser(prog="TEKLIA DAN training") + parser = argparse.ArgumentParser(prog="teklia-dan") subcommands = parser.add_subparsers(metavar="subcommand") - add_train_parser(subcommands) add_extract_parser(subcommands) + add_train_parser(subcommands) add_generate_parser(subcommands) return parser diff --git a/dan/datasets/extract/arkindex_utils.py b/dan/datasets/extract/arkindex_utils.py deleted file mode 100644 index d9d2e065..00000000 --- a/dan/datasets/extract/arkindex_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -# -*- coding: utf-8 -*- - -""" - The arkindex_utils module - ====================== -""" - -import errno -import logging -import sys - -from apistar.exceptions import ErrorResponse - - -def retrieve_corpus(client, corpus_name: str) -> str: - """ - Retrieve the corpus id from the corpus name. - :param client: The arkindex client. - :param corpus_name: The name of the corpus to retrieve. - :return target_corpus: The id of the retrieved corpus. - """ - for corpus in client.request("ListCorpus"): - if corpus["name"] == corpus_name: - target_corpus = corpus["id"] - try: - logging.info(f"Corpus id retrieved: {target_corpus}") - except NameError: - logging.error(f"Corpus {corpus_name} not found") - sys.exit(errno.EINVAL) - - return target_corpus - - -def retrieve_subsets( - client, corpus: str, parents_types: list, parents_names: list -) -> list: - """ - Retrieve the requested subsets. - :param client: The arkindex client. - :param corpus: The id of the retrieved corpus. - :param parents_types: The types of parents of the elements to retrieve. - :param parents_names: The names of parents of the elements to retrieve. - :return subsets: The retrieved subsets. - """ - subsets = [] - for parent_type in parents_types: - try: - subsets.extend( - client.request("ListElements", corpus=corpus, type=parent_type)[ - "results" - ] - ) - except ErrorResponse as e: - logging.error(f"{e.content}: {parent_type}") - sys.exit(errno.EINVAL) - # Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets. - if parents_names is not None: - logging.info(f"Retrieving {parents_names} subset(s)") - subsets = [subset for subset in subsets if subset["name"] in parents_names] - else: - logging.info("Retrieving all subsets") - - if len(subsets) == 0: - logging.info("No subset found") - - return subsets diff --git a/dan/datasets/extract/extract_from_arkindex.py b/dan/datasets/extract/extract_from_arkindex.py index c9c1c2c1..5c3818be 100644 --- a/dan/datasets/extract/extract_from_arkindex.py +++ b/dan/datasets/extract/extract_from_arkindex.py @@ -1,127 +1,386 @@ # -*- coding: utf-8 -*- """ - The extraction module - ====================== +Extract dataset from Arkindex using API. """ +from collections import defaultdict import logging import os +import pathlib +import random +import uuid -import cv2 import imageio.v2 as iio from arkindex import ArkindexClient, options_from_env from tqdm import tqdm -from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets -from dan.datasets.extract.utils import get_cli_args - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +from dan.datasets.extract.utils import ( + insert_token, + parse_tokens, + save_image, + save_json, + save_text, ) +from dan import logger + -IMAGES_DIR = "./images/" # Path to the images directory. -LABELS_DIR = "./labels/" # Path to the labels directory. +IMAGES_DIR = "images" # Subpath to the images directory. +LABELS_DIR = "labels" # Subpath to the labels directory. +MANUAL_SOURCE = "manual" -# Layout string to token -SEM_MATCHING_TOKENS_STR = { - "INTITULE": "ⓘ", - "DATE": "â““", - "COTE_SERIE": "â“¢", - "ANALYSE_COMPL.": "â“’", - "PRECISIONS_SUR_COTE": "ⓟ", - "COTE_ARTICLE": "â“", -} -# Layout begin-token to end-token -SEM_MATCHING_TOKENS = {"ⓘ": "â’¾", "â““": "â’¹", "â“¢": "Ⓢ", "â“’": "â’¸", "ⓟ": "â“…", "â“": "â’¶"} +def parse_worker_version(worker_version_id): + if worker_version_id == MANUAL_SOURCE: + return False + return worker_version_id def add_extract_parser(subcommands) -> None: parser = subcommands.add_parser( "extract", description=__doc__, + help=__doc__, + ) + + # Required arguments. + parser.add_argument( + "--parent", + type=uuid.UUID, + nargs="+", + help="ID of the parent folder to import from Arkindex.", + required=False, + ) + parser.add_argument( + "--element-type", + nargs="+", + type=str, + help="Type of elements to retrieve", + required=True, + ) + parser.add_argument( + "--output", + type=pathlib.Path, + help="Path where the data will be generated.", + required=True, + ) + + # Optional arguments. + parser.add_argument( + "--load-entities", action="store_true", help="Extract text with their entities" + ) + parser.add_argument( + "--tokens", + type=pathlib.Path, + help="Mapping between starting tokens and end tokens. Needed for entities.", + required=False, + ) + + parser.add_argument( + "--use-existing-split", + action="store_true", + help="Use the specified folder IDs for the dataset split.", + ) + + parser.add_argument( + "--train-folder", + type=uuid.UUID, + help="ID of the training folder to import from Arkindex.", + required=False, + ) + parser.add_argument( + "--val-folder", + type=uuid.UUID, + help="ID of the validation folder to import from Arkindex.", + required=False, + ) + parser.add_argument( + "--test-folder", + type=uuid.UUID, + help="ID of the testing folder to import from Arkindex.", + required=False, ) + + parser.add_argument( + "--transcription-worker-version", + type=parse_worker_version, + help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.", + required=False, + default=MANUAL_SOURCE, + ) + parser.add_argument( + "--entity-worker-version", + type=parse_worker_version, + help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.", + required=False, + default=MANUAL_SOURCE, + ) + + parser.add_argument( + "--train-prob", type=float, default=0.7, help="Training set split size." + ) + + parser.add_argument( + "--val-prob", type=float, default=0.15, help="Validation set split size" + ) + parser.set_defaults(func=run) -def run(): - args = get_cli_args() +class ArkindexExtractor: + """ + Extract data from Arkindex + """ - # Get and initialize the parameters. - os.makedirs(IMAGES_DIR, exist_ok=True) - os.makedirs(LABELS_DIR, exist_ok=True) + def __init__( + self, + client, + folders, + element_type, + split_names, + output, + load_entities, + tokens, + use_existing_split, + transcription_worker_version, + entity_worker_version, + train_prob, + val_prob, + ) -> None: + self.client = client + self.element_type = element_type + self.split_names = split_names + self.output = output + self.load_entities = load_entities + self.tokens = parse_tokens(tokens) if self.load_entities else None + self.use_existing_split = use_existing_split + self.transcription_worker_version = transcription_worker_version + self.entity_worker_version = entity_worker_version + self.train_prob = train_prob + self.val_prob = val_prob - # Login to arkindex. - client = ArkindexClient(**options_from_env()) + self.get_subsets(folders) - corpus = retrieve_corpus(client, args.corpus) - subsets = retrieve_subsets(client, corpus, args.parents_types, args.parents_names) + def get_subsets(self, folders): + if self.use_existing_split: + self.subsets = [ + (folder, split) for folder, split in zip(folders, self.split_names) + ] + else: + self.subsets = [(folder, None) for folder in folders] - # Iterate over the subsets to find the page images and labels. - for subset in subsets: + def assign_random_split(self): + """ + assuming train_prob + valid_prob + test_prob = 1 + """ + prob = random.random() + if prob <= self.train_prob: + return self.split_names[0] + elif prob <= self.train_prob + self.val_prob: + return self.split_names[1] + else: + return self.split_names[2] - os.makedirs( - os.path.join(args.output_dir, IMAGES_DIR, subset["name"]), exist_ok=True - ) - os.makedirs( - os.path.join(args.output_dir, LABELS_DIR, subset["name"]), exist_ok=True + def extract_transcription( + self, + element, + ): + transcriptions = self.client.request( + "ListTranscriptions", + id=element["id"], + worker_version=self.transcription_worker_version, ) - for page in tqdm( - client.paginate( - "ListElementChildren", id=subset["id"], type="page", recursive=True - ), - desc="Set " + subset["name"], - ): - - image = iio.imread(page["zone"]["url"]) - cv2.imwrite( - os.path.join( - args.output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg" - ), - cv2.cvtColor(image, cv2.COLOR_BGR2RGB), + if transcriptions["count"] != 1: + logger.warning( + f"More than one transcription found on element ({element['id']}) with this config." ) + return - tr = client.request( - "ListTranscriptions", id=page["id"], worker_version=None - )["results"] - tr = [one for one in tr if one["worker_version_id"] is None] - assert len(tr) == 1, page["id"] + transcription = transcriptions["results"].pop() + if self.load_entities: + entities = self.client.request( + "ListTranscriptionEntities", + id=transcription["id"], + worker_version=self.entity_worker_version, + ) + if entities["count"] == 0: + logger.warning( + f"No entities found on transcription ({transcription['id']})." + ) + return + else: + text = transcription["text"] - for one_tr in tr: - ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[ - "results" + count = 0 + for entity in entities["results"]: + start_token, end_token = self.tokens[ + entity["entity"]["metas"]["subtype"] ] - ent = [one for one in ent if one["worker_version_id"] is None] - if len(ent) == 0: - continue - else: - text = one_tr["text"] + text, count = insert_token( + text, + count, + start_token, + end_token, + offset=entity["offset"], + length=entity["length"], + ) + else: + text = transcription["text"].strip() + return text - new_text = text - count = 0 - for e in ent: - start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]] - end_token = SEM_MATCHING_TOKENS[start_token] - new_text = ( - new_text[: count + e["offset"]] - + start_token - + new_text[count + e["offset"] :] + def process_element( + self, + element, + split, + ): + text = self.extract_transcription( + element, + ) + + if not text: + logging.warning( + f"Skipping {element['id']} (zero or multiple transcriptions with worker_version=None)" + ) + else: + logging.info(f"Processed {element['type']} {element['id']}") + + im_path = os.path.join( + self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg" + ) + txt_path = os.path.join( + self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt" + ) + + save_text(txt_path, text) + try: + image = iio.imread(element["zone"]["url"]) + save_image(im_path, image) + except Exception: + logger.error(f"Couldn't retrieve image of element ({element['id']}") + pass + return element["id"] + + def process_page( + self, + page, + split, + ): + # Extract only pages + data = defaultdict(list) + if self.element_type == ["page"]: + data["page"] = [ + self.process_element( + page, + split, ) - count += 1 - new_text = ( - new_text[: count + e["offset"] + e["length"]] - + end_token - + new_text[count + e["offset"] + e["length"] :] + ] + # Extract page's children elements (text_zone, text_line) + else: + for element_type in self.element_type: + for element in self.client.paginate( + "ListElementChildren", + id=page["id"], + type=element_type, + recursive=True, + ): + data[element_type].append( + self.process_element( + element, + split, + ) + ) + return data + + def run(self): + split_dict = defaultdict(dict) + # Iterate over the subsets to find the page images and labels. + for subset_id, subset_split in self.subsets: + page_idx = 0 + # Iterate over the pages to create splits at page level. + for page in tqdm( + self.client.paginate( + "ListElementChildren", id=subset_id, type="page", recursive=True + ) + ): + page_idx += 1 + split = subset_split or self.assign_random_split() + + split_dict[split][page["id"]] = self.process_page( + page=page, + split=split, ) - count += 1 - - with open( - os.path.join( - args.output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt" - ), - "w", - ) as f: - f.write(new_text) + + save_json(self.output / "split.json", split_dict) + + +def run( + parent, + element_type, + output, + load_entities, + tokens, + use_existing_split, + train_folder, + val_folder, + test_folder, + transcription_worker_version, + entity_worker_version, + train_prob, + val_prob, +): + assert ( + use_existing_split or parent + ), "One of `--use-existing-split` and `--parent` must be set" + + if use_existing_split: + assert ( + train_folder + ), "If you use an existing split, you must specify the training folder." + assert ( + val_folder + ), "If you use an existing split, you must specify the validation folder." + assert ( + test_folder + ), "If you use an existing split, you must specify the testing folder." + folders = [train_folder, val_folder, test_folder] + else: + folders = parent + + if load_entities: + assert tokens, "Please provide the entities to match." + + # Get and initialize the parameters. + os.makedirs(IMAGES_DIR, exist_ok=True) + os.makedirs(LABELS_DIR, exist_ok=True) + + # Login to arkindex. + assert ( + "ARKINDEX_API_URL" in os.environ + ), "The ARKINDEX API URL was not found in your environment." + assert ( + "ARKINDEX_API_TOKEN" in os.environ + ), "Your API credentials was not found in your environment." + client = ArkindexClient(**options_from_env()) + + # Create out directories + split_names = ["train", "val", "test"] + for split in split_names: + os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True) + os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True) + + ArkindexExtractor( + client=client, + folders=folders, + element_type=element_type, + split_names=split_names, + output=output, + load_entities=load_entities, + tokens=tokens, + use_existing_split=use_existing_split, + transcription_worker_version=transcription_worker_version, + entity_worker_version=entity_worker_version, + train_prob=train_prob, + val_prob=val_prob, + ).run() diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 77d15804..4b1aa93f 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -1,74 +1,56 @@ # -*- coding: utf-8 -*- +import yaml +import json +import random -""" - The utils module - ====================== -""" +import cv2 -import argparse +random.seed(42) -def get_cli_args(): +def assign_random_split(train_prob, val_prob): """ - Get the command-line arguments. - :return: The command-line arguments. + assuming train_prob + valid_prob + test_prob = 1 """ - parser = argparse.ArgumentParser( - description="Arkindex DAN Training Label Generation" - ) + prob = random.random() + if prob <= train_prob: + return "train" + elif prob <= train_prob + val_prob: + return "valid" + else: + return "test" - # Required arguments. - parser.add_argument( - "--corpus", - type=str, - help="Name of the corpus from which the data will be retrieved.", - required=True, - ) - parser.add_argument( - "--element-type", - nargs="+", - type=str, - help="Type of elements to retrieve", - required=True, - ) - parser.add_argument( - "--parents-types", - nargs="+", - type=str, - help="Type of parents of the elements.", - required=True, - ) - parser.add_argument( - "--output-dir", - type=str, - help="Path to the output directory.", - required=True, - ) - # Optional arguments. - parser.add_argument( - "--parents-names", - nargs="+", - type=str, - help="Names of parents of the elements.", - default=None, - ) - parser.add_argument( - "--no-entities", action="store_true", help="Extract text without entities" - ) +def save_text(path, text): + with open(path, "w") as f: + f.write(text) - parser.add_argument( - "--use-existing-split", - action="store_true", - help="Do not partition pages into train/val/test", - ) - parser.add_argument( - "--train-prob", type=float, default=0.7, help="Training set probability" - ) +def save_image(path, image): + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + - parser.add_argument( - "--val-prob", type=float, default=0.15, help="Validation set probability" +def save_json(path, dict): + with open(path, "w") as outfile: + json.dump(dict, outfile, indent=4) + + +def insert_token(text, count, start_token, end_token, offset, length): + text = ( + # Text before entity + text[: count + offset] + # Starting token + + start_token + # Entity + + text[count + offset : count + 1 + offset + length] + # End token + + end_token + # Text after entity + + text[count + 1 + offset + length :] ) + return text, count + 2 + - return parser.parse_args() +def parse_tokens(filename): + with open(filename) as f: + return yaml.safe_load(f) diff --git a/dan/datasets/utils.py b/dan/datasets/utils.py index 6422357c..c911cc9b 100644 --- a/dan/datasets/utils.py +++ b/dan/datasets/utils.py @@ -1,12 +1,6 @@ # -*- coding: utf-8 -*- -import json -import random import re -import cv2 - -random.seed(42) - def convert(text): return int(text) if text.isdigit() else text.lower() @@ -14,30 +8,3 @@ def convert(text): def natural_sort(data): return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)]) - - -def assign_random_split(train_prob, val_prob): - """ - assuming train_prob + val_prob + test_prob = 1 - """ - prob = random.random() - if prob <= train_prob: - return "train" - elif prob <= train_prob + val_prob: - return "val" - else: - return "test" - - -def save_text(path, text): - with open(path, "w") as f: - f.write(text) - - -def save_image(path, image): - cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - - -def save_json(path, dict): - with open(path, "w") as outfile: - json.dump(dict, outfile, indent=4) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 1f8216a9..3166a543 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +""" +Train a DAN model at document level. +""" import random import numpy as np @@ -18,6 +21,7 @@ def add_document_parser(subcommands) -> None: parser = subcommands.add_parser( "document", description=__doc__, + help=__doc__, ) parser.set_defaults(func=run) diff --git a/dan/ocr/line/generate_synthetic.py b/dan/ocr/line/generate_synthetic.py index 435e19ed..67aa5a24 100644 --- a/dan/ocr/line/generate_synthetic.py +++ b/dan/ocr/line/generate_synthetic.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +""" +Generate synthetic data to train DAN models +""" import random import numpy as np @@ -18,6 +21,7 @@ def add_generate_parser(subcommands) -> None: parser = subcommands.add_parser( "generate", description=__doc__, + help=__doc__, ) parser.set_defaults(func=run) diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py index 9e092fd9..b49e3df0 100644 --- a/dan/ocr/line/train.py +++ b/dan/ocr/line/train.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +""" +Train a DAN model at line level. +""" import random import numpy as np @@ -18,6 +21,7 @@ def add_line_parser(subcommands) -> None: parser = subcommands.add_parser( "line", description=__doc__, + help=__doc__, ) parser.set_defaults(func=run) diff --git a/dan/ocr/train.py b/dan/ocr/train.py index dda43ba8..375656b3 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +""" +Train a new DAN model. +""" from dan.ocr.document.train import add_document_parser from dan.ocr.line.train import add_line_parser @@ -8,6 +11,7 @@ def add_train_parser(subcommands) -> None: parser = subcommands.add_parser( "train", description=__doc__, + help=__doc__, ) subcommands = parser.add_subparsers(metavar="subcommand") diff --git a/requirements.txt b/requirements.txt index 27758aa6..6288bf12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ networkx==2.6.3 numpy==1.22.3 opencv-python==4.5.5.64 PyYAML==6.0 -tensorboard==0.2.1 +tensorboard==2.8.0 torch==1.11.0 torchvision==0.12.0 tqdm==4.62.3 -- GitLab