Skip to content
Snippets Groups Projects
Commit 8a519d32 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Remove data splitting

parent df420e45
No related branches found
No related tags found
1 merge request!226Remove data splitting
......@@ -27,19 +27,6 @@ def parse_worker_version(worker_version_id):
return worker_version_id
def validate_probability(proba):
try:
proba = float(proba)
except ValueError:
raise argparse.ArgumentTypeError(f"`{proba}` is not a valid float.")
if proba > 1 or proba < 0:
raise argparse.ArgumentTypeError(
f"`{proba}` is not a valid probability. Must be between 0 and 1 (both exclusive)."
)
return proba
def validate_char(char):
if len(char) != 1:
raise argparse.ArgumentTypeError(
......@@ -62,13 +49,6 @@ def add_extract_parser(subcommands) -> None:
type=pathlib.Path,
help="Path where the data were exported from Arkindex.",
)
parser.add_argument(
"--parent",
type=validate_uuid,
nargs="+",
help="ID of the parent folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--element-type",
nargs="+",
......@@ -90,6 +70,25 @@ def add_extract_parser(subcommands) -> None:
required=True,
)
parser.add_argument(
"--train-folder",
type=validate_uuid,
help="ID of the training folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--val-folder",
type=validate_uuid,
help="ID of the validation folder to extract from Arkindex.",
required=True,
)
parser.add_argument(
"--test-folder",
type=validate_uuid,
help="ID of the testing folder to extract from Arkindex.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--load-entities",
......@@ -114,31 +113,6 @@ def add_extract_parser(subcommands) -> None:
required=False,
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Use the specified folder IDs for the dataset split.",
)
parser.add_argument(
"--train-folder",
type=validate_uuid,
help="ID of the training folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--val-folder",
type=validate_uuid,
help="ID of the validation folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--test-folder",
type=validate_uuid,
help="ID of the testing folder to import from Arkindex.",
required=False,
)
parser.add_argument(
"--transcription-worker-version",
type=parse_worker_version,
......@@ -152,20 +126,6 @@ def add_extract_parser(subcommands) -> None:
required=False,
)
parser.add_argument(
"--train-prob",
type=validate_probability,
default=0.7,
help="Training set split size.",
)
parser.add_argument(
"--val-prob",
type=validate_probability,
default=0.15,
help="Validation set split size.",
)
parser.add_argument(
"--max-width",
type=int,
......
......@@ -58,7 +58,7 @@ class Element:
def get_elements(
parent_id: str,
element_type: str,
element_type: List[str],
max_width: Optional[int] = None,
max_height: Optional[int] = None,
) -> List[Element]:
......@@ -69,7 +69,7 @@ def get_elements(
query = (
list_children(parent_id=parent_id)
.join(Image)
.where(ArkindexElement.type == element_type)
.where(ArkindexElement.type.in_(element_type))
.select(
ArkindexElement.id,
ArkindexElement.type,
......
......@@ -23,12 +23,9 @@ from dan.datasets.extract.exceptions import (
)
from dan.datasets.extract.utils import (
EntityType,
Subset,
download_image,
insert_token,
parse_tokens,
save_json,
save_text,
)
IMAGES_DIR = "images" # Subpath to the images directory.
......@@ -44,63 +41,29 @@ class ArkindexExtractor:
def __init__(
self,
folders: list = [],
element_type: list = [],
element_type: List[str] = [],
parent_element_type: str = None,
output: Path = None,
load_entities: bool = False,
entity_separators: list = [],
tokens: Path = None,
use_existing_split: bool = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
entity_worker_version: Optional[Union[str, bool]] = None,
train_prob: float = None,
val_prob: float = None,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
) -> None:
self.folders = folders
self.element_type = element_type
self.parent_element_type = parent_element_type
self.output = output
self.load_entities = load_entities
self.entity_separators = entity_separators
self.tokens = parse_tokens(tokens) if self.load_entities else None
self.use_existing_split = use_existing_split
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.train_prob = train_prob
self.val_prob = val_prob
self.max_width = max_width
self.max_height = max_height
self.subsets = self.get_subsets(folders)
def get_subsets(self, folders: list) -> List[Subset]:
"""
Assign each folder to its split if it's already known.
"""
if self.use_existing_split:
return [
Subset(folder, split) for folder, split in zip(folders, SPLIT_NAMES)
]
else:
return [Subset(folder) for folder in folders]
def _assign_random_split(self):
"""
Yields a randomly chosen split for an element.
Assumes that train_prob + valid_prob + test_prob = 1
"""
prob = random.random()
if prob <= self.train_prob:
yield SPLIT_NAMES[0]
elif prob <= self.train_prob + self.val_prob:
yield SPLIT_NAMES[1]
else:
yield SPLIT_NAMES[2]
def get_random_split(self):
return next(self._assign_random_split())
def _keep_char(self, char: str) -> bool:
# Keep all text by default if no separator was given
return not self.entity_separators or char in self.entity_separators
......@@ -177,14 +140,14 @@ class ArkindexExtractor:
transcription = random.choice(transcriptions)
if self.load_entities:
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
else:
if not self.load_entities:
return transcription.text.strip()
entities = get_transcription_entities(
transcription.id, self.entity_worker_version
)
return self.reconstruct_text(transcription.text, entities)
def process_element(
self,
element: Element,
......@@ -196,14 +159,12 @@ class ArkindexExtractor:
"""
text = self.extract_transcription(element)
txt_path = Path(
self.output, LABELS_DIR, split, f"{element.type}_{element.id}.txt"
)
save_text(txt_path, text)
im_path = Path(
self.output, IMAGES_DIR, split, f"{element.type}_{element.id}.jpg"
base_path = Path(split, f"{element.type}_{element.id}")
Path(self.output, LABELS_DIR, base_path).with_suffix(".txt").write_text(text)
download_image(
element, Path(self.output, LABELS_DIR, base_path).with_suffix(".jpg")
)
download_image(element, im_path)
return element.id
def process_parent(
......@@ -223,85 +184,59 @@ class ArkindexExtractor:
logger.warning(f"Skipping {parent.id}: {str(e)}")
# Extract children elements
else:
for element_type in self.element_type:
for element in get_elements(
parent.id,
element_type,
max_width=self.max_width,
max_height=self.max_height,
):
try:
data[element_type].append(self.process_element(element, split))
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
for element in get_elements(
parent.id,
self.element_type,
max_width=self.max_width,
max_height=self.max_height,
):
try:
data[element.type].append(self.process_element(element, split))
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
return data
def run(self):
split_dict = defaultdict(dict)
# Iterate over the subsets to find the page images and labels.
for idx, subset in enumerate(self.subsets, start=1):
for idx, (folder_id, split) in enumerate(
zip(self.folders, SPLIT_NAMES), start=1
):
# Iterate over the pages to create splits at page level.
for parent in tqdm(
get_elements(
subset.id,
self.parent_element_type,
folder_id,
[self.parent_element_type],
max_width=self.max_width,
max_height=self.max_height,
),
desc=f"Processing {subset} {idx}/{len(self.subsets)}",
desc=f"Processing {folder_id} {idx}/{len(self.subsets)}",
):
split = subset.split or self.get_random_split()
split_dict[split][parent.id] = self.process_parent(
self.process_parent(
parent=parent,
split=split,
)
save_json(self.output / "split.json", split_dict)
def run(
database: Path,
parent: list,
element_type: str,
element_type: List[str],
parent_element_type: str,
output: Path,
load_entities: bool,
entity_separators: list,
tokens: Path,
use_existing_split: bool,
train_folder: UUID,
val_folder: UUID,
test_folder: UUID,
transcription_worker_version: Optional[Union[str, bool]],
entity_worker_version: Optional[Union[str, bool]],
train_prob,
val_prob,
max_width: Optional[int],
max_height: Optional[int],
):
assert (
use_existing_split or parent
), "One of `--use-existing-split` and `--parent` must be set"
assert use_existing_split ^ bool(
parent
), "Only one of `--use-existing-split` and `--parent` must be set"
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
if use_existing_split:
assert (
train_folder
), "If you use an existing split, you must specify the training folder."
assert (
val_folder
), "If you use an existing split, you must specify the validation folder."
assert (
test_folder
), "If you use an existing split, you must specify the testing folder."
folders = [str(train_folder), str(val_folder), str(test_folder)]
else:
folders = [str(parent_id) for parent_id in parent]
folders = [str(train_folder), str(val_folder), str(test_folder)]
if load_entities:
assert tokens, "Please provide the entities to match."
......@@ -319,11 +254,8 @@ def run(
load_entities=load_entities,
entity_separators=entity_separators,
tokens=tokens,
use_existing_split=use_existing_split,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
train_prob=train_prob,
val_prob=val_prob,
max_width=max_width,
max_height=max_height,
).run()
# -*- coding: utf-8 -*-
import json
import logging
import time
from pathlib import Path
......@@ -18,18 +17,6 @@ logger = logging.getLogger(__name__)
MAX_RETRIES = 5
class Subset(NamedTuple):
id: str
split: str = None
def __str__(self) -> str:
return (
f"Subset(id='{self.id}', split='{self.split.capitalize()}')"
if self.split
else f"Subset(id='{self.id}')"
)
class EntityType(NamedTuple):
start: str
end: str = ""
......@@ -60,20 +47,10 @@ def download_image(element: Element, im_path: Path):
raise ImageDownloadError(element.id, e)
def save_text(path: Path, text: str):
with path.open("w") as f:
f.write(text)
def save_image(path: Path, image: ndarray):
cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path: Path, data: dict):
with path.open("w") as outfile:
json.dump(data, outfile, indent=4)
def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -> str:
"""
Insert the given tokens at the right position in the text
......@@ -89,7 +66,7 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
def parse_tokens(filename: Path) -> dict:
with filename.open() as f:
return {
name: EntityType(**tokens) for name, tokens in yaml.safe_load(f).items()
}
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
}
# -*- coding: utf-8 -*-
import json
import os
import pickle
from itertools import pairwise
......@@ -10,7 +11,7 @@ import torch
import yaml
from dan import logger
from dan.datasets.extract.utils import parse_tokens, save_json
from dan.datasets.extract.utils import parse_tokens
from dan.decoder import GlobalHTADecoder
from dan.encoder import FCN_Encoder
from dan.predict.attention import (
......@@ -390,9 +391,9 @@ def process_batch(
)
result["attention_gif"] = gif_filename
json_filename = f"{output}/{image_path.stem}.json"
json_filename = Path(output, image_path.stem).with_suffix(".json")
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result)
json_filename.write_text(json.dumps(result, indent=2))
def run(
......
......@@ -17,7 +17,6 @@ At the end, you should have a tree structure like this:
output/
├── charset.pkl
├── labels.json
├── split.json
├── images
│ ├── train
│ ├── val
......
......@@ -7,21 +7,17 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
| Parameter | Description | Type | Default |
| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------- | ------------------------------------ |
| `database` | Path to an Arkindex export database in SQLite format. | `Path` | |
| `--parent` | UUID of the folder to import from Arkindex. You may specify multiple UUIDs. | `str` or `uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--parent-element-type` | Type of the parent element containing the data. | `str` | `page` |
| `--output` | Folder where the data will be generated. | `Path` | |
| `--load-entities` | Extract text with their entities. Needed for NER tasks. | `bool` | `False` |
| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text. | `str` | (see [dedicated section](#examples)) |
| `--tokens` | Mapping between starting tokens and end tokens. Needed for NER tasks. | `Path` | |
| `--use-existing-split` | Use the specified folder IDs for the dataset split. | `bool` | |
| `--train-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--val-folder` | ID of the validation folder to import from Arkindex. | `uuid` | |
| `--test-folder` | ID of the training folder to import from Arkindex. | `uuid` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--train-prob` | Training set split size | `float` | `0.7` |
| `--val-prob` | Validation set split size | `float` | `0.15` |
| `--max-width` | Images larger than this width will be resized to this width. | `int` | |
| `--max-height` | Images larger than this height will be resized to this height. | `int` | |
......@@ -53,44 +49,13 @@ CLASSEMENT:
## Examples
### HTR and NER data from one source
### HTR and NER data
To extract HTR+NER data from **pages** from a folder, you have to define an end token for each entity and use the following command:
To use the data from three folders as **training**, **validation** and **testing** dataset respectively, please use the following:
```shell
teklia-dan dataset extract \
database.sqlite \
--parent folder_uuid \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
with `tokens.yml` compliant with the format described before.
### HTR and NER data from multiple source
To do the same but only use the data from three folders, you have to define an end token for each entity and the commands becomes:
```shell
teklia-dan dataset extract \
database.sqlite \
--parent folder1_uuid folder2_uuid folder3_uuid \
--element-type page \
--output data \
--load-entities \
--tokens tokens.yml
```
### HTR and NER data with an existing split
To use the data from three folders as **training**, **validation** and **testing** dataset respectively, you have to define a end token for each entity and the commands becomes:
```shell
teklia-dan dataset extract \
database.sqlite \
--use-existing-split \
--train-folder train_folder_uuid \
--val-folder val_folder_uuid \
--test-folder test_folder_uuid \
......@@ -100,22 +65,24 @@ teklia-dan dataset extract \
--tokens tokens.yml
```
### HTR from multiple element types with some parent filtering
### HTR from multiple element types
To extract HTR data from **annotations** and **text_zones** from a folder, but only keep those that are children of **single_pages**, you have to define an end token for each entity and use the following command:
To extract HTR data from **annotations** and **text_zones** from each folder, but only keep those that are children of **single_pages**, please use the following:
```shell
teklia-dan dataset extract \
database.sqlite \
--parent folder_uuid \
--train-folder train_folder_uuid \
--val-folder val_folder_uuid \
--test-folder test_folder_uuid \
--element-type text_zone annotation \
--parent-element-type single_page \
--output data
```
### NER data
### HTR + NER data
To extract NER data and keep breaklines and spaces between entities, use the following command:
To extract NER data and keep line breaks and spaces between entities, use the following command:
```shell
teklia-dan dataset extract \
......@@ -125,4 +92,4 @@ teklia-dan dataset extract \
--tokens tokens.yml
```
If several separators follow each other, it will keep only one, ideally a breakline if there is one, otherwise a space. If you change the order of the `--entity-separators` parameters, then it will keep a space if there is one, otherwise a breakline.
If several separators follow each other, it will keep only one, ideally a line break if there is one, otherwise a space. If you change the order of the `--entity-separators` parameters, then it will keep a space if there is one, otherwise a line break.
......@@ -16,7 +16,7 @@ def test_get_elements():
"""
elements = get_elements(
parent_id="d2b9fe93-3198-42de-8c07-f4ab67990e21",
element_type="page",
element_type=["page"],
)
# Check number of results
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment