Skip to content
Snippets Groups Projects
Verified Commit 30524dfe authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

implement extraction

parent 4710eb23
No related branches found
No related tags found
No related merge requests found
......@@ -21,9 +21,9 @@ We evaluate the DAN on two public datasets: RIMES and READ 2016 at single-page a
We obtained the following results:
| | CER (%) | WER (%) | LOER (%) | mAP_cer (%) |
|:-----------------------:|---------|:-------:|:--------:|-------------|
| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 |
| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 |
| :---------------------: | ------- | :-----: | :------: | ----------- |
| RIMES (single page) | 4.54 | 11.85 | 3.82 | 93.74 |
| READ 2016 (single page) | 3.53 | 13.33 | 5.94 | 92.57 |
| READ 2016 (double page) | 3.69 | 14.20 | 4.60 | 93.92 |
......@@ -92,3 +92,105 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In
```python
text, confidence_scores = model.predict(image, confidences=True)
```
### Commands
This package provides three subcommands. To get more information about any subcommand, use the `--help` option.
#### Data extraction from Arkindex
Use the `teklia-dan extract` command to extract a dataset from Arkindex. This will generate the images and the labels needed to train a DAN model.
The available arguments are
| Parameter | Description | Type | Default |
| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- |
| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | str/uuid | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | str | |
| `--output` | Folder where the data will be generated. Must exist. | Path | |
| `--load-entities` | Extract text with their entities. Needed for NER tasks. | bool | False |
| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | Path | |
| `--use-existing-split` | Use the specified folder IDs for the dataset split. | bool | |
| `--train-folder` | ID of the training folder to import from Arkindex. | uuid | |
| `--val-folder` | ID of the validation folder to import from Arkindex. | uuid | |
| `--test-folder` | ID of the training folder to import from Arkindex. | uuid | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use ‘manual’ for manual filtering. | str/uuid | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use ‘manual’ for manual filtering | str/uuid | |
| `--train-prob` | Training set split size | float | 0,7 |
| `--val-prob` | Validation set split size | float | 0,15 |
The `--tokens` argument expects a file with the following format.
```yaml
---
INTITULE:
start: ⓘ
end: Ⓘ
DATE:
start: ⓘ
end: Ⓓ
COTE_SERIE:
start: ⓘ
end: Ⓢ
ANALYSE_COMPL.:
start: ⓘ
end: Ⓒ
PRECISIONS_SUR_COTE:
start: ⓘ
end: Ⓟ
COTE_ARTICLE:
start: ⓘ
end: Ⓐ
CLASSEMENT:
start: ⓘ
end: Ⓛ
```
To extract HTR+NER data from **pages** from [this folder](https://arkindex.teklia.com/element/665e84ea-bd97-4912-91b0-1f4a844287ff), use the following command:
```shell
teklia-dan extract \
--parent 665e84ea-bd97-4912-91b0-1f4a844287ff \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
with `tokens.yml` having the content described just above.
To do the same but only use the data from three folders, the commands becomes:
```shell
teklia-dan extract \
--parent 2275529a-1ec5-40ce-a516-42ea7ada858c af9b38b5-5d95-417d-87ec-730537cb1898 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
To use the data from three folders as **training**, **validation** and **testing** dataset respectively, the commands becomes:
```shell
teklia-dan extract \
--use-existing-split \
--train-folder 2275529a-1ec5-40ce-a516-42ea7ada858c
--val-folder af9b38b5-5d95-417d-87ec-730537cb1898 \
--test-folder 6ff44957-0e65-48c5-9d77-a178116405b2 \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
To extract HTR data from **annotations** and **text_zones** from [this folder](https://demo.arkindex.org/element/48852284-fc02-41bb-9a42-4458e5a51615), use the following command:
```shell
teklia-dan extract \
--parent 48852284-fc02-41bb-9a42-4458e5a51615 \
--element-type text_zone annotation \
--output data
```
#### Model training
`teklia-dan train` with multiple arguments.
#### Synthetic data generation
`teklia-dan generate` with multiple arguments
# -*- coding: utf-8 -*-
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
......@@ -8,11 +8,11 @@ from dan.ocr.train import add_train_parser
def get_parser():
parser = argparse.ArgumentParser(prog="TEKLIA DAN training")
parser = argparse.ArgumentParser(prog="teklia-dan")
subcommands = parser.add_subparsers(metavar="subcommand")
add_train_parser(subcommands)
add_extract_parser(subcommands)
add_train_parser(subcommands)
add_generate_parser(subcommands)
return parser
......
# -*- coding: utf-8 -*-
"""
The arkindex_utils module
======================
"""
import errno
import logging
import sys
from apistar.exceptions import ErrorResponse
def retrieve_corpus(client, corpus_name: str) -> str:
"""
Retrieve the corpus id from the corpus name.
:param client: The arkindex client.
:param corpus_name: The name of the corpus to retrieve.
:return target_corpus: The id of the retrieved corpus.
"""
for corpus in client.request("ListCorpus"):
if corpus["name"] == corpus_name:
target_corpus = corpus["id"]
try:
logging.info(f"Corpus id retrieved: {target_corpus}")
except NameError:
logging.error(f"Corpus {corpus_name} not found")
sys.exit(errno.EINVAL)
return target_corpus
def retrieve_subsets(
client, corpus: str, parents_types: list, parents_names: list
) -> list:
"""
Retrieve the requested subsets.
:param client: The arkindex client.
:param corpus: The id of the retrieved corpus.
:param parents_types: The types of parents of the elements to retrieve.
:param parents_names: The names of parents of the elements to retrieve.
:return subsets: The retrieved subsets.
"""
subsets = []
for parent_type in parents_types:
try:
subsets.extend(
client.request("ListElements", corpus=corpus, type=parent_type)[
"results"
]
)
except ErrorResponse as e:
logging.error(f"{e.content}: {parent_type}")
sys.exit(errno.EINVAL)
# Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets.
if parents_names is not None:
logging.info(f"Retrieving {parents_names} subset(s)")
subsets = [subset for subset in subsets if subset["name"] in parents_names]
else:
logging.info("Retrieving all subsets")
if len(subsets) == 0:
logging.info("No subset found")
return subsets
# -*- coding: utf-8 -*-
"""
The extraction module
======================
Extract dataset from Arkindex using API.
"""
from collections import defaultdict
import logging
import os
import pathlib
import random
import uuid
import cv2
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets
from dan.datasets.extract.utils import get_cli_args
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
from dan.datasets.extract.utils import (
insert_token,
parse_tokens,
save_image,
save_json,
save_text,
)
from dan import logger
IMAGES_DIR = "./images/" # Path to the images directory.
LABELS_DIR = "./labels/" # Path to the labels directory.
IMAGES_DIR = "images" # Subpath to the images directory.
LABELS_DIR = "labels" # Subpath to the labels directory.
MANUAL_SOURCE = "manual"
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL.": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": "",
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
def parse_worker_version(worker_version_id):
if worker_version_id == MANUAL_SOURCE:
return False
return worker_version_id
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--parent",
type=uuid.UUID,
nargs="+",
help="ID of the parent folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--output",
type=pathlib.Path,
help="Path where the data will be generated.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--load-entities", action="store_true", help="Extract text with their entities"
)
parser.add_argument(
"--tokens",
type=pathlib.Path,
help="Mapping between starting tokens and end tokens. Needed for entities.",
required=False,
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Use the specified folder IDs for the dataset split.",
)
parser.add_argument(
"--train-folder",
type=uuid.UUID,
help="ID of the training folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--val-folder",
type=uuid.UUID,
help="ID of the validation folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--test-folder",
type=uuid.UUID,
help="ID of the testing folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--transcription-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
)
parser.add_argument(
"--entity-worker-version",
type=parse_worker_version,
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=MANUAL_SOURCE,
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set split size."
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set split size"
)
parser.set_defaults(func=run)
def run():
args = get_cli_args()
class ArkindexExtractor:
"""
Extract data from Arkindex
"""
# Get and initialize the parameters.
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)
def __init__(
self,
client,
folders,
element_type,
split_names,
output,
load_entities,
tokens,
use_existing_split,
transcription_worker_version,
entity_worker_version,
train_prob,
val_prob,
) -> None:
self.client = client
self.element_type = element_type
self.split_names = split_names
self.output = output
self.load_entities = load_entities
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
# Login to arkindex.
client = ArkindexClient(**options_from_env())
self.get_subsets(folders)
corpus = retrieve_corpus(client, args.corpus)
subsets = retrieve_subsets(client, corpus, args.parents_types, args.parents_names)
def get_subsets(self, folders):
if self.use_existing_split:
self.subsets = [
(folder, split) for folder, split in zip(folders, self.split_names)
]
else:
self.subsets = [(folder, None) for folder in folders]
# Iterate over the subsets to find the page images and labels.
for subset in subsets:
def assign_random_split(self):
"""
assuming train_prob + valid_prob + test_prob = 1
"""
prob = random.random()
if prob <= self.train_prob:
return self.split_names[0]
elif prob <= self.train_prob + self.val_prob:
return self.split_names[1]
else:
return self.split_names[2]
os.makedirs(
os.path.join(args.output_dir, IMAGES_DIR, subset["name"]), exist_ok=True
)
os.makedirs(
os.path.join(args.output_dir, LABELS_DIR, subset["name"]), exist_ok=True
def extract_transcription(
self,
element,
):
transcriptions = self.client.request(
"ListTranscriptions",
id=element["id"],
worker_version=self.transcription_worker_version,
)
for page in tqdm(
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
),
desc="Set " + subset["name"],
):
image = iio.imread(page["zone"]["url"])
cv2.imwrite(
os.path.join(
args.output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg"
),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
if transcriptions["count"] != 1:
logger.warning(
f"More than one transcription found on element ({element['id']}) with this config."
)
return
tr = client.request(
"ListTranscriptions", id=page["id"], worker_version=None
)["results"]
tr = [one for one in tr if one["worker_version_id"] is None]
assert len(tr) == 1, page["id"]
transcription = transcriptions["results"].pop()
if self.load_entities:
entities = self.client.request(
"ListTranscriptionEntities",
id=transcription["id"],
worker_version=self.entity_worker_version,
)
if entities["count"] == 0:
logger.warning(
f"No entities found on transcription ({transcription['id']})."
)
return
else:
text = transcription["text"]
for one_tr in tr:
ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[
"results"
count = 0
for entity in entities["results"]:
start_token, end_token = self.tokens[
entity["entity"]["metas"]["subtype"]
]
ent = [one for one in ent if one["worker_version_id"] is None]
if len(ent) == 0:
continue
else:
text = one_tr["text"]
text, count = insert_token(
text,
count,
start_token,
end_token,
offset=entity["offset"],
length=entity["length"],
)
else:
text = transcription["text"].strip()
return text
new_text = text
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]]
end_token = SEM_MATCHING_TOKENS[start_token]
new_text = (
new_text[: count + e["offset"]]
+ start_token
+ new_text[count + e["offset"] :]
def process_element(
self,
element,
split,
):
text = self.extract_transcription(
element,
)
if not text:
logging.warning(
f"Skipping {element['id']} (zero or multiple transcriptions with worker_version=None)"
)
else:
logging.info(f"Processed {element['type']} {element['id']}")
im_path = os.path.join(
self.output, IMAGES_DIR, split, f"{element['type']}_{element['id']}.jpg"
)
txt_path = os.path.join(
self.output, LABELS_DIR, split, f"{element['type']}_{element['id']}.txt"
)
save_text(txt_path, text)
try:
image = iio.imread(element["zone"]["url"])
save_image(im_path, image)
except Exception:
logger.error(f"Couldn't retrieve image of element ({element['id']}")
pass
return element["id"]
def process_page(
self,
page,
split,
):
# Extract only pages
data = defaultdict(list)
if self.element_type == ["page"]:
data["page"] = [
self.process_element(
page,
split,
)
count += 1
new_text = (
new_text[: count + e["offset"] + e["length"]]
+ end_token
+ new_text[count + e["offset"] + e["length"] :]
]
# Extract page's children elements (text_zone, text_line)
else:
for element_type in self.element_type:
for element in self.client.paginate(
"ListElementChildren",
id=page["id"],
type=element_type,
recursive=True,
):
data[element_type].append(
self.process_element(
element,
split,
)
)
return data
def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels.
for subset_id, subset_split in self.subsets:
page_idx = 0
# Iterate over the pages to create splits at page level.
for page in tqdm(
self.client.paginate(
"ListElementChildren", id=subset_id, type="page", recursive=True
)
):
page_idx += 1
split = subset_split or self.assign_random_split()
split_dict[split][page["id"]] = self.process_page(
page=page,
split=split,
)
count += 1
with open(
os.path.join(
args.output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt"
),
"w",
) as f:
f.write(new_text)
save_json(self.output / "split.json", split_dict)
def run(
parent,
element_type,
output,
load_entities,
tokens,
use_existing_split,
train_folder,
val_folder,
test_folder,
transcription_worker_version,
entity_worker_version,
train_prob,
val_prob,
):
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
if use_existing_split:
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [train_folder, val_folder, test_folder]
else:
folders = parent
if load_entities:
assert tokens, "Please provide the entities to match."
# Get and initialize the parameters.
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)
# Login to arkindex.
assert (
"ARKINDEX_API_URL" in os.environ
), "The ARKINDEX API URL was not found in your environment."
assert (
"ARKINDEX_API_TOKEN" in os.environ
), "Your API credentials was not found in your environment."
client = ArkindexClient(**options_from_env())
# Create out directories
split_names = ["train", "val", "test"]
for split in split_names:
os.makedirs(os.path.join(output, IMAGES_DIR, split), exist_ok=True)
os.makedirs(os.path.join(output, LABELS_DIR, split), exist_ok=True)
ArkindexExtractor(
client=client,
folders=folders,
element_type=element_type,
split_names=split_names,
output=output,
load_entities=load_entities,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
).run()
# -*- coding: utf-8 -*-
import yaml
import json
import random
"""
The utils module
======================
"""
import cv2
import argparse
random.seed(42)
def get_cli_args():
def assign_random_split(train_prob, val_prob):
"""
Get the command-line arguments.
:return: The command-line arguments.
assuming train_prob + valid_prob + test_prob = 1
"""
parser = argparse.ArgumentParser(
description="Arkindex DAN Training Label Generation"
)
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "valid"
else:
return "test"
# Required arguments.
parser.add_argument(
"--corpus",
type=str,
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
type=str,
help="Type of parents of the elements.",
required=True,
)
parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--parents-names",
nargs="+",
type=str,
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities", action="store_true", help="Extract text without entities"
)
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test",
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set probability"
)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set probability"
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
def insert_token(text, count, start_token, end_token, offset, length):
text = (
# Text before entity
text[: count + offset]
# Starting token
+ start_token
# Entity
+ text[count + offset : count + 1 + offset + length]
# End token
+ end_token
# Text after entity
+ text[count + 1 + offset + length :]
)
return text, count + 2
return parser.parse_args()
def parse_tokens(filename):
with open(filename) as f:
return yaml.safe_load(f)
# -*- coding: utf-8 -*-
import json
import random
import re
import cv2
random.seed(42)
def convert(text):
return int(text) if text.isdigit() else text.lower()
......@@ -14,30 +8,3 @@ def convert(text):
def natural_sort(data):
return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)])
def assign_random_split(train_prob, val_prob):
"""
assuming train_prob + val_prob + test_prob = 1
"""
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "val"
else:
return "test"
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
# -*- coding: utf-8 -*-
"""
Train a DAN model at document level.
"""
import random
import numpy as np
......@@ -18,6 +21,7 @@ def add_document_parser(subcommands) -> None:
parser = subcommands.add_parser(
"document",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
......
# -*- coding: utf-8 -*-
"""
Generate synthetic data to train DAN models
"""
import random
import numpy as np
......@@ -18,6 +21,7 @@ def add_generate_parser(subcommands) -> None:
parser = subcommands.add_parser(
"generate",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
......
# -*- coding: utf-8 -*-
"""
Train a DAN model at line level.
"""
import random
import numpy as np
......@@ -18,6 +21,7 @@ def add_line_parser(subcommands) -> None:
parser = subcommands.add_parser(
"line",
description=__doc__,
help=__doc__,
)
parser.set_defaults(func=run)
......
# -*- coding: utf-8 -*-
"""
Train a new DAN model.
"""
from dan.ocr.document.train import add_document_parser
from dan.ocr.line.train import add_line_parser
......@@ -8,6 +11,7 @@ def add_train_parser(subcommands) -> None:
parser = subcommands.add_parser(
"train",
description=__doc__,
help=__doc__,
)
subcommands = parser.add_subparsers(metavar="subcommand")
......
......@@ -6,7 +6,7 @@ networkx==2.6.3
numpy==1.22.3
opencv-python==4.5.5.64
PyYAML==6.0
tensorboard==0.2.1
tensorboard==2.8.0
torch==1.11.0
torchvision==0.12.0
tqdm==4.62.3
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment