From e316fcc13af905df79325fc211119add6a686d67 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Tue, 22 Nov 2022 11:01:53 +0000 Subject: [PATCH] Implement extraction command --- .gitlab-ci.yml | 2 +- MANIFEST.in | 2 + README.md | 109 ++++- dan/__init__.py | 8 + dan/cli.py | 8 +- dan/datasets/__init__.py | 17 + dan/datasets/extract/__init__.py | 115 +++++ dan/datasets/extract/arkindex_utils.py | 66 --- dan/datasets/extract/extract_from_arkindex.py | 453 ++++++++++++------ dan/datasets/extract/utils.py | 43 ++ dan/datasets/utils.py | 33 -- dan/ocr/__init__.py | 19 + dan/ocr/document/__init__.py | 15 + dan/ocr/document/train.py | 9 - dan/ocr/line/__init__.py | 22 + dan/ocr/line/generate_synthetic.py | 9 - dan/ocr/line/train.py | 9 - dan/ocr/train.py | 16 - requirements.txt | 2 +- tests/conftest.py | 20 + tests/test_extract.py | 120 +++++ tox.ini | 12 + 22 files changed, 799 insertions(+), 310 deletions(-) create mode 100644 MANIFEST.in delete mode 100644 dan/datasets/extract/arkindex_utils.py create mode 100644 dan/datasets/extract/utils.py delete mode 100644 dan/ocr/train.py create mode 100644 tests/conftest.py create mode 100644 tests/test_extract.py create mode 100644 tox.ini diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e35b054d..81e95c76 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -32,7 +32,7 @@ bump-python-deps: - schedules script: - - devops python-deps requirements.txt tests-requirements.txt + - devops python-deps requirements.txt release-notes: stage: deploy diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..fd959fa8 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include requirements.txt +include VERSION diff --git a/README.md b/README.md index 9b29f204..7c4b6b6c 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,9 @@ We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page a We obtained the following results: | | CER (%) | WER (%) | LOER (%) | mAP_cer (%) | -|:-----------------------:|---------|:-------:|:--------:|-------------| -| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 | -| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 | +| :---------------------: | ------- | :-----: | :------: | ----------- | +| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 | +| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 | | READ 2016 (double page) | 3.69 | 14.20 | 4.60 | 93.92 | @@ -92,3 +92,106 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In ```python text, confidence_scores = model.predict(image, confidences=True) ``` + +### Commands + +This package provides three subcommands. To get more information about any subcommand, use the `--help` option. + +#### Data extraction from Arkindex + +Use the `teklia-dan dataset extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model. +The available arguments are + +| Parameter | Description | Type | Default | +| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- | +| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str/uuid` | | +| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | +| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` | +| `--output` | Folder where the data will be generated. | `Path` | | +| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` | +| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | | +| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | | +| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | | +| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | | +| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | | +| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering. | `str/uuid` | | +| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | `str/uuid` | | +| `--train-prob` | Training set split size | `float` | `0.7` | +| `--val-prob` | Validation set split size | `float` | `0.15` | + +The `--tokens` argument expects a file with the following format. +```yaml +--- +INTITULE: + start: ⓘ + end: â’¾ +DATE: + start: â““ + end: â’¹ +COTE_SERIE: + start: â“¢ + end: Ⓢ +ANALYSE_COMPL.: + start: â“’ + end: â’¸ +PRECISIONS_SUR_COTE: + start: â“Ÿ + end: â“… +COTE_ARTICLE: + start: â“ + end: â’¶ +CLASSEMENT: + start: â“› + end: â“ +``` + + +To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command: +```shell +teklia-dan dataset extract \ + --parent 665e84ea-bd97-4912-91b0-1f4a844287ff \ + --element-type page \ + --output data \ + --load-entities \ + --tokens tokens.yml +``` +with `tokens.yml` having the content described just above. + + +To do the same but only use the data from three folders, the commands becomes: +```shell +teklia-dan dataset extract \ + --parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \ + --element-type page \ + --output data \ + --load-entities \ + --tokens tokens.yml +``` + +To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes: +```shell +teklia-dan dataset extract \ + --use-existing-split \ + --train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c + --val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \ + --test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \ + --element-type page \ + --output data \ + --load-entities \ + --tokens tokens.yml +``` + +To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615) that are children of **single_pages**, use the following command: +```shell +teklia-dan dataset extract \ + --parent 48852284-fc02-41bb-9a42-4458e5a51615 \ + --element-type text_zone annotation \ + --parent-element-type single_page \ + --output data +``` + +#### Model training +`teklia-dan train` with multiple arguments. + +#### Synthetic data generation +`teklia-dan generate` with multiple arguments diff --git a/dan/__init__.py b/dan/__init__.py index e69de29b..b74e8889 100644 --- a/dan/__init__.py +++ b/dan/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s/%(name)s: %(message)s", +) +logger = logging.getLogger(__name__) diff --git a/dan/cli.py b/dan/cli.py index e633ceb4..ddf244f4 100644 --- a/dan/cli.py +++ b/dan/cli.py @@ -2,17 +2,17 @@ import argparse import errno -from dan.datasets.extract.extract_from_arkindex import add_extract_parser -from dan.ocr.line.generate_synthetic import add_generate_parser -from dan.ocr.train import add_train_parser +from dan.datasets import add_dataset_parser +from dan.ocr import add_train_parser +from dan.ocr.line import add_generate_parser def get_parser(): parser = argparse.ArgumentParser(prog="teklia-dan") subcommands = parser.add_subparsers(metavar="subcommand") + add_dataset_parser(subcommands) add_train_parser(subcommands) - add_extract_parser(subcommands) add_generate_parser(subcommands) return parser diff --git a/dan/datasets/__init__.py b/dan/datasets/__init__.py index e69de29b..889e11cf 100644 --- a/dan/datasets/__init__.py +++ b/dan/datasets/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +""" +Preprocess datasets for training. +""" + +from dan.datasets.extract import add_extract_parser + + +def add_dataset_parser(subcommands) -> None: + parser = subcommands.add_parser( + "dataset", + description=__doc__, + help=__doc__, + ) + subcommands = parser.add_subparsers(metavar="subcommand") + + add_extract_parser(subcommands) diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index e69de29b..76521846 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +""" +Extract dataset from Arkindex using API. +""" + +import pathlib +import uuid + +from dan.datasets.extract.extract_from_arkindex import run + +MANUAL_SOURCE = "manual" + + +def parse_worker_version(worker_version_id): + if worker_version_id == MANUAL_SOURCE: + return False + return worker_version_id + + +def add_extract_parser(subcommands) -> None: + parser = subcommands.add_parser( + "extract", + description=__doc__, + help=__doc__, + ) + + # Required arguments. + parser.add_argument( + "--parent", + type=uuid.UUID, + nargs="+", + help="ID of the parent folder to import from Arkindex.", + required=False, + ) + parser.add_argument( + "--element-type", + nargs="+", + type=str, + help="Type of elements to retrieve", + required=True, + ) + parser.add_argument( + "--parent-element-type", + type=str, + help="Type of the parent element containing the data.", + required=False, + default="page", + ) + parser.add_argument( + "--output", + type=pathlib.Path, + help="Path where the data will be generated.", + required=True, + ) + + # Optional arguments. + parser.add_argument( + "--load-entities", action="store_true", help="Extract text with their entities" + ) + parser.add_argument( + "--tokens", + type=pathlib.Path, + help="Mapping between starting tokens and end tokens. Needed for entities.", + required=False, + ) + + parser.add_argument( + "--use-existing-split", + action="store_true", + help="Use the specified folder IDs for the dataset split.", + ) + + parser.add_argument( + "--train-folder", + type=uuid.UUID, + help="ID of the training folder to import from Arkindex.", + required=False, + ) + parser.add_argument( + "--val-folder", + type=uuid.UUID, + help="ID of the validation folder to import from Arkindex.", + required=False, + ) + parser.add_argument( + "--test-folder", + type=uuid.UUID, + help="ID of the testing folder to import from Arkindex.", + required=False, + ) + + parser.add_argument( + "--transcription-worker-version", + type=parse_worker_version, + help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.", + required=False, + default=MANUAL_SOURCE, + ) + parser.add_argument( + "--entity-worker-version", + type=parse_worker_version, + help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.", + required=False, + default=MANUAL_SOURCE, + ) + + parser.add_argument( + "--train-prob", type=float, default=0.7, help="Training set split size." + ) + + parser.add_argument( + "--val-prob", type=float, default=0.15, help="Validation set split size" + ) + + parser.set_defaults(func=run) diff --git a/dan/datasets/extract/arkindex_utils.py b/dan/datasets/extract/arkindex_utils.py deleted file mode 100644 index d9d2e065..00000000 --- a/dan/datasets/extract/arkindex_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -# -*- coding: utf-8 -*- - -""" - The arkindex_utils module - ====================== -""" - -import errno -import logging -import sys - -from apistar.exceptions import ErrorResponse - - -def retrieve_corpus(client, corpus_name: str) -> str: - """ - Retrieve the corpus id from the corpus name. - :param client: The arkindex client. - :param corpus_name: The name of the corpus to retrieve. - :return target_corpus: The id of the retrieved corpus. - """ - for corpus in client.request("ListCorpus"): - if corpus["name"] == corpus_name: - target_corpus = corpus["id"] - try: - logging.info(f"Corpus id retrieved: {target_corpus}") - except NameError: - logging.error(f"Corpus {corpus_name} not found") - sys.exit(errno.EINVAL) - - return target_corpus - - -def retrieve_subsets( - client, corpus: str, parents_types: list, parents_names: list -) -> list: - """ - Retrieve the requested subsets. - :param client: The arkindex client. - :param corpus: The id of the retrieved corpus. - :param parents_types: The types of parents of the elements to retrieve. - :param parents_names: The names of parents of the elements to retrieve. - :return subsets: The retrieved subsets. - """ - subsets = [] - for parent_type in parents_types: - try: - subsets.extend( - client.request("ListElements", corpus=corpus, type=parent_type)[ - "results" - ] - ) - except ErrorResponse as e: - logging.error(f"{e.content}: {parent_type}") - sys.exit(errno.EINVAL) - # Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets. - if parents_names is not None: - logging.info(f"Retrieving {parents_names} subset(s)") - subsets = [subset for subset in subsets if subset["name"] in parents_names] - else: - logging.info("Retrieving all subsets") - - if len(subsets) == 0: - logging.info("No subset found") - - return subsets diff --git a/dan/datasets/extract/extract_from_arkindex.py b/dan/datasets/extract/extract_from_arkindex.py index a135cb7d..c4e1fe48 100644 --- a/dan/datasets/extract/extract_from_arkindex.py +++ b/dan/datasets/extract/extract_from_arkindex.py @@ -1,184 +1,319 @@ # -*- coding: utf-8 -*- -""" - The extraction module - ====================== -""" - import logging import os +import random +from collections import defaultdict +from pathlib import Path +from typing import List, NamedTuple -import cv2 import imageio.v2 as iio from arkindex import ArkindexClient, options_from_env from tqdm import tqdm -from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +from dan import logger +from dan.datasets.extract.utils import ( + insert_token, + parse_tokens, + save_image, + save_json, + save_text, ) +IMAGES_DIR = "images" # Subpath to the images directory. +LABELS_DIR = "labels" # Subpath to the labels directory. + +Entity = NamedTuple("Entity", offset=int, length=int, label=str) + + +class ArkindexExtractor: + """ + Extract data from Arkindex + """ + + def __init__( + self, + client: ArkindexClient, + folders: list = [], + element_type: list = [], + parent_element_type: list = ["page"], + split_names: list = [], + output: Path = None, + load_entities: bool = None, + tokens: Path = None, + use_existing_split: bool = None, + transcription_worker_version: str = None, + entity_worker_version: str = None, + train_prob: float = None, + val_prob: float = None, + ) -> None: + self.client = client + self.element_type = element_type + self.parent_element_type = parent_element_type + self.split_names = split_names + self.output = output + self.load_entities = load_entities + self.tokens = parse_tokens(tokens) if self.load_entities else None + self.use_existing_split = use_existing_split + self.transcription_worker_version = transcription_worker_version + self.entity_worker_version = entity_worker_version + self.train_prob = train_prob + self.val_prob = val_prob + + self.get_subsets(folders) + + def get_subsets(self, folders: list): + """ + Assign each folder to its split if it's already known. + Assign None if it's unknown. + """ + if self.use_existing_split: + self.subsets = [ + (folder, split) for folder, split in zip(folders, self.split_names) + ] + else: + self.subsets = [(folder, None) for folder in folders] + + def _assign_random_split(self): + """ + Yields a randomly chosen split for an element. + Assumes that train_prob + valid_prob + test_prob = 1 + """ + prob = random.random() + if prob <= self.train_prob: + yield self.split_names[0] + elif prob <= self.train_prob + self.val_prob: + yield self.split_names[1] + else: + yield self.split_names[2] + + def get_random_split(self): + return next(self._assign_random_split()) + + def extract_entities(self, transcription: dict): + entities = self.client.paginate( + "ListTranscriptionEntities", + id=transcription["id"], + worker_version=self.entity_worker_version, + ) + if entities is None: + logger.warning( + f"No entities found on transcription ({transcription['id']})." + ) + return + return [ + Entity( + offset=entity["offset"], + length=entity["length"], + label=entity["entity"]["metas"]["subtype"], + ) + for entity in entities + ] + + def reconstruct_text(self, text: str, entities: List[Entity]): + """ + Insert tokens delimiting the start/end of each entity on the transcription. + """ + count = 0 + for entity in entities: + matching_tokens = self.tokens[entity.label] + start_token, end_token = ( + matching_tokens["start"], + matching_tokens["end"], + ) + text, count = insert_token( + text, + count, + start_token, + end_token, + offset=entity.offset, + length=entity.length, + ) + return text + + def extract_transcription( + self, + element: dict, + ): + """ + Extract the element's transcription. + If the entities are needed, they are added to the transcription using tokens. + """ + transcriptions = self.client.request( + "ListTranscriptions", + id=element["id"], + worker_version=self.transcription_worker_version, + ) + if transcriptions["count"] != 1: + logger.warning( + f"More than one transcription found on element ({element['id']}) with this config." + ) + return + + transcription = transcriptions["results"].pop() + if self.load_entities: + entities = self.extract_entities(transcription) + return self.reconstruct_text(transcription["text"], entities) + else: + return transcription["text"].strip() + + def process_element( + self, + element: dict, + split: str, + ): + """ + Extract an element's data and save it to disk. + The output path is directly related to the split of the element. + """ + text = self.extract_transcription( + element, + ) -IMAGES_DIR = "./images/" # Path to the images directory. -LABELS_DIR = "./labels/" # Path to the labels directory. - -# Layout string to token -SEM_MATCHING_TOKENS_STR = { - "INTITULE": "ⓘ", - "DATE": "â““", - "COTE_SERIE": "â“¢", - "ANALYSE_COMPL.": "â“’", - "PRECISIONS_SUR_COTE": "â“Ÿ", - "COTE_ARTICLE": "â“", -} - -# Layout begin-token to end-token -SEM_MATCHING_TOKENS = {"ⓘ": "â’¾", "â““": "â’¹", "â“¢": "Ⓢ", "â“’": "â’¸", "â“Ÿ": "â“…", "â“": "â’¶"} - - -def add_extract_parser(subcommands) -> None: - parser = subcommands.add_parser( - "extract", - description=__doc__, - help=__doc__, - ) - # Required arguments. - parser.add_argument( - "--corpus", - type=str, - help="Name of the corpus from which the data will be retrieved.", - required=True, - ) - parser.add_argument( - "--element-type", - nargs="+", - type=str, - help="Type of elements to retrieve", - required=True, - ) - parser.add_argument( - "--parents-types", - nargs="+", - type=str, - help="Type of parents of the elements.", - required=True, - ) - parser.add_argument( - "--output-dir", - type=str, - help="Path to the output directory.", - required=True, - ) - - # Optional arguments. - parser.add_argument( - "--parents-names", - nargs="+", - type=str, - help="Names of parents of the elements.", - default=None, - ) - parser.add_argument( - "--no-entities", action="store_true", help="Extract text without entities" - ) - - parser.add_argument( - "--use-existing-split", - action="store_true", - help="Do not partition pages into train/val/test", - ) - - parser.add_argument( - "--train-prob", type=float, default=0.7, help="Training set probability" - ) - - parser.add_argument( - "--val-prob", type=float, default=0.15, help="Validation set probability" - ) - parser.set_defaults(func=run) + if not text: + logging.warning(f"Skipping {element['id']}") + else: + logging.info(f"Processed {element['type']} {element['id']}") + + im_path = os.path.join( + self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg" + ) + txt_path = os.path.join( + self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt" + ) + + save_text(txt_path, text) + try: + image = iio.imread(element["zone"]["url"]) + save_image(im_path, image) + except Exception: + logger.error(f"Couldn't retrieve image of element ({element['id']}") + raise + return element["id"] + + def process_parent( + self, + parent: dict, + split: str, + ): + """ + Extract data from a parent element. + Depending on the given types, + """ + data = defaultdict(list) + if self.element_type == [parent["type"]]: + data[self.element_type[0]] = [ + self.process_element( + parent, + split, + ) + ] + # Extract children elements + else: + for element_type in self.element_type: + for element in self.client.paginate( + "ListElementChildren", + id=parent["id"], + type=element_type, + recursive=True, + ): + data[element_type].append( + self.process_element( + element, + split, + ) + ) + return data + + def run(self): + split_dict = defaultdict(dict) + # Iterate over the subsets to find the page images and labels. + for subset_id, subset_split in self.subsets: + page_idx = 0 + # Iterate over the pages to create splits at page level. + for parent in tqdm( + self.client.paginate( + "ListElementChildren", + id=subset_id, + type=self.parent_element_type, + recursive=True, + ) + ): + page_idx += 1 + split = subset_split or self.get_random_split() + + split_dict[split][parent["id"]] = self.process_parent( + parent=parent, + split=split, + ) + + save_json(self.output / "split.json", split_dict) def run( - corpus, + parent, element_type, - parents_types, - output_dir, - parents_names, - no_entities, + parent_element_type, + output, + load_entities, + tokens, use_existing_split, + train_folder, + val_folder, + test_folder, + transcription_worker_version, + entity_worker_version, train_prob, val_prob, ): - # Get and initialize the parameters. - os.makedirs(IMAGES_DIR, exist_ok=True) - os.makedirs(LABELS_DIR, exist_ok=True) - - # Login to arkindex. - client = ArkindexClient(**options_from_env()) - - corpus = retrieve_corpus(client, corpus) - subsets = retrieve_subsets(client, corpus, parents_types, parents_names) + assert ( + use_existing_split or parent + ), "One of `--use-existing-split` and `--parent` must be set" - # Iterate over the subsets to find the page images and labels. - for subset in subsets: + if use_existing_split: + assert ( + train_folder + ), "If you use an existing split, you must specify the training folder." + assert ( + val_folder + ), "If you use an existing split, you must specify the validation folder." + assert ( + test_folder + ), "If you use an existing split, you must specify the testing folder." + folders = [train_folder, val_folder, test_folder] + else: + folders = parent - os.makedirs(os.path.join(output_dir, IMAGES_DIR, subset["name"]), exist_ok=True) - os.makedirs(os.path.join(output_dir, LABELS_DIR, subset["name"]), exist_ok=True) + if load_entities: + assert tokens, "Please provide the entities to match." - for page in tqdm( - client.paginate( - "ListElementChildren", id=subset["id"], type="page", recursive=True - ), - desc="Set " + subset["name"], - ): + # Login to arkindex. + assert ( + "ARKINDEX_API_URL" in os.environ + ), "The ARKINDEX API URL was not found in your environment." + assert ( + "ARKINDEX_API_TOKEN" in os.environ + ), "Your API credentials was not found in your environment." + client = ArkindexClient(**options_from_env()) - image = iio.imread(page["zone"]["url"]) - cv2.imwrite( - os.path.join( - output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg" - ), - cv2.cvtColor(image, cv2.COLOR_BGR2RGB), - ) + # Create out directories + split_names = ["train", "val", "test"] + for split in split_names: + os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True) + os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True) - tr = client.request( - "ListTranscriptions", id=page["id"], worker_version=None - )["results"] - tr = [one for one in tr if one["worker_version_id"] is None] - assert len(tr) == 1, page["id"] - - for one_tr in tr: - ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[ - "results" - ] - ent = [one for one in ent if one["worker_version_id"] is None] - if len(ent) == 0: - continue - else: - text = one_tr["text"] - - new_text = text - count = 0 - for e in ent: - start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]] - end_token = SEM_MATCHING_TOKENS[start_token] - new_text = ( - new_text[: count + e["offset"]] - + start_token - + new_text[count + e["offset"] :] - ) - count += 1 - new_text = ( - new_text[: count + e["offset"] + e["length"]] - + end_token - + new_text[count + e["offset"] + e["length"] :] - ) - count += 1 - - with open( - os.path.join( - output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt" - ), - "w", - ) as f: - f.write(new_text) + ArkindexExtractor( + client=client, + folders=folders, + element_type=element_type, + parent_element_type=parent_element_type, + split_names=split_names, + output=output, + load_entities=load_entities, + tokens=tokens, + use_existing_split=use_existing_split, + transcription_worker_version=transcription_worker_version, + entity_worker_version=entity_worker_version, + train_prob=train_prob, + val_prob=val_prob, + ).run() diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py new file mode 100644 index 00000000..4a582228 --- /dev/null +++ b/dan/datasets/extract/utils.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +import json + +import cv2 +import yaml + + +def save_text(path, text): + with open(path, "w") as f: + f.write(text) + + +def save_image(path, image): + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + +def save_json(path, dict): + with open(path, "w") as outfile: + json.dump(dict, outfile, indent=4) + + +def insert_token(text, count, start_token, end_token, offset, length): + """ + Insert the given tokens at the right position in the text + """ + text = ( + # Text before entity + text[: count + offset] + # Starting token + + start_token + # Entity + + text[count + offset : count + 1 + offset + length] + # End token + + end_token + # Text after entity + + text[count + 1 + offset + length :] + ) + return text, count + 2 + + +def parse_tokens(filename): + with open(filename) as f: + return yaml.safe_load(f) diff --git a/dan/datasets/utils.py b/dan/datasets/utils.py index 6422357c..c911cc9b 100644 --- a/dan/datasets/utils.py +++ b/dan/datasets/utils.py @@ -1,12 +1,6 @@ # -*- coding: utf-8 -*- -import json -import random import re -import cv2 - -random.seed(42) - def convert(text): return int(text) if text.isdigit() else text.lower() @@ -14,30 +8,3 @@ def convert(text): def natural_sort(data): return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)]) - - -def assign_random_split(train_prob, val_prob): - """ - assuming train_prob + val_prob + test_prob = 1 - """ - prob = random.random() - if prob <= train_prob: - return "train" - elif prob <= train_prob + val_prob: - return "val" - else: - return "test" - - -def save_text(path, text): - with open(path, "w") as f: - f.write(text) - - -def save_image(path, image): - cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - - -def save_json(path, dict): - with open(path, "w") as outfile: - json.dump(dict, outfile, indent=4) diff --git a/dan/ocr/__init__.py b/dan/ocr/__init__.py index e69de29b..3d18b6fe 100644 --- a/dan/ocr/__init__.py +++ b/dan/ocr/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +""" +Train a new DAN model. +""" + +from dan.ocr.document import add_document_parser +from dan.ocr.line import add_line_parser + + +def add_train_parser(subcommands) -> None: + parser = subcommands.add_parser( + "train", + description=__doc__, + help=__doc__, + ) + subcommands = parser.add_subparsers(metavar="subcommand") + + add_line_parser(subcommands) + add_document_parser(subcommands) diff --git a/dan/ocr/document/__init__.py b/dan/ocr/document/__init__.py index e69de29b..375a1327 100644 --- a/dan/ocr/document/__init__.py +++ b/dan/ocr/document/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +""" +Train a DAN model at document level. +""" + +from dan.ocr.document.train import run + + +def add_document_parser(subcommands) -> None: + parser = subcommands.add_parser( + "document", + description=__doc__, + help=__doc__, + ) + parser.set_defaults(func=run) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 09df425a..64ca1d31 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler from dan.transforms import aug_config -def add_document_parser(subcommands) -> None: - parser = subcommands.add_parser( - "document", - description=__doc__, - help=__doc__, - ) - parser.set_defaults(func=run) - - def train_and_test(rank, params): torch.manual_seed(0) torch.cuda.manual_seed(0) diff --git a/dan/ocr/line/__init__.py b/dan/ocr/line/__init__.py index e69de29b..603a060d 100644 --- a/dan/ocr/line/__init__.py +++ b/dan/ocr/line/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +from dan.ocr.line.generate_synthetic import run as run_generate +from dan.ocr.line.train import run as run_train + + +def add_generate_parser(subcommands) -> None: + parser = subcommands.add_parser( + "generate", + description=__doc__, + help="Generate synthetic data to train DAN models.", + ) + parser.set_defaults(func=run_generate) + + +def add_line_parser(subcommands) -> None: + parser = subcommands.add_parser( + "line", + description=__doc__, + help="Train a DAN model at line level.", + ) + parser.set_defaults(func=run_train) diff --git a/dan/ocr/line/generate_synthetic.py b/dan/ocr/line/generate_synthetic.py index b034d5b0..de378031 100644 --- a/dan/ocr/line/generate_synthetic.py +++ b/dan/ocr/line/generate_synthetic.py @@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler from dan.transforms import line_aug_config -def add_generate_parser(subcommands) -> None: - parser = subcommands.add_parser( - "generate", - description=__doc__, - help=__doc__, - ) - parser.set_defaults(func=run) - - def train_and_test(rank, params): torch.manual_seed(0) torch.cuda.manual_seed(0) diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py index ba3d0b4f..7d2bab45 100644 --- a/dan/ocr/line/train.py +++ b/dan/ocr/line/train.py @@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler from dan.transforms import line_aug_config -def add_line_parser(subcommands) -> None: - parser = subcommands.add_parser( - "line", - description=__doc__, - help=__doc__, - ) - parser.set_defaults(func=run) - - def train_and_test(rank, params): torch.manual_seed(0) torch.cuda.manual_seed(0) diff --git a/dan/ocr/train.py b/dan/ocr/train.py deleted file mode 100644 index 37176069..00000000 --- a/dan/ocr/train.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- - -from dan.ocr.document.train import add_document_parser -from dan.ocr.line.train import add_line_parser - - -def add_train_parser(subcommands) -> None: - parser = subcommands.add_parser( - "train", - description=__doc__, - help=__doc__, - ) - subcommands = parser.add_subparsers(metavar="subcommand") - - add_line_parser(subcommands) - add_document_parser(subcommands) diff --git a/requirements.txt b/requirements.txt index 27758aa6..6288bf12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ networkx==2.6.3 numpy==1.22.3 opencv-python==4.5.5.64 PyYAML==6.0 -tensorboard==0.2.1 +tensorboard==2.8.0 torch==1.11.0 torchvision==0.12.0 tqdm==4.62.3 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..21a29dfa --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +import os + +import pytest + + +@pytest.fixture(autouse=True) +def setup_environment(responses): + """Setup needed environment variables""" + + # Allow accessing remote API schemas + # defaulting to the prod environment + schema_url = os.environ.get( + "ARKINDEX_API_SCHEMA_URL", + "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json", + ) + responses.add_passthru(schema_url) + + # Set schema url in environment + os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url diff --git a/tests/test_extract.py b/tests/test_extract.py new file mode 100644 index 00000000..3e3552d2 --- /dev/null +++ b/tests/test_extract.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +import pytest +from arkindex.mock import MockApiClient + +from dan.datasets.extract.extract_from_arkindex import ArkindexExtractor, Entity +from dan.datasets.extract.utils import insert_token + + +@pytest.fixture +def arkindex_extractor(): + return ArkindexExtractor( + client=MockApiClient(), split_names=["train", "val", "test"] + ) + + +@pytest.mark.parametrize( + "text,count,offset,length,expected", + ( + ("n°1 16 janvier 1611", 0, 0, 3, "ⓘn°1 â’¾16 janvier 1611"), + ("ⓘn°1 â’¾16 janvier 1611", 2, 4, 15, "ⓘn°1 Ⓘⓘ16 janvier 1611â’¾"), + ), +) +def test_insert_token(text, count, offset, length, expected): + start_token, end_token = "ⓘ", "â’¾" + assert ( + insert_token(text, count, start_token, end_token, offset, length)[0] == expected + ) + + +@pytest.mark.parametrize( + "text,entities,expected", + ( + ( + "n°1 16 janvier 1611", + [ + Entity(offset=0, length=3, label="P"), + Entity(offset=4, length=15, label="D"), + ], + "â“Ÿn°1 â“…â““16 janvier 1611â’¹", + ), + ), +) +def test_reconstruct_text(arkindex_extractor, text, entities, expected): + arkindex_extractor.tokens = { + "P": {"start": "â“Ÿ", "end": "â“…"}, + "D": {"start": "â““", "end": "â’¹"}, + } + assert arkindex_extractor.reconstruct_text(text, entities) == expected + + +@pytest.mark.parametrize( + "text,offset,length,label,expected", + ( + (" n°1 16 janvier 1611 ", None, None, None, "n°1 16 janvier 1611"), + ("n°1 16 janvier 1611", 0, 3, "P", "â“Ÿn°1 â“…16 janvier 1611"), + ), +) +def test_extract_transcription( + arkindex_extractor, text, offset, length, label, expected +): + element = {"id": "element_id"} + transcription = {"id": "transcription_id", "text": text} + arkindex_extractor.client.add_response( + "ListTranscriptions", + id="element_id", + worker_version=None, + response={"count": 1, "results": [transcription]}, + ) + + if label: + arkindex_extractor.load_entities = True + arkindex_extractor.tokens = { + "P": {"start": "â“Ÿ", "end": "â“…"}, + } + arkindex_extractor.client.add_response( + "ListTranscriptionEntities", + id="transcription_id", + worker_version=None, + response=[ + { + "entity": {"id": "entity_id", "metas": {"subtype": label}}, + "offset": offset, + "length": length, + "worker_version": None, + "worker_run_id": None, + } + ], + ) + + assert arkindex_extractor.extract_transcription(element) == expected + + +@pytest.mark.parametrize( + "offset,length,label", + ((0, 3, "P"),), +) +def test_extract_entities(arkindex_extractor, offset, length, label): + transcription = {"id": "transcription_id"} + arkindex_extractor.tokens = { + "P": {"start": "â“Ÿ", "end": "â“…"}, + } + arkindex_extractor.client.add_response( + "ListTranscriptionEntities", + id="transcription_id", + worker_version=None, + response=[ + { + "entity": {"id": "entity_id", "metas": {"subtype": label}}, + "offset": offset, + "length": length, + "worker_version": None, + "worker_run_id": None, + } + ], + ) + + assert arkindex_extractor.extract_entities(transcription) == [ + Entity(offset=offset, length=length, label=label) + ] diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..aedc8cf0 --- /dev/null +++ b/tox.ini @@ -0,0 +1,12 @@ +[tox] +envlist = teklia-dan + +[testenv] +passenv = ARKINDEX_API_SCHEMA_URL +commands = + pytest {posargs} + +deps = + pytest + pytest-responses + -rrequirements.txt -- GitLab