From e316fcc13af905df79325fc211119add6a686d67 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 22 Nov 2022 11:01:53 +0000
Subject: [PATCH] Implement extraction command

---
 .gitlab-ci.yml                                |   2 +-
 MANIFEST.in                                   |   2 +
 README.md                                     | 109 ++++-
 dan/__init__.py                               |   8 +
 dan/cli.py                                    |   8 +-
 dan/datasets/__init__.py                      |  17 +
 dan/datasets/extract/__init__.py              | 115 +++++
 dan/datasets/extract/arkindex_utils.py        |  66 ---
 dan/datasets/extract/extract_from_arkindex.py | 453 ++++++++++++------
 dan/datasets/extract/utils.py                 |  43 ++
 dan/datasets/utils.py                         |  33 --
 dan/ocr/__init__.py                           |  19 +
 dan/ocr/document/__init__.py                  |  15 +
 dan/ocr/document/train.py                     |   9 -
 dan/ocr/line/__init__.py                      |  22 +
 dan/ocr/line/generate_synthetic.py            |   9 -
 dan/ocr/line/train.py                         |   9 -
 dan/ocr/train.py                              |  16 -
 requirements.txt                              |   2 +-
 tests/conftest.py                             |  20 +
 tests/test_extract.py                         | 120 +++++
 tox.ini                                       |  12 +
 22 files changed, 799 insertions(+), 310 deletions(-)
 create mode 100644 MANIFEST.in
 delete mode 100644 dan/datasets/extract/arkindex_utils.py
 create mode 100644 dan/datasets/extract/utils.py
 delete mode 100644 dan/ocr/train.py
 create mode 100644 tests/conftest.py
 create mode 100644 tests/test_extract.py
 create mode 100644 tox.ini

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index e35b054d..81e95c76 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -32,7 +32,7 @@ bump-python-deps:
     - schedules
 
   script:
-    - devops python-deps requirements.txt tests-requirements.txt
+    - devops python-deps requirements.txt
 
 release-notes:
   stage: deploy
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..fd959fa8
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,2 @@
+include requirements.txt
+include VERSION
diff --git a/README.md b/README.md
index 9b29f204..7c4b6b6c 100644
--- a/README.md
+++ b/README.md
@@ -21,9 +21,9 @@ We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page a
 We obtained the following results:
 
 |                         | CER (%) | WER (%) | LOER (%) | mAP_cer (%) |
-|:-----------------------:|---------|:-------:|:--------:|-------------|
-|       RIMES (single page)      | 4.54    |  11.85  |   3.82   | 93.74       |
-|     READ 2016 (single page)    | 3.53    |  13.33  |   5.94   | 92.57       |
+| :---------------------: | ------- | :-----: | :------: | ----------- |
+|   RIMES (single page)   | 4.54    |  11.85  |   3.82   | 93.74       |
+| READ 2016 (single page) | 3.53    |  13.33  |   5.94   | 92.57       |
 | READ 2016 (double page) | 3.69    |  14.20  |   4.60   | 93.92       |
 
 
@@ -92,3 +92,106 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In
 ```python
 text, confidence_scores = model.predict(image, confidences=True)
 ```
+
+### Commands
+
+This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
+
+#### Data extraction from Arkindex
+
+Use the `teklia-dan dataset extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model.
+The available arguments are
+
+| Parameter                      | Description                                                                         | Type     | Default |
+| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- |
+| `--parent`                       | UUID of the folder to import from Arkindex. You may specify multiple UUIDs.         | `str/uuid` |         |
+| `--element-type`                 | Type of the elements to extract. You may specify multiple types.                    | `str`      |         |
+| `--parent-element-type`                 | Type of the parent element containing the data.                    | `str`      |  `page`       |
+| `--output`                       | Folder where the data will be generated.                                | `Path`     |         |
+| `--load-entities`                | Extract text with their entities. Needed for NER tasks.                             | `bool`     | `False`   |
+| `--tokens`                       | Mapping between starting tokens and end tokens. Needed for NER tasks.               | `Path`    |         |
+| `--use-existing-split`           | Use the specified folder IDs for the dataset split.                                 | `bool`     |         |
+| `--train-folder`                 | ID of the training folder to import from Arkindex.                                  | `uuid`     |         |
+| `--val-folder`                   | ID of the validation folder to import from Arkindex.                                | `uuid`     |         |
+| `--test-folder`                  | ID of the training folder to import from Arkindex.                                  | `uuid`     |         |
+| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering.         | `str/uuid` |         |
+| `--entity-worker-version`        | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | `str/uuid` |         |
+| `--train-prob`                   | Training set split size                                                             | `float`    | `0.7`     |
+| `--val-prob`                     | Validation set split size                                                           | `float`    | `0.15`    |
+
+The `--tokens` argument expects a file with the following format.
+```yaml
+---
+INTITULE:
+  start: ⓘ
+  end: â’¾
+DATE:
+  start: â““
+  end: â’¹
+COTE_SERIE:
+  start: â“¢
+  end: Ⓢ
+ANALYSE_COMPL.:
+  start: â“’
+  end: â’¸
+PRECISIONS_SUR_COTE:
+  start: â“Ÿ
+  end: â“…
+COTE_ARTICLE:
+  start: ⓐ
+  end: â’¶
+CLASSEMENT:
+  start: â“›
+  end: Ⓛ
+```
+
+
+To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command:
+```shell
+teklia-dan dataset extract \
+    --parent 665e84ea-bd97-4912-91b0-1f4a844287ff \
+    --element-type page \
+    --output data \
+    --load-entities \
+    --tokens tokens.yml
+```
+with `tokens.yml` having the content described just above.
+
+
+To do the same but only use the data from three folders, the commands becomes:
+```shell
+teklia-dan dataset extract \
+    --parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \
+    --element-type page \
+    --output data \
+    --load-entities \
+    --tokens tokens.yml
+```
+
+To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes:
+```shell
+teklia-dan dataset extract \
+    --use-existing-split \
+    --train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c
+    --val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \
+    --test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \
+    --element-type page \
+    --output data \
+    --load-entities \
+    --tokens tokens.yml
+```
+
+To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615) that are children of **single_pages**, use the following command:
+```shell
+teklia-dan dataset extract \
+    --parent 48852284-fc02-41bb-9a42-4458e5a51615 \
+    --element-type text_zone annotation \
+    --parent-element-type single_page \
+    --output data
+```
+
+#### Model training
+`teklia-dan train` with multiple arguments.
+
+#### Synthetic data generation
+`teklia-dan generate` with multiple arguments
diff --git a/dan/__init__.py b/dan/__init__.py
index e69de29b..b74e8889 100644
--- a/dan/__init__.py
+++ b/dan/__init__.py
@@ -0,0 +1,8 @@
+# -*- coding: utf-8 -*-
+import logging
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
+)
+logger = logging.getLogger(__name__)
diff --git a/dan/cli.py b/dan/cli.py
index e633ceb4..ddf244f4 100644
--- a/dan/cli.py
+++ b/dan/cli.py
@@ -2,17 +2,17 @@
 import argparse
 import errno
 
-from dan.datasets.extract.extract_from_arkindex import add_extract_parser
-from dan.ocr.line.generate_synthetic import add_generate_parser
-from dan.ocr.train import add_train_parser
+from dan.datasets import add_dataset_parser
+from dan.ocr import add_train_parser
+from dan.ocr.line import add_generate_parser
 
 
 def get_parser():
     parser = argparse.ArgumentParser(prog="teklia-dan")
     subcommands = parser.add_subparsers(metavar="subcommand")
 
+    add_dataset_parser(subcommands)
     add_train_parser(subcommands)
-    add_extract_parser(subcommands)
     add_generate_parser(subcommands)
     return parser
 
diff --git a/dan/datasets/__init__.py b/dan/datasets/__init__.py
index e69de29b..889e11cf 100644
--- a/dan/datasets/__init__.py
+++ b/dan/datasets/__init__.py
@@ -0,0 +1,17 @@
+# -*- coding: utf-8 -*-
+"""
+Preprocess datasets for training.
+"""
+
+from dan.datasets.extract import add_extract_parser
+
+
+def add_dataset_parser(subcommands) -> None:
+    parser = subcommands.add_parser(
+        "dataset",
+        description=__doc__,
+        help=__doc__,
+    )
+    subcommands = parser.add_subparsers(metavar="subcommand")
+
+    add_extract_parser(subcommands)
diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index e69de29b..76521846 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+"""
+Extract dataset from Arkindex using API.
+"""
+
+import pathlib
+import uuid
+
+from dan.datasets.extract.extract_from_arkindex import run
+
+MANUAL_SOURCE = "manual"
+
+
+def parse_worker_version(worker_version_id):
+    if worker_version_id == MANUAL_SOURCE:
+        return False
+    return worker_version_id
+
+
+def add_extract_parser(subcommands) -> None:
+    parser = subcommands.add_parser(
+        "extract",
+        description=__doc__,
+        help=__doc__,
+    )
+
+    # Required arguments.
+    parser.add_argument(
+        "--parent",
+        type=uuid.UUID,
+        nargs="+",
+        help="ID of the parent folder to import from Arkindex.",
+        required=False,
+    )
+    parser.add_argument(
+        "--element-type",
+        nargs="+",
+        type=str,
+        help="Type of elements to retrieve",
+        required=True,
+    )
+    parser.add_argument(
+        "--parent-element-type",
+        type=str,
+        help="Type of the parent element containing the data.",
+        required=False,
+        default="page",
+    )
+    parser.add_argument(
+        "--output",
+        type=pathlib.Path,
+        help="Path where the data will be generated.",
+        required=True,
+    )
+
+    # Optional arguments.
+    parser.add_argument(
+        "--load-entities", action="store_true", help="Extract text with their entities"
+    )
+    parser.add_argument(
+        "--tokens",
+        type=pathlib.Path,
+        help="Mapping between starting tokens and end tokens. Needed for entities.",
+        required=False,
+    )
+
+    parser.add_argument(
+        "--use-existing-split",
+        action="store_true",
+        help="Use the specified folder IDs for the dataset split.",
+    )
+
+    parser.add_argument(
+        "--train-folder",
+        type=uuid.UUID,
+        help="ID of the training folder to import from Arkindex.",
+        required=False,
+    )
+    parser.add_argument(
+        "--val-folder",
+        type=uuid.UUID,
+        help="ID of the validation folder to import from Arkindex.",
+        required=False,
+    )
+    parser.add_argument(
+        "--test-folder",
+        type=uuid.UUID,
+        help="ID of the testing folder to import from Arkindex.",
+        required=False,
+    )
+
+    parser.add_argument(
+        "--transcription-worker-version",
+        type=parse_worker_version,
+        help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
+        required=False,
+        default=MANUAL_SOURCE,
+    )
+    parser.add_argument(
+        "--entity-worker-version",
+        type=parse_worker_version,
+        help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
+        required=False,
+        default=MANUAL_SOURCE,
+    )
+
+    parser.add_argument(
+        "--train-prob", type=float, default=0.7, help="Training set split size."
+    )
+
+    parser.add_argument(
+        "--val-prob", type=float, default=0.15, help="Validation set split size"
+    )
+
+    parser.set_defaults(func=run)
diff --git a/dan/datasets/extract/arkindex_utils.py b/dan/datasets/extract/arkindex_utils.py
deleted file mode 100644
index d9d2e065..00000000
--- a/dan/datasets/extract/arkindex_utils.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""
-    The arkindex_utils module
-    ======================
-"""
-
-import errno
-import logging
-import sys
-
-from apistar.exceptions import ErrorResponse
-
-
-def retrieve_corpus(client, corpus_name: str) -> str:
-    """
-    Retrieve the corpus id from the corpus name.
-    :param client: The arkindex client.
-    :param corpus_name: The name of the corpus to retrieve.
-    :return target_corpus: The id of the retrieved corpus.
-    """
-    for corpus in client.request("ListCorpus"):
-        if corpus["name"] == corpus_name:
-            target_corpus = corpus["id"]
-    try:
-        logging.info(f"Corpus id retrieved: {target_corpus}")
-    except NameError:
-        logging.error(f"Corpus {corpus_name} not found")
-        sys.exit(errno.EINVAL)
-
-    return target_corpus
-
-
-def retrieve_subsets(
-    client, corpus: str, parents_types: list, parents_names: list
-) -> list:
-    """
-    Retrieve the requested subsets.
-    :param client: The arkindex client.
-    :param corpus: The id of the retrieved corpus.
-    :param parents_types: The types of parents of the elements to retrieve.
-    :param parents_names: The names of parents of the elements to retrieve.
-    :return subsets: The retrieved subsets.
-    """
-    subsets = []
-    for parent_type in parents_types:
-        try:
-            subsets.extend(
-                client.request("ListElements", corpus=corpus, type=parent_type)[
-                    "results"
-                ]
-            )
-        except ErrorResponse as e:
-            logging.error(f"{e.content}: {parent_type}")
-            sys.exit(errno.EINVAL)
-    # Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets.
-    if parents_names is not None:
-        logging.info(f"Retrieving {parents_names} subset(s)")
-        subsets = [subset for subset in subsets if subset["name"] in parents_names]
-    else:
-        logging.info("Retrieving all subsets")
-
-    if len(subsets) == 0:
-        logging.info("No subset found")
-
-    return subsets
diff --git a/dan/datasets/extract/extract_from_arkindex.py b/dan/datasets/extract/extract_from_arkindex.py
index a135cb7d..c4e1fe48 100644
--- a/dan/datasets/extract/extract_from_arkindex.py
+++ b/dan/datasets/extract/extract_from_arkindex.py
@@ -1,184 +1,319 @@
 # -*- coding: utf-8 -*-
 
-"""
-    The extraction module
-    ======================
-"""
-
 import logging
 import os
+import random
+from collections import defaultdict
+from pathlib import Path
+from typing import List, NamedTuple
 
-import cv2
 import imageio.v2 as iio
 from arkindex import ArkindexClient, options_from_env
 from tqdm import tqdm
 
-from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets
-
-logging.basicConfig(
-    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+from dan import logger
+from dan.datasets.extract.utils import (
+    insert_token,
+    parse_tokens,
+    save_image,
+    save_json,
+    save_text,
 )
 
+IMAGES_DIR = "images"  # Subpath to the images directory.
+LABELS_DIR = "labels"  # Subpath to the labels directory.
+
+Entity = NamedTuple("Entity", offset=int, length=int, label=str)
+
+
+class ArkindexExtractor:
+    """
+    Extract data from Arkindex
+    """
+
+    def __init__(
+        self,
+        client: ArkindexClient,
+        folders: list = [],
+        element_type: list = [],
+        parent_element_type: list = ["page"],
+        split_names: list = [],
+        output: Path = None,
+        load_entities: bool = None,
+        tokens: Path = None,
+        use_existing_split: bool = None,
+        transcription_worker_version: str = None,
+        entity_worker_version: str = None,
+        train_prob: float = None,
+        val_prob: float = None,
+    ) -> None:
+        self.client = client
+        self.element_type = element_type
+        self.parent_element_type = parent_element_type
+        self.split_names = split_names
+        self.output = output
+        self.load_entities = load_entities
+        self.tokens = parse_tokens(tokens) if self.load_entities else None
+        self.use_existing_split = use_existing_split
+        self.transcription_worker_version = transcription_worker_version
+        self.entity_worker_version = entity_worker_version
+        self.train_prob = train_prob
+        self.val_prob = val_prob
+
+        self.get_subsets(folders)
+
+    def get_subsets(self, folders: list):
+        """
+        Assign each folder to its split if it's already known.
+        Assign None if it's unknown.
+        """
+        if self.use_existing_split:
+            self.subsets = [
+                (folder, split) for folder, split in zip(folders, self.split_names)
+            ]
+        else:
+            self.subsets = [(folder, None) for folder in folders]
+
+    def _assign_random_split(self):
+        """
+        Yields a randomly chosen split for an element.
+        Assumes that train_prob + valid_prob + test_prob = 1
+        """
+        prob = random.random()
+        if prob <= self.train_prob:
+            yield self.split_names[0]
+        elif prob <= self.train_prob + self.val_prob:
+            yield self.split_names[1]
+        else:
+            yield self.split_names[2]
+
+    def get_random_split(self):
+        return next(self._assign_random_split())
+
+    def extract_entities(self, transcription: dict):
+        entities = self.client.paginate(
+            "ListTranscriptionEntities",
+            id=transcription["id"],
+            worker_version=self.entity_worker_version,
+        )
+        if entities is None:
+            logger.warning(
+                f"No entities found on transcription ({transcription['id']})."
+            )
+            return
+        return [
+            Entity(
+                offset=entity["offset"],
+                length=entity["length"],
+                label=entity["entity"]["metas"]["subtype"],
+            )
+            for entity in entities
+        ]
+
+    def reconstruct_text(self, text: str, entities: List[Entity]):
+        """
+        Insert tokens delimiting the start/end of each entity on the transcription.
+        """
+        count = 0
+        for entity in entities:
+            matching_tokens = self.tokens[entity.label]
+            start_token, end_token = (
+                matching_tokens["start"],
+                matching_tokens["end"],
+            )
+            text, count = insert_token(
+                text,
+                count,
+                start_token,
+                end_token,
+                offset=entity.offset,
+                length=entity.length,
+            )
+        return text
+
+    def extract_transcription(
+        self,
+        element: dict,
+    ):
+        """
+        Extract the element's transcription.
+        If the entities are needed, they are added to the transcription using tokens.
+        """
+        transcriptions = self.client.request(
+            "ListTranscriptions",
+            id=element["id"],
+            worker_version=self.transcription_worker_version,
+        )
+        if transcriptions["count"] != 1:
+            logger.warning(
+                f"More than one transcription found on element ({element['id']}) with this config."
+            )
+            return
+
+        transcription = transcriptions["results"].pop()
+        if self.load_entities:
+            entities = self.extract_entities(transcription)
+            return self.reconstruct_text(transcription["text"], entities)
+        else:
+            return transcription["text"].strip()
+
+    def process_element(
+        self,
+        element: dict,
+        split: str,
+    ):
+        """
+        Extract an element's data and save it to disk.
+        The output path is directly related to the split of the element.
+        """
+        text = self.extract_transcription(
+            element,
+        )
 
-IMAGES_DIR = "./images/"  # Path to the images directory.
-LABELS_DIR = "./labels/"  # Path to the labels directory.
-
-# Layout string to token
-SEM_MATCHING_TOKENS_STR = {
-    "INTITULE": "ⓘ",
-    "DATE": "â““",
-    "COTE_SERIE": "â“¢",
-    "ANALYSE_COMPL.": "â“’",
-    "PRECISIONS_SUR_COTE": "â“Ÿ",
-    "COTE_ARTICLE": "ⓐ",
-}
-
-# Layout begin-token to end-token
-SEM_MATCHING_TOKENS = {"ⓘ": "Ⓘ", "ⓓ": "Ⓓ", "ⓢ": "Ⓢ", "ⓒ": "Ⓒ", "ⓟ": "Ⓟ", "ⓐ": "Ⓐ"}
-
-
-def add_extract_parser(subcommands) -> None:
-    parser = subcommands.add_parser(
-        "extract",
-        description=__doc__,
-        help=__doc__,
-    )
-    # Required arguments.
-    parser.add_argument(
-        "--corpus",
-        type=str,
-        help="Name of the corpus from which the data will be retrieved.",
-        required=True,
-    )
-    parser.add_argument(
-        "--element-type",
-        nargs="+",
-        type=str,
-        help="Type of elements to retrieve",
-        required=True,
-    )
-    parser.add_argument(
-        "--parents-types",
-        nargs="+",
-        type=str,
-        help="Type of parents of the elements.",
-        required=True,
-    )
-    parser.add_argument(
-        "--output-dir",
-        type=str,
-        help="Path to the output directory.",
-        required=True,
-    )
-
-    # Optional arguments.
-    parser.add_argument(
-        "--parents-names",
-        nargs="+",
-        type=str,
-        help="Names of parents of the elements.",
-        default=None,
-    )
-    parser.add_argument(
-        "--no-entities", action="store_true", help="Extract text without entities"
-    )
-
-    parser.add_argument(
-        "--use-existing-split",
-        action="store_true",
-        help="Do not partition pages into train/val/test",
-    )
-
-    parser.add_argument(
-        "--train-prob", type=float, default=0.7, help="Training set probability"
-    )
-
-    parser.add_argument(
-        "--val-prob", type=float, default=0.15, help="Validation set probability"
-    )
-    parser.set_defaults(func=run)
+        if not text:
+            logging.warning(f"Skipping {element['id']}")
+        else:
+            logging.info(f"Processed {element['type']} {element['id']}")
+
+            im_path = os.path.join(
+                self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg"
+            )
+            txt_path = os.path.join(
+                self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt"
+            )
+
+            save_text(txt_path, text)
+            try:
+                image = iio.imread(element["zone"]["url"])
+                save_image(im_path, image)
+            except Exception:
+                logger.error(f"Couldn't retrieve image of element ({element['id']}")
+                raise
+            return element["id"]
+
+    def process_parent(
+        self,
+        parent: dict,
+        split: str,
+    ):
+        """
+        Extract data from a parent element.
+        Depending on the given types,
+        """
+        data = defaultdict(list)
+        if self.element_type == [parent["type"]]:
+            data[self.element_type[0]] = [
+                self.process_element(
+                    parent,
+                    split,
+                )
+            ]
+        # Extract children elements
+        else:
+            for element_type in self.element_type:
+                for element in self.client.paginate(
+                    "ListElementChildren",
+                    id=parent["id"],
+                    type=element_type,
+                    recursive=True,
+                ):
+                    data[element_type].append(
+                        self.process_element(
+                            element,
+                            split,
+                        )
+                    )
+        return data
+
+    def run(self):
+        split_dict = defaultdict(dict)
+        # Iterate over the subsets to find the page images and labels.
+        for subset_id, subset_split in self.subsets:
+            page_idx = 0
+            # Iterate over the pages to create splits at page level.
+            for parent in tqdm(
+                self.client.paginate(
+                    "ListElementChildren",
+                    id=subset_id,
+                    type=self.parent_element_type,
+                    recursive=True,
+                )
+            ):
+                page_idx += 1
+                split = subset_split or self.get_random_split()
+
+                split_dict[split][parent["id"]] = self.process_parent(
+                    parent=parent,
+                    split=split,
+                )
+
+        save_json(self.output / "split.json", split_dict)
 
 
 def run(
-    corpus,
+    parent,
     element_type,
-    parents_types,
-    output_dir,
-    parents_names,
-    no_entities,
+    parent_element_type,
+    output,
+    load_entities,
+    tokens,
     use_existing_split,
+    train_folder,
+    val_folder,
+    test_folder,
+    transcription_worker_version,
+    entity_worker_version,
     train_prob,
     val_prob,
 ):
-    # Get and initialize the parameters.
-    os.makedirs(IMAGES_DIR, exist_ok=True)
-    os.makedirs(LABELS_DIR, exist_ok=True)
-
-    # Login to arkindex.
-    client = ArkindexClient(**options_from_env())
-
-    corpus = retrieve_corpus(client, corpus)
-    subsets = retrieve_subsets(client, corpus, parents_types, parents_names)
+    assert (
+        use_existing_split or parent
+    ), "One of `--use-existing-split` and `--parent` must be set"
 
-    # Iterate over the subsets to find the page images and labels.
-    for subset in subsets:
+    if use_existing_split:
+        assert (
+            train_folder
+        ), "If you use an existing split, you must specify the training folder."
+        assert (
+            val_folder
+        ), "If you use an existing split, you must specify the validation folder."
+        assert (
+            test_folder
+        ), "If you use an existing split, you must specify the testing folder."
+        folders = [train_folder, val_folder, test_folder]
+    else:
+        folders = parent
 
-        os.makedirs(os.path.join(output_dir, IMAGES_DIR, subset["name"]), exist_ok=True)
-        os.makedirs(os.path.join(output_dir, LABELS_DIR, subset["name"]), exist_ok=True)
+    if load_entities:
+        assert tokens, "Please provide the entities to match."
 
-        for page in tqdm(
-            client.paginate(
-                "ListElementChildren", id=subset["id"], type="page", recursive=True
-            ),
-            desc="Set " + subset["name"],
-        ):
+    # Login to arkindex.
+    assert (
+        "ARKINDEX_API_URL" in os.environ
+    ), "The ARKINDEX API URL was not found in your environment."
+    assert (
+        "ARKINDEX_API_TOKEN" in os.environ
+    ), "Your API credentials was not found in your environment."
+    client = ArkindexClient(**options_from_env())
 
-            image = iio.imread(page["zone"]["url"])
-            cv2.imwrite(
-                os.path.join(
-                    output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg"
-                ),
-                cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
-            )
+    # Create out directories
+    split_names = ["train", "val", "test"]
+    for split in split_names:
+        os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True)
+        os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True)
 
-            tr = client.request(
-                "ListTranscriptions", id=page["id"], worker_version=None
-            )["results"]
-            tr = [one for one in tr if one["worker_version_id"] is None]
-            assert len(tr) == 1, page["id"]
-
-            for one_tr in tr:
-                ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[
-                    "results"
-                ]
-                ent = [one for one in ent if one["worker_version_id"] is None]
-                if len(ent) == 0:
-                    continue
-                else:
-                    text = one_tr["text"]
-
-            new_text = text
-            count = 0
-            for e in ent:
-                start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]]
-                end_token = SEM_MATCHING_TOKENS[start_token]
-                new_text = (
-                    new_text[: count + e["offset"]]
-                    + start_token
-                    + new_text[count + e["offset"] :]
-                )
-                count += 1
-                new_text = (
-                    new_text[: count + e["offset"] + e["length"]]
-                    + end_token
-                    + new_text[count + e["offset"] + e["length"] :]
-                )
-                count += 1
-
-            with open(
-                os.path.join(
-                    output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt"
-                ),
-                "w",
-            ) as f:
-                f.write(new_text)
+    ArkindexExtractor(
+        client=client,
+        folders=folders,
+        element_type=element_type,
+        parent_element_type=parent_element_type,
+        split_names=split_names,
+        output=output,
+        load_entities=load_entities,
+        tokens=tokens,
+        use_existing_split=use_existing_split,
+        transcription_worker_version=transcription_worker_version,
+        entity_worker_version=entity_worker_version,
+        train_prob=train_prob,
+        val_prob=val_prob,
+    ).run()
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
new file mode 100644
index 00000000..4a582228
--- /dev/null
+++ b/dan/datasets/extract/utils.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+import json
+
+import cv2
+import yaml
+
+
+def save_text(path, text):
+    with open(path, "w") as f:
+        f.write(text)
+
+
+def save_image(path, image):
+    cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+
+
+def save_json(path, dict):
+    with open(path, "w") as outfile:
+        json.dump(dict, outfile, indent=4)
+
+
+def insert_token(text, count, start_token, end_token, offset, length):
+    """
+    Insert the given tokens at the right position in the text
+    """
+    text = (
+        # Text before entity
+        text[: count + offset]
+        # Starting token
+        + start_token
+        # Entity
+        + text[count + offset : count + 1 + offset + length]
+        # End token
+        + end_token
+        # Text after entity
+        + text[count + 1 + offset + length :]
+    )
+    return text, count + 2
+
+
+def parse_tokens(filename):
+    with open(filename) as f:
+        return yaml.safe_load(f)
diff --git a/dan/datasets/utils.py b/dan/datasets/utils.py
index 6422357c..c911cc9b 100644
--- a/dan/datasets/utils.py
+++ b/dan/datasets/utils.py
@@ -1,12 +1,6 @@
 # -*- coding: utf-8 -*-
-import json
-import random
 import re
 
-import cv2
-
-random.seed(42)
-
 
 def convert(text):
     return int(text) if text.isdigit() else text.lower()
@@ -14,30 +8,3 @@ def convert(text):
 
 def natural_sort(data):
     return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)])
-
-
-def assign_random_split(train_prob, val_prob):
-    """
-    assuming train_prob + val_prob + test_prob = 1
-    """
-    prob = random.random()
-    if prob <= train_prob:
-        return "train"
-    elif prob <= train_prob + val_prob:
-        return "val"
-    else:
-        return "test"
-
-
-def save_text(path, text):
-    with open(path, "w") as f:
-        f.write(text)
-
-
-def save_image(path, image):
-    cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
-
-
-def save_json(path, dict):
-    with open(path, "w") as outfile:
-        json.dump(dict, outfile, indent=4)
diff --git a/dan/ocr/__init__.py b/dan/ocr/__init__.py
index e69de29b..3d18b6fe 100644
--- a/dan/ocr/__init__.py
+++ b/dan/ocr/__init__.py
@@ -0,0 +1,19 @@
+# -*- coding: utf-8 -*-
+"""
+Train a new DAN model.
+"""
+
+from dan.ocr.document import add_document_parser
+from dan.ocr.line import add_line_parser
+
+
+def add_train_parser(subcommands) -> None:
+    parser = subcommands.add_parser(
+        "train",
+        description=__doc__,
+        help=__doc__,
+    )
+    subcommands = parser.add_subparsers(metavar="subcommand")
+
+    add_line_parser(subcommands)
+    add_document_parser(subcommands)
diff --git a/dan/ocr/document/__init__.py b/dan/ocr/document/__init__.py
index e69de29b..375a1327 100644
--- a/dan/ocr/document/__init__.py
+++ b/dan/ocr/document/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+"""
+Train a DAN model at document level.
+"""
+
+from dan.ocr.document.train import run
+
+
+def add_document_parser(subcommands) -> None:
+    parser = subcommands.add_parser(
+        "document",
+        description=__doc__,
+        help=__doc__,
+    )
+    parser.set_defaults(func=run)
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 09df425a..64ca1d31 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler
 from dan.transforms import aug_config
 
 
-def add_document_parser(subcommands) -> None:
-    parser = subcommands.add_parser(
-        "document",
-        description=__doc__,
-        help=__doc__,
-    )
-    parser.set_defaults(func=run)
-
-
 def train_and_test(rank, params):
     torch.manual_seed(0)
     torch.cuda.manual_seed(0)
diff --git a/dan/ocr/line/__init__.py b/dan/ocr/line/__init__.py
index e69de29b..603a060d 100644
--- a/dan/ocr/line/__init__.py
+++ b/dan/ocr/line/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+
+from dan.ocr.line.generate_synthetic import run as run_generate
+from dan.ocr.line.train import run as run_train
+
+
+def add_generate_parser(subcommands) -> None:
+    parser = subcommands.add_parser(
+        "generate",
+        description=__doc__,
+        help="Generate synthetic data to train DAN models.",
+    )
+    parser.set_defaults(func=run_generate)
+
+
+def add_line_parser(subcommands) -> None:
+    parser = subcommands.add_parser(
+        "line",
+        description=__doc__,
+        help="Train a DAN model at line level.",
+    )
+    parser.set_defaults(func=run_train)
diff --git a/dan/ocr/line/generate_synthetic.py b/dan/ocr/line/generate_synthetic.py
index b034d5b0..de378031 100644
--- a/dan/ocr/line/generate_synthetic.py
+++ b/dan/ocr/line/generate_synthetic.py
@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
 from dan.transforms import line_aug_config
 
 
-def add_generate_parser(subcommands) -> None:
-    parser = subcommands.add_parser(
-        "generate",
-        description=__doc__,
-        help=__doc__,
-    )
-    parser.set_defaults(func=run)
-
-
 def train_and_test(rank, params):
     torch.manual_seed(0)
     torch.cuda.manual_seed(0)
diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py
index ba3d0b4f..7d2bab45 100644
--- a/dan/ocr/line/train.py
+++ b/dan/ocr/line/train.py
@@ -14,15 +14,6 @@ from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
 from dan.transforms import line_aug_config
 
 
-def add_line_parser(subcommands) -> None:
-    parser = subcommands.add_parser(
-        "line",
-        description=__doc__,
-        help=__doc__,
-    )
-    parser.set_defaults(func=run)
-
-
 def train_and_test(rank, params):
     torch.manual_seed(0)
     torch.cuda.manual_seed(0)
diff --git a/dan/ocr/train.py b/dan/ocr/train.py
deleted file mode 100644
index 37176069..00000000
--- a/dan/ocr/train.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from dan.ocr.document.train import add_document_parser
-from dan.ocr.line.train import add_line_parser
-
-
-def add_train_parser(subcommands) -> None:
-    parser = subcommands.add_parser(
-        "train",
-        description=__doc__,
-        help=__doc__,
-    )
-    subcommands = parser.add_subparsers(metavar="subcommand")
-
-    add_line_parser(subcommands)
-    add_document_parser(subcommands)
diff --git a/requirements.txt b/requirements.txt
index 27758aa6..6288bf12 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,7 +6,7 @@ networkx==2.6.3
 numpy==1.22.3
 opencv-python==4.5.5.64
 PyYAML==6.0
-tensorboard==0.2.1
+tensorboard==2.8.0
 torch==1.11.0
 torchvision==0.12.0
 tqdm==4.62.3
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..21a29dfa
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+import os
+
+import pytest
+
+
+@pytest.fixture(autouse=True)
+def setup_environment(responses):
+    """Setup needed environment variables"""
+
+    # Allow accessing remote API schemas
+    # defaulting to the prod environment
+    schema_url = os.environ.get(
+        "ARKINDEX_API_SCHEMA_URL",
+        "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
+    )
+    responses.add_passthru(schema_url)
+
+    # Set schema url in environment
+    os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
diff --git a/tests/test_extract.py b/tests/test_extract.py
new file mode 100644
index 00000000..3e3552d2
--- /dev/null
+++ b/tests/test_extract.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+
+import pytest
+from arkindex.mock import MockApiClient
+
+from dan.datasets.extract.extract_from_arkindex import ArkindexExtractor, Entity
+from dan.datasets.extract.utils import insert_token
+
+
+@pytest.fixture
+def arkindex_extractor():
+    return ArkindexExtractor(
+        client=MockApiClient(), split_names=["train", "val", "test"]
+    )
+
+
+@pytest.mark.parametrize(
+    "text,count,offset,length,expected",
+    (
+        ("n°1 16 janvier 1611", 0, 0, 3, "ⓘn°1 Ⓘ16 janvier 1611"),
+        ("ⓘn°1 Ⓘ16 janvier 1611", 2, 4, 15, "ⓘn°1 Ⓘⓘ16 janvier 1611Ⓘ"),
+    ),
+)
+def test_insert_token(text, count, offset, length, expected):
+    start_token, end_token = "ⓘ", "Ⓘ"
+    assert (
+        insert_token(text, count, start_token, end_token, offset, length)[0] == expected
+    )
+
+
+@pytest.mark.parametrize(
+    "text,entities,expected",
+    (
+        (
+            "n°1 16 janvier 1611",
+            [
+                Entity(offset=0, length=3, label="P"),
+                Entity(offset=4, length=15, label="D"),
+            ],
+            "ⓟn°1 Ⓟⓓ16 janvier 1611Ⓓ",
+        ),
+    ),
+)
+def test_reconstruct_text(arkindex_extractor, text, entities, expected):
+    arkindex_extractor.tokens = {
+        "P": {"start": "â“Ÿ", "end": "â“…"},
+        "D": {"start": "â““", "end": "â’¹"},
+    }
+    assert arkindex_extractor.reconstruct_text(text, entities) == expected
+
+
+@pytest.mark.parametrize(
+    "text,offset,length,label,expected",
+    (
+        ("   n°1 16 janvier 1611   ", None, None, None, "n°1 16 janvier 1611"),
+        ("n°1 16 janvier 1611", 0, 3, "P", "ⓟn°1 Ⓟ16 janvier 1611"),
+    ),
+)
+def test_extract_transcription(
+    arkindex_extractor, text, offset, length, label, expected
+):
+    element = {"id": "element_id"}
+    transcription = {"id": "transcription_id", "text": text}
+    arkindex_extractor.client.add_response(
+        "ListTranscriptions",
+        id="element_id",
+        worker_version=None,
+        response={"count": 1, "results": [transcription]},
+    )
+
+    if label:
+        arkindex_extractor.load_entities = True
+        arkindex_extractor.tokens = {
+            "P": {"start": "â“Ÿ", "end": "â“…"},
+        }
+        arkindex_extractor.client.add_response(
+            "ListTranscriptionEntities",
+            id="transcription_id",
+            worker_version=None,
+            response=[
+                {
+                    "entity": {"id": "entity_id", "metas": {"subtype": label}},
+                    "offset": offset,
+                    "length": length,
+                    "worker_version": None,
+                    "worker_run_id": None,
+                }
+            ],
+        )
+
+    assert arkindex_extractor.extract_transcription(element) == expected
+
+
+@pytest.mark.parametrize(
+    "offset,length,label",
+    ((0, 3, "P"),),
+)
+def test_extract_entities(arkindex_extractor, offset, length, label):
+    transcription = {"id": "transcription_id"}
+    arkindex_extractor.tokens = {
+        "P": {"start": "â“Ÿ", "end": "â“…"},
+    }
+    arkindex_extractor.client.add_response(
+        "ListTranscriptionEntities",
+        id="transcription_id",
+        worker_version=None,
+        response=[
+            {
+                "entity": {"id": "entity_id", "metas": {"subtype": label}},
+                "offset": offset,
+                "length": length,
+                "worker_version": None,
+                "worker_run_id": None,
+            }
+        ],
+    )
+
+    assert arkindex_extractor.extract_entities(transcription) == [
+        Entity(offset=offset, length=length, label=label)
+    ]
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 00000000..aedc8cf0
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,12 @@
+[tox]
+envlist = teklia-dan
+
+[testenv]
+passenv = ARKINDEX_API_SCHEMA_URL
+commands =
+  pytest {posargs}
+
+deps =
+  pytest
+  pytest-responses
+  -rrequirements.txt
-- 
GitLab