Skip to content
Snippets Groups Projects
Commit bb5b087f authored by manonBlanco's avatar manonBlanco
Browse files

Remove 'split' command

parent b3dbf692
No related branches found
No related tags found
1 merge request!85Remove 'split' command
Pipeline #163386 failed
This commit is part of merge request !85. Comments created here will be created in the context of that merge request.
......@@ -2,7 +2,6 @@
import argparse
from atr_data_generator.extract import add_extract_subparser
from atr_data_generator.split import add_split_subparser
def main():
......@@ -13,7 +12,6 @@ def main():
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_extract_subparser(subcommands)
add_split_subparser(subcommands)
args = vars(parser.parse_args())
......
# -*- coding: utf-8 -*-
"""
Split dataset
"""
from pathlib import Path
from teklia_toolbox.config import ConfigParser
from atr_data_generator.arguments import CommonArgs
from atr_data_generator.split.arguments import SplitArgs
from atr_data_generator.split.main import main
def add_split_subparser(subcommands):
parser = subcommands.add_parser(
"split",
description=__doc__,
help=__doc__,
)
parser.add_argument("--config", type=Path, help="Configuration file")
parser.set_defaults(func=main, config_parser=config_parser)
def get_parser():
parser = ConfigParser()
# Common
common = parser.add_subparser("common")
common.add_option("dataset_name", type=str)
common.add_option("output_dir", type=Path, default=None)
common.add_option("cache_dir", type=Path, default=None)
common.add_option("log_parameters", type=bool, default=True)
# Split
split = parser.add_subparser("split", default={})
split.add_option("train_ratio", type=float, default=0.8)
split.add_option("val_ratio", type=float, default=0.1)
split.add_option("test_ratio", type=float, default=0.1)
split.add_option("use_existing_split", type=bool, default=False)
return parser
def config_parser(configuration_path: Path):
"""
Returns parsed
- CommonArgs
- SplitArgs
"""
config_data = get_parser().parse(configuration_path)
return {
"common": CommonArgs(**config_data["common"]),
"split": SplitArgs(**config_data["split"]),
}
# -*- coding: utf-8 -*-
import logging
import random
from enum import Enum
from pathlib import Path
import numpy as np
from atr_data_generator.arguments import CommonArgs
from atr_data_generator.split.arguments import SplitArgs
from atr_data_generator.utils import export_parameters
logger = logging.getLogger(__name__)
SEED = 42
random.seed(SEED)
class Split(Enum):
Train: str = "train"
Test: str = "test"
Validation: str = "val"
class PartitionSplitter:
def __init__(
self,
common: CommonArgs,
split: SplitArgs,
):
self.output_dir = common.output_dir
self.split_train_ratio = split.train_ratio
self.split_test_ratio = split.test_ratio
self.split_val_ratio = split.val_ratio
def page_level_split(self, line_ids: list) -> dict:
"""
Split pages into train, validation and test subsets.
Don't split lines to avoid data leakage.
line_ids (list): a list of line ids named {page_id}_{line_number}_{line_id}
"""
# Get page ids from line ids to create splits at page level
page_ids = ["_".join(line_id.split("_")[:-2]) for line_id in line_ids]
# Remove duplicates and sort for reproducibility
page_ids = sorted(set(page_ids))
random.Random(SEED).shuffle(page_ids)
page_count = len(page_ids)
# Use np.split to split in three sets
stop_train_idx = round(page_count * self.split_train_ratio)
stop_val_idx = stop_train_idx + round(page_count * self.split_val_ratio)
train_page_ids, val_page_ids, test_page_ids = np.split(
page_ids, [stop_train_idx, stop_val_idx]
)
# Build dictionary that will be used to split lines {id: split}
page_dict = {page_id: Split.Train.value for page_id in train_page_ids}
page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids})
page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
return page_dict
def parse_data(self):
# Get all images ids (and remove extension)
lines_path = self.output_dir / "Lines"
return [
str(file.relative_to(lines_path).with_suffix(""))
for file in sorted(lines_path.rglob("*.jpg"))
]
def create_partitions(self):
"""
Data should be in output_dir / "Lines"
"""
logger.info(f"Creating {[split.value for split in Split]} partitions")
line_ids = self.parse_data()
page_dict = self.page_level_split(line_ids)
# extend this split for lines
datasets = {s.value: [] for s in Split}
for line_id in line_ids:
page_id = "_".join(line_id.split("_")[:-2])
split_id = page_dict[page_id]
datasets[split_id].append(line_id)
partitions_dir = self.output_dir / "Partitions"
partitions_dir.mkdir(exist_ok=True)
for split, split_line_ids in datasets.items():
if not split_line_ids:
logger.info(f"Partition {split} is empty! Skipping...")
continue
file_name = f"{partitions_dir}/{Split(split).name}Lines.lst"
Path(file_name).write_text("\n".join(split_line_ids) + "\n")
return datasets
def main(common: CommonArgs, split: SplitArgs):
data_partitioner = PartitionSplitter(common=common, split=split)
# create partitions from all the extracted data
datasets = data_partitioner.create_partitions()
export_parameters(
common=common,
split=split,
datasets=datasets,
)
# -*- coding: utf-8 -*-
import random
from atr_data_generator.arguments import CommonArgs
from atr_data_generator.split.arguments import SplitArgs
from atr_data_generator.split.main import PartitionSplitter
def test_create_partitions(fake_expected_partitions, tmp_path):
splitter = PartitionSplitter(
common=CommonArgs(dataset_name="test", output_dir=tmp_path), split=SplitArgs()
)
all_ids = []
for split_name, expected_split_ids in fake_expected_partitions.items():
all_ids += expected_split_ids
# shuffle to show that splits are reproducible
random.shuffle(all_ids)
# just to check that there's no problem with the data
assert len(all_ids) > 0
partitions_dir = splitter.output_dir / "Partitions"
lines_dir = splitter.output_dir / "Lines"
trans_dir = splitter.output_dir / "Transcriptions"
# create fake data
for line_id in all_ids:
trans_file = trans_dir / f"{line_id}.txt"
trans_file.parent.mkdir(parents=True, exist_ok=True)
trans_file.touch()
img_file = lines_dir / f"{line_id}.jpg"
img_file.parent.mkdir(parents=True, exist_ok=True)
img_file.touch()
splitter.create_partitions()
for split_name, expected_split_ids in fake_expected_partitions.items():
part_file = partitions_dir / split_name
assert part_file.read_text().strip().split("\n") == expected_split_ids
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment