From 30524dfe4eac089c75778ec3caadbd6c6035d1ef Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 16 Nov 2022 13:20:08 +0000
Subject: [PATCH] implement extraction

---
 README.md                                     | 108 ++++-
 dan/__init__.py                               |   8 +
 dan/cli.py                                    |   4 +-
 dan/datasets/extract/arkindex_utils.py        |  66 ---
 dan/datasets/extract/extract_from_arkindex.py | 431 ++++++++++++++----
 dan/datasets/extract/utils.py                 | 102 ++---
 dan/datasets/utils.py                         |  33 --
 dan/ocr/document/train.py                     |   4 +
 dan/ocr/line/generate_synthetic.py            |   4 +
 dan/ocr/line/train.py                         |   4 +
 dan/ocr/train.py                              |   4 +
 requirements.txt                              |   2 +-
 12 files changed, 519 insertions(+), 251 deletions(-)
 delete mode 100644 dan/datasets/extract/arkindex_utils.py

diff --git a/README.md b/README.md
index 9b29f204..af2f58c5 100644
--- a/README.md
+++ b/README.md
@@ -21,9 +21,9 @@ We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page a
 We obtained the following results:
 
 |                         | CER (%) | WER (%) | LOER (%) | mAP_cer (%) |
-|:-----------------------:|---------|:-------:|:--------:|-------------|
-|       RIMES (single page)      | 4.54    |  11.85  |   3.82   | 93.74       |
-|     READ 2016 (single page)    | 3.53    |  13.33  |   5.94   | 92.57       |
+| :---------------------: | ------- | :-----: | :------: | ----------- |
+|   RIMES (single page)   | 4.54    |  11.85  |   3.82   | 93.74       |
+| READ 2016 (single page) | 3.53    |  13.33  |   5.94   | 92.57       |
 | READ 2016 (double page) | 3.69    |  14.20  |   4.60   | 93.92       |
 
 
@@ -92,3 +92,105 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In
 ```python
 text, confidence_scores = model.predict(image, confidences=True)
 ```
+
+### Commands
+
+This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
+
+#### Data extraction from Arkindex
+
+Use the `teklia-dan extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model.
+The available arguments are
+
+| Parameter                      | Description                                                                         | Type     | Default |
+| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- |
+| `--parent`                       | UUID of the folder to import from Arkindex. You may specify multiple UUIDs.         | str/uuid |         |
+| `--element-type`                 | Type of the elements to extract. You may specify multiple types.                    | str      |         |
+| `--output`                       | Folder where the data will be generated. Must exist.                                | Path     |         |
+| `--load-entities`                | Extract text with their entities. Needed for NER tasks.                             | bool     | False   |
+| `--tokens`                       | Mapping between starting tokens and end tokens. Needed for NER tasks.               | Path     |         |
+| `--use-existing-split`           | Use the specified folder IDs for the dataset split.                                 | bool     |         |
+| `--train-folder`                 | ID of the training folder to import from Arkindex.                                  | uuid     |         |
+| `--val-folder`                   | ID of the validation folder to import from Arkindex.                                | uuid     |         |
+| `--test-folder`                  | ID of the training folder to import from Arkindex.                                  | uuid     |         |
+| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering.         | str/uuid |         |
+| `--entity-worker-version`        | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | str/uuid |         |
+| `--train-prob`                   | Training set split size                                                             | float    | 0,7     |
+| `--val-prob`                     | Validation set split size                                                           | float    | 0,15    |
+
+The `--tokens` argument expects a file with the following format.
+```yaml
+---
+INTITULE:
+  start: ⓘ
+  end: â’¾
+DATE:
+  start: ⓘ
+  end: â’¹
+COTE_SERIE:
+  start: ⓘ
+  end: Ⓢ
+ANALYSE_COMPL.:
+  start: ⓘ
+  end: â’¸
+PRECISIONS_SUR_COTE:
+  start: ⓘ
+  end: â“…
+COTE_ARTICLE:
+  start: ⓘ
+  end: â’¶
+CLASSEMENT:
+  start: ⓘ
+  end: Ⓛ
+```
+
+
+To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command:
+```shell
+teklia-dan extract \
+    --parent 665e84ea-bd97-4912-91b0-1f4a844287ff \
+    --element-type page \
+    --output data \
+    --load-entities \
+    --tokens tokens.yml
+```
+with `tokens.yml` having the content described just above.
+
+
+To do the same but only use the data from three folders, the commands becomes:
+```shell
+teklia-dan extract \
+    --parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \
+    --element-type page \
+    --output data \
+    --load-entities \
+    --tokens tokens.yml
+```
+
+To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes:
+```shell
+teklia-dan extract \
+    --use-existing-split \
+    --train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c 
+    --val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \
+    --test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \
+    --element-type page \
+    --output data \
+    --load-entities \
+    --tokens tokens.yml
+```
+
+To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615), use the following command:
+```shell
+teklia-dan extract \
+    --parent 48852284-fc02-41bb-9a42-4458e5a51615 \
+    --element-type text_zone annotation \
+    --output data
+```
+
+#### Model training
+`teklia-dan train` with multiple arguments.
+
+#### Synthetic data generation
+`teklia-dan generate` with multiple arguments
+
diff --git a/dan/__init__.py b/dan/__init__.py
index e69de29b..b74e8889 100644
--- a/dan/__init__.py
+++ b/dan/__init__.py
@@ -0,0 +1,8 @@
+# -*- coding: utf-8 -*-
+import logging
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
+)
+logger = logging.getLogger(__name__)
diff --git a/dan/cli.py b/dan/cli.py
index fff057b9..8b06e4e7 100644
--- a/dan/cli.py
+++ b/dan/cli.py
@@ -8,11 +8,11 @@ from dan.ocr.train import add_train_parser
 
 
 def get_parser():
-    parser = argparse.ArgumentParser(prog="TEKLIA DAN training")
+    parser = argparse.ArgumentParser(prog="teklia-dan")
     subcommands = parser.add_subparsers(metavar="subcommand")
 
-    add_train_parser(subcommands)
     add_extract_parser(subcommands)
+    add_train_parser(subcommands)
     add_generate_parser(subcommands)
     return parser
 
diff --git a/dan/datasets/extract/arkindex_utils.py b/dan/datasets/extract/arkindex_utils.py
deleted file mode 100644
index d9d2e065..00000000
--- a/dan/datasets/extract/arkindex_utils.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""
-    The arkindex_utils module
-    ======================
-"""
-
-import errno
-import logging
-import sys
-
-from apistar.exceptions import ErrorResponse
-
-
-def retrieve_corpus(client, corpus_name: str) -> str:
-    """
-    Retrieve the corpus id from the corpus name.
-    :param client: The arkindex client.
-    :param corpus_name: The name of the corpus to retrieve.
-    :return target_corpus: The id of the retrieved corpus.
-    """
-    for corpus in client.request("ListCorpus"):
-        if corpus["name"] == corpus_name:
-            target_corpus = corpus["id"]
-    try:
-        logging.info(f"Corpus id retrieved: {target_corpus}")
-    except NameError:
-        logging.error(f"Corpus {corpus_name} not found")
-        sys.exit(errno.EINVAL)
-
-    return target_corpus
-
-
-def retrieve_subsets(
-    client, corpus: str, parents_types: list, parents_names: list
-) -> list:
-    """
-    Retrieve the requested subsets.
-    :param client: The arkindex client.
-    :param corpus: The id of the retrieved corpus.
-    :param parents_types: The types of parents of the elements to retrieve.
-    :param parents_names: The names of parents of the elements to retrieve.
-    :return subsets: The retrieved subsets.
-    """
-    subsets = []
-    for parent_type in parents_types:
-        try:
-            subsets.extend(
-                client.request("ListElements", corpus=corpus, type=parent_type)[
-                    "results"
-                ]
-            )
-        except ErrorResponse as e:
-            logging.error(f"{e.content}: {parent_type}")
-            sys.exit(errno.EINVAL)
-    # Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets.
-    if parents_names is not None:
-        logging.info(f"Retrieving {parents_names} subset(s)")
-        subsets = [subset for subset in subsets if subset["name"] in parents_names]
-    else:
-        logging.info("Retrieving all subsets")
-
-    if len(subsets) == 0:
-        logging.info("No subset found")
-
-    return subsets
diff --git a/dan/datasets/extract/extract_from_arkindex.py b/dan/datasets/extract/extract_from_arkindex.py
index c9c1c2c1..5c3818be 100644
--- a/dan/datasets/extract/extract_from_arkindex.py
+++ b/dan/datasets/extract/extract_from_arkindex.py
@@ -1,127 +1,386 @@
 # -*- coding: utf-8 -*-
 
 """
-    The extraction module
-    ======================
+Extract dataset from Arkindex using API.
 """
 
+from collections import defaultdict
 import logging
 import os
+import pathlib
+import random
+import uuid
 
-import cv2
 import imageio.v2 as iio
 from arkindex import ArkindexClient, options_from_env
 from tqdm import tqdm
 
-from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets
-from dan.datasets.extract.utils import get_cli_args
-
-logging.basicConfig(
-    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+from dan.datasets.extract.utils import (
+    insert_token,
+    parse_tokens,
+    save_image,
+    save_json,
+    save_text,
 )
 
+from dan import logger
+
 
-IMAGES_DIR = "./images/"  # Path to the images directory.
-LABELS_DIR = "./labels/"  # Path to the labels directory.
+IMAGES_DIR = "images"  # Subpath to the images directory.
+LABELS_DIR = "labels"  # Subpath to the labels directory.
+MANUAL_SOURCE = "manual"
 
-# Layout string to token
-SEM_MATCHING_TOKENS_STR = {
-    "INTITULE": "ⓘ",
-    "DATE": "â““",
-    "COTE_SERIE": "â“¢",
-    "ANALYSE_COMPL.": "â“’",
-    "PRECISIONS_SUR_COTE": "ⓟ",
-    "COTE_ARTICLE": "ⓐ",
-}
 
-# Layout begin-token to end-token
-SEM_MATCHING_TOKENS = {"ⓘ": "Ⓘ", "ⓓ": "Ⓓ", "ⓢ": "Ⓢ", "ⓒ": "Ⓒ", "ⓟ": "Ⓟ", "ⓐ": "Ⓐ"}
+def parse_worker_version(worker_version_id):
+    if worker_version_id == MANUAL_SOURCE:
+        return False
+    return worker_version_id
 
 
 def add_extract_parser(subcommands) -> None:
     parser = subcommands.add_parser(
         "extract",
         description=__doc__,
+        help=__doc__,
+    )
+
+    # Required arguments.
+    parser.add_argument(
+        "--parent",
+        type=uuid.UUID,
+        nargs="+",
+        help="ID of the parent folder to import from Arkindex.",
+        required=False,
+    )
+    parser.add_argument(
+        "--element-type",
+        nargs="+",
+        type=str,
+        help="Type of elements to retrieve",
+        required=True,
+    )
+    parser.add_argument(
+        "--output",
+        type=pathlib.Path,
+        help="Path where the data will be generated.",
+        required=True,
+    )
+
+    # Optional arguments.
+    parser.add_argument(
+        "--load-entities", action="store_true", help="Extract text with their entities"
+    )
+    parser.add_argument(
+        "--tokens",
+        type=pathlib.Path,
+        help="Mapping between starting tokens and end tokens. Needed for entities.",
+        required=False,
+    )
+
+    parser.add_argument(
+        "--use-existing-split",
+        action="store_true",
+        help="Use the specified folder IDs for the dataset split.",
+    )
+
+    parser.add_argument(
+        "--train-folder",
+        type=uuid.UUID,
+        help="ID of the training folder to import from Arkindex.",
+        required=False,
+    )
+    parser.add_argument(
+        "--val-folder",
+        type=uuid.UUID,
+        help="ID of the validation folder to import from Arkindex.",
+        required=False,
+    )
+    parser.add_argument(
+        "--test-folder",
+        type=uuid.UUID,
+        help="ID of the testing folder to import from Arkindex.",
+        required=False,
     )
+
+    parser.add_argument(
+        "--transcription-worker-version",
+        type=parse_worker_version,
+        help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
+        required=False,
+        default=MANUAL_SOURCE,
+    )
+    parser.add_argument(
+        "--entity-worker-version",
+        type=parse_worker_version,
+        help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
+        required=False,
+        default=MANUAL_SOURCE,
+    )
+
+    parser.add_argument(
+        "--train-prob", type=float, default=0.7, help="Training set split size."
+    )
+
+    parser.add_argument(
+        "--val-prob", type=float, default=0.15, help="Validation set split size"
+    )
+
     parser.set_defaults(func=run)
 
 
-def run():
-    args = get_cli_args()
+class ArkindexExtractor:
+    """
+    Extract data from Arkindex
+    """
 
-    # Get and initialize the parameters.
-    os.makedirs(IMAGES_DIR, exist_ok=True)
-    os.makedirs(LABELS_DIR, exist_ok=True)
+    def __init__(
+        self,
+        client,
+        folders,
+        element_type,
+        split_names,
+        output,
+        load_entities,
+        tokens,
+        use_existing_split,
+        transcription_worker_version,
+        entity_worker_version,
+        train_prob,
+        val_prob,
+    ) -> None:
+        self.client = client
+        self.element_type = element_type
+        self.split_names = split_names
+        self.output = output
+        self.load_entities = load_entities
+        self.tokens = parse_tokens(tokens) if self.load_entities else None
+        self.use_existing_split = use_existing_split
+        self.transcription_worker_version = transcription_worker_version
+        self.entity_worker_version = entity_worker_version
+        self.train_prob = train_prob
+        self.val_prob = val_prob
 
-    # Login to arkindex.
-    client = ArkindexClient(**options_from_env())
+        self.get_subsets(folders)
 
-    corpus = retrieve_corpus(client, args.corpus)
-    subsets = retrieve_subsets(client, corpus, args.parents_types, args.parents_names)
+    def get_subsets(self, folders):
+        if self.use_existing_split:
+            self.subsets = [
+                (folder, split) for folder, split in zip(folders, self.split_names)
+            ]
+        else:
+            self.subsets = [(folder, None) for folder in folders]
 
-    # Iterate over the subsets to find the page images and labels.
-    for subset in subsets:
+    def assign_random_split(self):
+        """
+        assuming train_prob + valid_prob + test_prob = 1
+        """
+        prob = random.random()
+        if prob <= self.train_prob:
+            return self.split_names[0]
+        elif prob <= self.train_prob + self.val_prob:
+            return self.split_names[1]
+        else:
+            return self.split_names[2]
 
-        os.makedirs(
-            os.path.join(args.output_dir, IMAGES_DIR, subset["name"]), exist_ok=True
-        )
-        os.makedirs(
-            os.path.join(args.output_dir, LABELS_DIR, subset["name"]), exist_ok=True
+    def extract_transcription(
+        self,
+        element,
+    ):
+        transcriptions = self.client.request(
+            "ListTranscriptions",
+            id=element["id"],
+            worker_version=self.transcription_worker_version,
         )
 
-        for page in tqdm(
-            client.paginate(
-                "ListElementChildren", id=subset["id"], type="page", recursive=True
-            ),
-            desc="Set " + subset["name"],
-        ):
-
-            image = iio.imread(page["zone"]["url"])
-            cv2.imwrite(
-                os.path.join(
-                    args.output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg"
-                ),
-                cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
+        if transcriptions["count"] != 1:
+            logger.warning(
+                f"More than one transcription found on element ({element['id']}) with this config."
             )
+            return
 
-            tr = client.request(
-                "ListTranscriptions", id=page["id"], worker_version=None
-            )["results"]
-            tr = [one for one in tr if one["worker_version_id"] is None]
-            assert len(tr) == 1, page["id"]
+        transcription = transcriptions["results"].pop()
+        if self.load_entities:
+            entities = self.client.request(
+                "ListTranscriptionEntities",
+                id=transcription["id"],
+                worker_version=self.entity_worker_version,
+            )
+            if entities["count"] == 0:
+                logger.warning(
+                    f"No entities found on transcription ({transcription['id']})."
+                )
+                return
+            else:
+                text = transcription["text"]
 
-            for one_tr in tr:
-                ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[
-                    "results"
+            count = 0
+            for entity in entities["results"]:
+                start_token, end_token = self.tokens[
+                    entity["entity"]["metas"]["subtype"]
                 ]
-                ent = [one for one in ent if one["worker_version_id"] is None]
-                if len(ent) == 0:
-                    continue
-                else:
-                    text = one_tr["text"]
+                text, count = insert_token(
+                    text,
+                    count,
+                    start_token,
+                    end_token,
+                    offset=entity["offset"],
+                    length=entity["length"],
+                )
+        else:
+            text = transcription["text"].strip()
+        return text
 
-            new_text = text
-            count = 0
-            for e in ent:
-                start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]]
-                end_token = SEM_MATCHING_TOKENS[start_token]
-                new_text = (
-                    new_text[: count + e["offset"]]
-                    + start_token
-                    + new_text[count + e["offset"] :]
+    def process_element(
+        self,
+        element,
+        split,
+    ):
+        text = self.extract_transcription(
+            element,
+        )
+
+        if not text:
+            logging.warning(
+                f"Skipping {element['id']} (zero or multiple transcriptions with worker_version=None)"
+            )
+        else:
+            logging.info(f"Processed {element['type']} {element['id']}")
+
+            im_path = os.path.join(
+                self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg"
+            )
+            txt_path = os.path.join(
+                self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt"
+            )
+
+            save_text(txt_path, text)
+            try:
+                image = iio.imread(element["zone"]["url"])
+                save_image(im_path, image)
+            except Exception:
+                logger.error(f"Couldn't retrieve image of element ({element['id']}")
+                pass
+            return element["id"]
+
+    def process_page(
+        self,
+        page,
+        split,
+    ):
+        # Extract only pages
+        data = defaultdict(list)
+        if self.element_type == ["page"]:
+            data["page"] = [
+                self.process_element(
+                    page,
+                    split,
                 )
-                count += 1
-                new_text = (
-                    new_text[: count + e["offset"] + e["length"]]
-                    + end_token
-                    + new_text[count + e["offset"] + e["length"] :]
+            ]
+        # Extract page's children elements (text_zone, text_line)
+        else:
+            for element_type in self.element_type:
+                for element in self.client.paginate(
+                    "ListElementChildren",
+                    id=page["id"],
+                    type=element_type,
+                    recursive=True,
+                ):
+                    data[element_type].append(
+                        self.process_element(
+                            element,
+                            split,
+                        )
+                    )
+        return data
+
+    def run(self):
+        split_dict = defaultdict(dict)
+        # Iterate over the subsets to find the page images and labels.
+        for subset_id, subset_split in self.subsets:
+            page_idx = 0
+            # Iterate over the pages to create splits at page level.
+            for page in tqdm(
+                self.client.paginate(
+                    "ListElementChildren", id=subset_id, type="page", recursive=True
+                )
+            ):
+                page_idx += 1
+                split = subset_split or self.assign_random_split()
+
+                split_dict[split][page["id"]] = self.process_page(
+                    page=page,
+                    split=split,
                 )
-                count += 1
-
-            with open(
-                os.path.join(
-                    args.output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt"
-                ),
-                "w",
-            ) as f:
-                f.write(new_text)
+
+        save_json(self.output / "split.json", split_dict)
+
+
+def run(
+    parent,
+    element_type,
+    output,
+    load_entities,
+    tokens,
+    use_existing_split,
+    train_folder,
+    val_folder,
+    test_folder,
+    transcription_worker_version,
+    entity_worker_version,
+    train_prob,
+    val_prob,
+):
+    assert (
+        use_existing_split or parent
+    ), "One of `--use-existing-split` and `--parent` must be set"
+
+    if use_existing_split:
+        assert (
+            train_folder
+        ), "If you use an existing split, you must specify the training folder."
+        assert (
+            val_folder
+        ), "If you use an existing split, you must specify the validation folder."
+        assert (
+            test_folder
+        ), "If you use an existing split, you must specify the testing folder."
+        folders = [train_folder, val_folder, test_folder]
+    else:
+        folders = parent
+
+    if load_entities:
+        assert tokens, "Please provide the entities to match."
+
+    # Get and initialize the parameters.
+    os.makedirs(IMAGES_DIR, exist_ok=True)
+    os.makedirs(LABELS_DIR, exist_ok=True)
+
+    # Login to arkindex.
+    assert (
+        "ARKINDEX_API_URL" in os.environ
+    ), "The ARKINDEX API URL was not found in your environment."
+    assert (
+        "ARKINDEX_API_TOKEN" in os.environ
+    ), "Your API credentials was not found in your environment."
+    client = ArkindexClient(**options_from_env())
+
+    # Create out directories
+    split_names = ["train", "val", "test"]
+    for split in split_names:
+        os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True)
+        os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True)
+
+    ArkindexExtractor(
+        client=client,
+        folders=folders,
+        element_type=element_type,
+        split_names=split_names,
+        output=output,
+        load_entities=load_entities,
+        tokens=tokens,
+        use_existing_split=use_existing_split,
+        transcription_worker_version=transcription_worker_version,
+        entity_worker_version=entity_worker_version,
+        train_prob=train_prob,
+        val_prob=val_prob,
+    ).run()
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index 77d15804..4b1aa93f 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -1,74 +1,56 @@
 # -*- coding: utf-8 -*-
+import yaml
+import json
+import random
 
-"""
-    The utils module
-    ======================
-"""
+import cv2
 
-import argparse
+random.seed(42)
 
 
-def get_cli_args():
+def assign_random_split(train_prob, val_prob):
     """
-    Get the command-line arguments.
-    :return: The command-line arguments.
+    assuming train_prob + valid_prob + test_prob = 1
     """
-    parser = argparse.ArgumentParser(
-        description="Arkindex DAN Training Label Generation"
-    )
+    prob = random.random()
+    if prob <= train_prob:
+        return "train"
+    elif prob <= train_prob + val_prob:
+        return "valid"
+    else:
+        return "test"
 
-    # Required arguments.
-    parser.add_argument(
-        "--corpus",
-        type=str,
-        help="Name of the corpus from which the data will be retrieved.",
-        required=True,
-    )
-    parser.add_argument(
-        "--element-type",
-        nargs="+",
-        type=str,
-        help="Type of elements to retrieve",
-        required=True,
-    )
-    parser.add_argument(
-        "--parents-types",
-        nargs="+",
-        type=str,
-        help="Type of parents of the elements.",
-        required=True,
-    )
-    parser.add_argument(
-        "--output-dir",
-        type=str,
-        help="Path to the output directory.",
-        required=True,
-    )
 
-    # Optional arguments.
-    parser.add_argument(
-        "--parents-names",
-        nargs="+",
-        type=str,
-        help="Names of parents of the elements.",
-        default=None,
-    )
-    parser.add_argument(
-        "--no-entities", action="store_true", help="Extract text without entities"
-    )
+def save_text(path, text):
+    with open(path, "w") as f:
+        f.write(text)
 
-    parser.add_argument(
-        "--use-existing-split",
-        action="store_true",
-        help="Do not partition pages into train/val/test",
-    )
 
-    parser.add_argument(
-        "--train-prob", type=float, default=0.7, help="Training set probability"
-    )
+def save_image(path, image):
+    cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+
 
-    parser.add_argument(
-        "--val-prob", type=float, default=0.15, help="Validation set probability"
+def save_json(path, dict):
+    with open(path, "w") as outfile:
+        json.dump(dict, outfile, indent=4)
+
+
+def insert_token(text, count, start_token, end_token, offset, length):
+    text = (
+        # Text before entity
+        text[: count + offset]
+        # Starting token
+        + start_token
+        # Entity
+        + text[count + offset : count + 1 + offset + length]
+        # End token
+        + end_token
+        # Text after entity
+        + text[count + 1 + offset + length :]
     )
+    return text, count + 2
+
 
-    return parser.parse_args()
+def parse_tokens(filename):
+    with open(filename) as f:
+        return yaml.safe_load(f)
diff --git a/dan/datasets/utils.py b/dan/datasets/utils.py
index 6422357c..c911cc9b 100644
--- a/dan/datasets/utils.py
+++ b/dan/datasets/utils.py
@@ -1,12 +1,6 @@
 # -*- coding: utf-8 -*-
-import json
-import random
 import re
 
-import cv2
-
-random.seed(42)
-
 
 def convert(text):
     return int(text) if text.isdigit() else text.lower()
@@ -14,30 +8,3 @@ def convert(text):
 
 def natural_sort(data):
     return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)])
-
-
-def assign_random_split(train_prob, val_prob):
-    """
-    assuming train_prob + val_prob + test_prob = 1
-    """
-    prob = random.random()
-    if prob <= train_prob:
-        return "train"
-    elif prob <= train_prob + val_prob:
-        return "val"
-    else:
-        return "test"
-
-
-def save_text(path, text):
-    with open(path, "w") as f:
-        f.write(text)
-
-
-def save_image(path, image):
-    cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
-
-
-def save_json(path, dict):
-    with open(path, "w") as outfile:
-        json.dump(dict, outfile, indent=4)
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 1f8216a9..3166a543 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -1,4 +1,7 @@
 # -*- coding: utf-8 -*-
+"""
+Train a DAN model at document level.
+"""
 import random
 
 import numpy as np
@@ -18,6 +21,7 @@ def add_document_parser(subcommands) -> None:
     parser = subcommands.add_parser(
         "document",
         description=__doc__,
+        help=__doc__,
     )
     parser.set_defaults(func=run)
 
diff --git a/dan/ocr/line/generate_synthetic.py b/dan/ocr/line/generate_synthetic.py
index 435e19ed..67aa5a24 100644
--- a/dan/ocr/line/generate_synthetic.py
+++ b/dan/ocr/line/generate_synthetic.py
@@ -1,4 +1,7 @@
 # -*- coding: utf-8 -*-
+"""
+Generate synthetic data to train DAN models
+"""
 import random
 
 import numpy as np
@@ -18,6 +21,7 @@ def add_generate_parser(subcommands) -> None:
     parser = subcommands.add_parser(
         "generate",
         description=__doc__,
+        help=__doc__,
     )
     parser.set_defaults(func=run)
 
diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py
index 9e092fd9..b49e3df0 100644
--- a/dan/ocr/line/train.py
+++ b/dan/ocr/line/train.py
@@ -1,4 +1,7 @@
 # -*- coding: utf-8 -*-
+"""
+Train a DAN model at line level.
+"""
 import random
 
 import numpy as np
@@ -18,6 +21,7 @@ def add_line_parser(subcommands) -> None:
     parser = subcommands.add_parser(
         "line",
         description=__doc__,
+        help=__doc__,
     )
     parser.set_defaults(func=run)
 
diff --git a/dan/ocr/train.py b/dan/ocr/train.py
index dda43ba8..375656b3 100644
--- a/dan/ocr/train.py
+++ b/dan/ocr/train.py
@@ -1,4 +1,7 @@
 # -*- coding: utf-8 -*-
+"""
+Train a new DAN model.
+"""
 
 from dan.ocr.document.train import add_document_parser
 from dan.ocr.line.train import add_line_parser
@@ -8,6 +11,7 @@ def add_train_parser(subcommands) -> None:
     parser = subcommands.add_parser(
         "train",
         description=__doc__,
+        help=__doc__,
     )
     subcommands = parser.add_subparsers(metavar="subcommand")
 
diff --git a/requirements.txt b/requirements.txt
index 27758aa6..6288bf12 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,7 +6,7 @@ networkx==2.6.3
 numpy==1.22.3
 opencv-python==4.5.5.64
 PyYAML==6.0
-tensorboard==0.2.1
+tensorboard==2.8.0
 torch==1.11.0
 torchvision==0.12.0
 tqdm==4.62.3
-- 
GitLab