Skip to content
Snippets Groups Projects
Commit 8aabca8d authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Support Arkindex datasets during extraction

parent 543d1e23
No related branches found
No related tags found
1 merge request!328Support Arkindex datasets during extraction
......@@ -46,6 +46,12 @@ def add_extract_parser(subcommands) -> None:
type=pathlib.Path,
help="Path where the data were exported from Arkindex.",
)
parser.add_argument(
"--dataset-id",
type=UUID,
help="ID of the dataset to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
......@@ -53,13 +59,6 @@ def add_extract_parser(subcommands) -> None:
help="Type of elements to retrieve.",
required=True,
)
parser.add_argument(
"--parent-element-type",
type=str,
help="Type of the parent element containing the data.",
required=False,
default="page",
)
parser.add_argument(
"--output",
type=pathlib.Path,
......@@ -67,25 +66,6 @@ def add_extract_parser(subcommands) -> None:
required=True,
)
parser.add_argument(
"--train-folder",
type=UUID,
help="ID of the training folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--val-folder",
type=UUID,
help="ID of the validation folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--test-folder",
type=UUID,
help="ID of the testing folder to extract from Arkindex.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--entity-separators",
......
......@@ -11,9 +11,10 @@ from uuid import UUID
from tqdm import tqdm
from arkindex_export import open_database
from arkindex_export import Dataset, open_database
from dan.datasets.extract.db import (
Element,
get_dataset_elements,
get_elements,
get_transcription_entities,
get_transcriptions,
......@@ -36,7 +37,9 @@ from dan.utils import LMTokenMapping, parse_tokens
LANGUAGE_DIR = "language_model" # Subpath to the language model directory.
TRAIN_NAME = "train"
SPLIT_NAMES = [TRAIN_NAME, "val", "test"]
VAL_NAME = "val"
TEST_NAME = "test"
SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
logger = logging.getLogger(__name__)
......@@ -48,9 +51,8 @@ class ArkindexExtractor:
def __init__(
self,
folders: list = [],
dataset_id: UUID | None = None,
element_type: List[str] = [],
parent_element_type: str | None = None,
output: Path | None = None,
entity_separators: List[str] = ["\n", " "],
unknown_token: str = "",
......@@ -61,9 +63,8 @@ class ArkindexExtractor:
allow_empty: bool = False,
subword_vocab_size: int = 1000,
) -> None:
self.folders = folders
self.dataset_id = dataset_id
self.element_type = element_type
self.parent_element_type = parent_element_type
self.output = output
self.entity_separators = entity_separators
self.unknown_token = unknown_token
......@@ -272,20 +273,24 @@ class ArkindexExtractor:
)
def run(self):
# Retrieve the Dataset and its splits from the cache
dataset = Dataset.get_by_id(self.dataset_id)
splits = dataset.sets.split(",")
assert set(splits).issubset(
set(SPLIT_NAMES)
), f'Dataset must have "{TRAIN_NAME}", "{VAL_NAME}" and "{TEST_NAME}" steps'
# Iterate over the subsets to find the page images and labels.
for folder_id, split in zip(self.folders, SPLIT_NAMES):
for split in splits:
with tqdm(
get_elements(
folder_id,
[self.parent_element_type],
),
desc=f"Extracting data from ({folder_id}) for split ({split})",
get_dataset_elements(dataset, split),
desc=f"Extracting data from ({self.dataset_id}) for split ({split})",
) as pbar:
# Iterate over the pages to create splits at page level.
for parent in pbar:
self.process_parent(
pbar=pbar,
parent=parent,
parent=parent.element,
split=split,
)
# Progress bar updates
......@@ -303,15 +308,12 @@ class ArkindexExtractor:
def run(
database: Path,
dataset_id: UUID,
element_type: List[str],
parent_element_type: str,
output: Path,
entity_separators: List[str],
unknown_token: str,
tokens: Path,
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: str | bool | None,
entity_worker_version: str | bool | None,
keep_spaces: bool,
......@@ -321,15 +323,12 @@ def run(
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
folders = [str(train_folder), str(val_folder), str(test_folder)]
# Create directories
Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True)
ArkindexExtractor(
folders=folders,
dataset_id=dataset_id,
element_type=element_type,
parent_element_type=parent_element_type,
output=output,
entity_separators=entity_separators,
unknown_token=unknown_token,
......
# -*- coding: utf-8 -*-
from typing import List
from arkindex_export import Image
from arkindex_export.models import (
Dataset,
DatasetElement,
Element,
Entity,
EntityType,
......@@ -13,6 +14,26 @@ from arkindex_export.models import (
from arkindex_export.queries import list_children
def get_dataset_elements(
dataset: Dataset,
split: str,
):
"""
Retrieve dataset elements in a specific split from an SQLite export of an Arkindex corpus
"""
query = (
DatasetElement.select(DatasetElement.element)
.join(Element)
.join(Image, on=(DatasetElement.element.image == Image.id))
.where(
DatasetElement.dataset == dataset,
DatasetElement.set_name == split,
)
)
return query
def get_elements(
parent_id: str,
element_type: List[str],
......
......@@ -6,7 +6,7 @@ There are a several steps to follow when training a DAN model.
To extract the data, DAN uses an Arkindex export database in SQLite format. You will need to:
1. Structure the data into folders (`train` / `val` / `test`) in [Arkindex](https://demo.arkindex.org/).
1. Structure the data into splits (`train` / `val` / `test`) in a project dataset in [Arkindex](https://demo.arkindex.org/).
1. [Export the project](https://doc.arkindex.org/howto/export/) in SQLite format.
1. Extract the data with the [extract command](../usage/datasets/extract.md).
1. Download images with the [download command](../usage/datasets/download.md).
......
......@@ -11,15 +11,12 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
| Parameter | Description | Type | Default |
| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | |
| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `pathlib.Path` | |
| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | |
| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `⁇` |
| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | |
| `--train-folder` | ID of the training folder to extract from Arkindex. | `uuid` | |
| `--val-folder` | ID of the validation folder to extract from Arkindex. | `uuid` | |
| `--test-folder` | ID of the training folder to extract from Arkindex. | `uuid` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` |
......@@ -61,9 +58,7 @@ To use the data from three folders as **training**, **validation** and **testing
```shell
teklia-dan dataset extract \
database.sqlite \
--train-folder train_folder_uuid \
--val-folder val_folder_uuid \
--test-folder test_folder_uuid \
--dataset-id dataset_uuid \
--element-type page \
--output data \
--tokens tokens.yml
......@@ -115,10 +110,7 @@ To extract HTR data from **annotations** and **text_zones** from each folder, bu
```shell
teklia-dan dataset extract \
database.sqlite \
--train-folder train_folder_uuid \
--val-folder val_folder_uuid \
--test-folder test_folder_uuid \
--dataset-id dataset_uuid \
--element-type text_zone annotation \
--parent-element-type single_page \
--output data
```
......@@ -7,6 +7,8 @@ from typing import List
import pytest
from arkindex_export import (
Dataset,
DatasetElement,
Element,
ElementPath,
Entity,
......@@ -19,6 +21,7 @@ from arkindex_export import (
WorkerVersion,
database,
)
from dan.datasets.extract.arkindex import SPLIT_NAMES
from tests import FIXTURES
......@@ -133,6 +136,8 @@ def mock_database(tmp_path_factory):
WorkerRun,
ImageServer,
Image,
Dataset,
DatasetElement,
Element,
ElementPath,
EntityType,
......@@ -175,8 +180,26 @@ def mock_database(tmp_path_factory):
type="worker",
)
# Create folders
create_element(id="root")
# Create dataset
dataset = Dataset.create(
id="dataset", name="Dataset", state="complete", sets=",".join(SPLIT_NAMES)
)
# Create dataset elements
for split in SPLIT_NAMES:
element_path = (FIXTURES / "extraction" / "elements" / split).with_suffix(
".json"
)
element_json = json.loads(element_path.read_text())
# Recursive function to create children
for child in element_json.get("children", []):
create_element(id=child)
# Linking the element to the dataset split
DatasetElement.create(
id=child, element_id=child, dataset=dataset, set_name=split
)
# Create data for entities extraction tests
# Create transcription
......
......@@ -4,26 +4,55 @@ from operator import itemgetter
import pytest
from dan.datasets.extract.arkindex import TRAIN_NAME
from dan.datasets.extract.db import (
Dataset,
DatasetElement,
Element,
get_dataset_elements,
get_elements,
get_transcription_entities,
get_transcriptions,
)
def test_get_dataset_elements(mock_database):
"""
Assert dataset elements retrieval output against verified results
"""
dataset_elements = get_dataset_elements(
dataset=Dataset.select().get(),
split=TRAIN_NAME,
)
# ID verification
assert all(
isinstance(dataset_element, DatasetElement)
for dataset_element in dataset_elements
)
assert [dataset_element.element.id for dataset_element in dataset_elements] == [
"train-page_1",
"train-page_2",
]
def test_get_elements(mock_database):
"""
Assert elements retrieval output against verified results
"""
elements = get_elements(
parent_id="train",
element_type=["double_page"],
parent_id="train-page_1",
element_type=["text_line"],
)
# ID verification
assert all(isinstance(element, Element) for element in elements)
assert [element.id for element in elements] == ["train-page_1", "train-page_2"]
assert [element.id for element in elements] == [
"train-page_1-line_1",
"train-page_1-line_2",
"train-page_1-line_3",
"train-page_1-line_4",
]
@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
......
......@@ -253,9 +253,8 @@ def test_extract(
]
extractor = ArkindexExtractor(
folders=["train", "val", "test"],
dataset_id="dataset",
element_type=["text_line"],
parent_element_type="double_page",
output=output,
# Keep the whole text
entity_separators=None,
......@@ -419,9 +418,7 @@ def test_extract(
@pytest.mark.parametrize("allow_empty", (True, False))
def test_empty_transcription(allow_empty, mock_database):
extractor = ArkindexExtractor(
folders=["train", "val", "test"],
element_type=["text_line"],
parent_element_type="double_page",
output=None,
entity_separators=None,
tokens=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment