diff --git a/README.md b/README.md index 7c4b6b6c2daea093e1b86634a95f02b45a49f1b9..79540ed50a2f24814ed5c318d48564954db72289 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,26 @@ teklia-dan dataset extract \ --output data ``` +#### Dataset formatting +Use the `teklia-dan dataset format` command to format a dataset. This will generate two important files to train a DAN model: +- `labels.json` +- `charset.pkl` +The available arguments are + +| Parameter | Description | Type | Default | +| ------------------------------ | ----------------------------------------------------------------------------------- | -------- | ------- | +| `--dataset` | Path to the folder containing the dataset. | `str/uuid` | | +| `--image-format` | Format under which the images were generated. | `str` | | +| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `str` | | + +```shell +teklia-dan dataset format \ + --dataset path/to/dataset \ + --image-format png +``` +The created files will be stored at the root of your dataset. + + #### Model training `teklia-dan train` with multiple arguments. diff --git a/dan/datasets/__init__.py b/dan/datasets/__init__.py index 889e11cf0ee55e3b247d4745a452d412a3e638fc..8e8f6fe419498542c2beae411408677543008541 100644 --- a/dan/datasets/__init__.py +++ b/dan/datasets/__init__.py @@ -4,6 +4,7 @@ Preprocess datasets for training. """ from dan.datasets.extract import add_extract_parser +from dan.datasets.format import add_format_parser def add_dataset_parser(subcommands) -> None: @@ -15,3 +16,4 @@ def add_dataset_parser(subcommands) -> None: subcommands = parser.add_subparsers(metavar="subcommand") add_extract_parser(subcommands) + add_format_parser(subcommands) diff --git a/dan/datasets/format/__init__.py b/dan/datasets/format/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..427d28b3b1622fd6e1403848c0b00727e9bd41a5 100644 --- a/dan/datasets/format/__init__.py +++ b/dan/datasets/format/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +""" +Format datasets for training. +""" + +from pathlib import Path + +from dan.datasets.format.atr import run + + +def add_format_parser(subcommands) -> None: + parser = subcommands.add_parser( + "format", + description=__doc__, + help=__doc__, + ) + parser.add_argument( + "--dataset", + type=Path, + help="Path to the exported dataset.", + required=True, + ) + parser.add_argument( + "--image-format", + type=str, + help="Format under which the images were saved.", + required=True, + ) + parser.add_argument( + "--keep-spaces", + action="store_true", + help="Do not remove spaces in transcriptions.", + ) + + parser.set_defaults(func=run) diff --git a/dan/datasets/format/atr.py b/dan/datasets/format/atr.py new file mode 100644 index 0000000000000000000000000000000000000000..7da5a8da73f2efee25088ea58f3885e8b257f2b3 --- /dev/null +++ b/dan/datasets/format/atr.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +import json +import os +import pickle +import re +from collections import defaultdict +from pathlib import Path + +from tqdm import tqdm + + +def remove_spaces(text): + # remove begin/ending spaces + text = text.strip() + # replace \t with regular space + text = re.sub("\t", " ", text) + # remove consecutive spaces + text = re.sub(" +", " ", text) + return text + + +class ATRDatasetFormatter: + """ + Global pipeline/functions for dataset formatting + """ + + def __init__(self, dataset: Path, image_format: str, remove_spaces: bool): + self.dataset = dataset + self.set_names = ["train", "val", "test"] + self.remove_spaces = remove_spaces + + self.image_folder = self.dataset / "images" + self.labels_folder = self.dataset / "labels" + + self.image_format = image_format + if self.image_format.startswith("."): + self.image_format = self.image_format[1:] + + def format(self): + """ + Format ATR dataset + """ + ground_truth = defaultdict(dict) + charset = set() + for set_name in self.set_names: + set_folder = self.labels_folder / set_name + for file_name in tqdm( + os.listdir(set_folder), desc="Formatting " + set_name + ): + data = self.parse_labels(set_name, file_name) + charset = charset.union(set(data["label"])) + ground_truth[set_name][data["img_path"]] = { + "text": data["label"], + } + return ground_truth, charset + + def read_file(self, file_name): + with open(file_name, "r") as f: + text = f.read() + if self.remove_spaces: + text = remove_spaces(text) + return text.strip() + + def parse_labels(self, set_name, file_name): + return { + "img_path": os.path.join( + self.image_folder, + set_name, + f"{os.path.splitext(file_name)[0]}.{self.image_format}", + ), + "label": self.read_file( + os.path.join(self.labels_folder, set_name, file_name) + ), + } + + def run(self): + ground_truth, charset = self.format() + + with open(self.dataset / "labels.json", "w") as f: + json.dump( + ground_truth, + f, + sort_keys=True, + indent=4, + ) + with open(self.dataset / "charset.pkl", "wb") as f: + pickle.dump(sorted(list(charset)), f) + + +def run(dataset, image_format, keep_spaces): + ATRDatasetFormatter( + dataset=dataset, image_format=image_format, remove_spaces=not keep_spaces + ).run() diff --git a/dan/datasets/format/bessin.py b/dan/datasets/format/bessin.py deleted file mode 100644 index b5b44a135931a747a39343e72f56972f2bac52a8..0000000000000000000000000000000000000000 --- a/dan/datasets/format/bessin.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import os -import re -from collections import Counter - -from tqdm import tqdm - -from dan.datasets.format.generic import OCRDatasetFormatter - - -def remove_spaces(text): - # remove begin/ending spaces - text = text.strip() - # replace \t with regular space - text = re.sub("\t", " ", text) - # remove consecutive spaces - text = re.sub(" +", " ", text) - # text = text.encode('ascii', 'ignore').decode("utf-8") - return text - - -class BessinDatasetFormatter(OCRDatasetFormatter): - def __init__( - self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False - ): - super(BessinDatasetFormatter, self).__init__( - "bessin", level, "_sem" if sem_token else "", set_names - ) - - self.dpi = dpi - self.counter = Counter() - self.map_datasets_files.update( - { - "bessin": { - # (1,050 for train, 100 for validation and 100 for test) - "line": { - "needed_files": [], - "arx_files": [], - "format_function": self.format_bessin_zone, - } - } - } - ) - - def preformat_bessin_zone(self): - """ - Extract all information from dataset and correct some annotations - """ - - dataset = {"train": list(), "valid": list(), "test": list()} - img_folder_path = os.path.join("Datasets", "raw", "bessin", "images") - labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels") - - train_files = [ - os.path.join(labels_folder_path, "train", name) - for name in os.listdir(os.path.join(labels_folder_path, "train")) - ] - valid_files = [ - os.path.join(labels_folder_path, "valid", name) - for name in os.listdir(os.path.join(labels_folder_path, "valid")) - ] - test_files = [ - os.path.join(labels_folder_path, "test", name) - for name in os.listdir(os.path.join(labels_folder_path, "test")) - ] - - for set_name, files in zip( - self.set_names, [train_files, valid_files, test_files] - ): - for i, label_file in enumerate( - tqdm(files, desc="Pre-formatting " + set_name) - ): - with open(label_file, "r") as f: - text = remove_spaces(f.read()) - - dataset[set_name].append( - { - "img_path": os.path.join( - img_folder_path, - set_name, - label_file.split("/")[-1].replace("txt", "jpg"), - ), - "label": text.strip(), - } - ) - - return dataset - - def format_bessin_zone(self): - """ - Format synist page dataset - """ - dataset = self.preformat_bessin_zone() - for set_name in self.set_names: - fold = os.path.join(self.target_fold_path, set_name) - for sample in tqdm(dataset[set_name], desc="Formatting " + set_name): - new_name = sample["img_path"].split("/")[-1] - new_img_path = os.path.join(fold, new_name) - self.load_resize_save(sample["img_path"], new_img_path) - zone = { - "text": sample["label"], - } - self.charset = self.charset.union(set(zone["text"])) - self.gt[set_name][new_name] = zone - self.counter.update(zone["text"]) - - -if __name__ == "__main__": - formatter = BessinDatasetFormatter("line", sem_token=False) - formatter.format() - print("Character freq: ") - for k, v in formatter.counter.items(): - print(k, v) diff --git a/dan/datasets/format/generic.py b/dan/datasets/format/generic.py deleted file mode 100644 index 8bdc795157d37890039230108fc3611c865867d2..0000000000000000000000000000000000000000 --- a/dan/datasets/format/generic.py +++ /dev/null @@ -1,91 +0,0 @@ -# -*- coding: utf-8 -*- -import os -import pickle -import shutil - -import numpy as np -from PIL import Image - - -class DatasetFormatter: - """ - Global pipeline/functions for dataset formatting - """ - - def __init__( - self, dataset_name, level, extra_name="", set_names=["train", "valid", "test"] - ): - self.dataset_name = dataset_name - self.level = level - self.set_names = set_names - self.target_fold_path = os.path.join( - "Datasets", "formatted", "{}_{}{}".format(dataset_name, level, extra_name) - ) - self.map_datasets_files = dict() - - def format(self): - self.init_format() - self.map_datasets_files[self.dataset_name][self.level]["format_function"]() - self.end_format() - - def init_format(self): - """ - Load and extracts needed files - """ - os.makedirs(self.target_fold_path, exist_ok=True) - - for set_name in self.set_names: - os.makedirs(os.path.join(self.target_fold_path, set_name), exist_ok=True) - - -class OCRDatasetFormatter(DatasetFormatter): - """ - Specific pipeline/functions for OCR/HTR dataset formatting - """ - - def __init__( - self, source_dataset, level, extra_name="", set_names=["train", "valid", "test"] - ): - super(OCRDatasetFormatter, self).__init__( - source_dataset, level, extra_name, set_names - ) - self.charset = set() - self.gt = dict() - for set_name in set_names: - self.gt[set_name] = dict() - - def load_resize_save(self, source_path, target_path): - """ - Load image, apply resolution modification and save it - """ - shutil.copyfile(source_path, target_path) - - def resize(self, img, source_dpi, target_dpi): - """ - Apply resolution modification to image - """ - if source_dpi == target_dpi: - return img - if isinstance(img, np.ndarray): - h, w = img.shape[:2] - img = Image.fromarray(img) - else: - w, h = img.size - ratio = target_dpi / source_dpi - img = img.resize((int(w * ratio), int(h * ratio)), Image.BILINEAR) - return np.array(img) - - def end_format(self): - """ - Save label and charset files - """ - with open(os.path.join(self.target_fold_path, "labels.pkl"), "wb") as f: - pickle.dump( - { - "ground_truth": self.gt, - "charset": sorted(list(self.charset)), - }, - f, - ) - with open(os.path.join(self.target_fold_path, "charset.pkl"), "wb") as f: - pickle.dump(sorted(list(self.charset)), f) diff --git a/dan/datasets/format/simara.py b/dan/datasets/format/simara.py deleted file mode 100644 index 0e98e47ddfe5ba837fa5e9baf6b63c6129c338e4..0000000000000000000000000000000000000000 --- a/dan/datasets/format/simara.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import os -from collections import defaultdict - -from tqdm import tqdm - -from dan.datasets.format.generic import OCRDatasetFormatter - -# Layout string to token -SEM_MATCHING_TOKENS_STR = { - "INTITULE": "ⓘ", - "DATE": "â““", - "COTE_SERIE": "â“¢", - "ANALYSE_COMPL": "â“’", - "PRECISIONS_SUR_COTE": "â“Ÿ", - "COTE_ARTICLE": "â“", -} - -# Layout begin-token to end-token -SEM_MATCHING_TOKENS = {"ⓘ": "â’¾", "â““": "â’¹", "â“¢": "Ⓢ", "â“’": "â’¸", "â“Ÿ": "â“…", "â“": "â’¶"} - - -class SimaraDatasetFormatter(OCRDatasetFormatter): - def __init__( - self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True - ): - super(SimaraDatasetFormatter, self).__init__( - "simara", level, "_sem" if sem_token else "", set_names - ) - - self.dpi = dpi - self.sem_token = sem_token - self.map_datasets_files.update( - { - "simara": { - # (1,050 for train, 100 for validation and 100 for test) - "page": { - "format_function": self.format_simara_page, - }, - } - } - ) - self.matching_tokens_str = SEM_MATCHING_TOKENS_STR - self.matching_tokens = SEM_MATCHING_TOKENS - - def preformat_simara_page(self): - """ - Extract all information from dataset and correct some annotations - """ - dataset = defaultdict(list) - img_folder_path = os.path.join("Datasets", "raw", "simara", "images") - labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels") - sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem") - train_files = [ - os.path.join(labels_folder_path, "train", name) - for name in os.listdir(os.path.join(sem_labels_folder_path, "train")) - ] - valid_files = [ - os.path.join(labels_folder_path, "valid", name) - for name in os.listdir(os.path.join(sem_labels_folder_path, "valid")) - ] - test_files = [ - os.path.join(labels_folder_path, "test", name) - for name in os.listdir(os.path.join(sem_labels_folder_path, "test")) - ] - for set_name, files in zip( - self.set_names, [train_files, valid_files, test_files] - ): - for i, label_file in enumerate( - tqdm(files, desc="Pre-formatting " + set_name) - ): - with open(label_file, "r") as f: - text = f.read() - with open(label_file.replace("labels", "labels_sem"), "r") as f: - sem_text = f.read() - dataset[set_name].append( - { - "img_path": os.path.join( - img_folder_path, - set_name, - label_file.split("/")[-1].replace("txt", "jpg"), - ), - "label": text, - "sem_label": sem_text, - } - ) - return dataset - - def format_simara_page(self): - """ - Format simara page dataset - """ - dataset = self.preformat_simara_page() - for set_name in self.set_names: - fold = os.path.join(self.target_fold_path, set_name) - for sample in tqdm(dataset[set_name], desc="Formatting " + set_name): - new_name = sample["img_path"].split("/")[-1] - new_img_path = os.path.join(fold, new_name) - self.load_resize_save( - sample["img_path"], new_img_path - ) # , 300, self.dpi) - page = { - "text": sample["label"] - if not self.sem_token - else sample["sem_label"], - } - self.charset = self.charset.union(set(page["text"])) - self.gt[set_name][new_name] = page - - -if __name__ == "__main__": - formatter = SimaraDatasetFormatter("page", sem_token=False) - formatter.format() diff --git a/dan/datasets/format/synist.py b/dan/datasets/format/synist.py deleted file mode 100644 index be76d0bd664d6fe8db84039f581ecf420b35c70f..0000000000000000000000000000000000000000 --- a/dan/datasets/format/synist.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import os -import re -from collections import Counter - -from tqdm import tqdm - -from dan.datasets.format.generic import OCRDatasetFormatter - - -def remove_spaces(text): - # remove begin/ending spaces - text = text.strip() - # replace \t with regular space - text = re.sub("\t", " ", text) - # remove consecutive spaces - text = re.sub(" +", " ", text) - return text - - -class SynistDatasetFormatter(OCRDatasetFormatter): - def __init__( - self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False - ): - super(SynistDatasetFormatter, self).__init__( - "synist_synth", level, "_sem" if sem_token else "", set_names - ) - - self.dpi = dpi - self.counter = Counter() - self.map_datasets_files.update( - { - "synist_synth": { - # (1,050 for train, 100 for validation and 100 for test) - "line": { - "needed_files": [], - "arx_files": [], - "format_function": self.format_synist_page, - } - } - } - ) - - def preformat_synist_page(self): - """ - Extract all information from dataset and correct some annotations - """ - - dataset = {"train": list(), "valid": list(), "test": list()} - img_folder_path = os.path.join( - "Datasets", "raw", "synist_synth_lines", "images" - ) - labels_folder_path = os.path.join( - "Datasets", "raw", "synist_synth_lines", "labels" - ) - - train_files = [ - os.path.join(labels_folder_path, "train", name) - for name in os.listdir(os.path.join(labels_folder_path, "train")) - ] - valid_files = [ - os.path.join(labels_folder_path, "valid", name) - for name in os.listdir(os.path.join(labels_folder_path, "valid")) - ] - test_files = [ - os.path.join(labels_folder_path, "test", name) - for name in os.listdir(os.path.join(labels_folder_path, "test")) - ] - - for set_name, files in zip( - self.set_names, [train_files, valid_files, test_files] - ): - for i, label_file in enumerate( - tqdm(files, desc="Pre-formatting " + set_name) - ): - with open(label_file, "r") as f: - text = remove_spaces(f.read()) - - dataset[set_name].append( - { - "img_path": os.path.join( - img_folder_path, - set_name, - label_file.split("/")[-1].replace("txt", "png"), - ), - "label": text.strip(), - } - ) - - return dataset - - def format_synist_page(self): - """ - Format synist page dataset - """ - dataset = self.preformat_synist_page() - for set_name in self.set_names: - fold = os.path.join(self.target_fold_path, set_name) - for sample in tqdm(dataset[set_name], desc="Formatting " + set_name): - new_name = sample["img_path"].split("/")[-1] - new_img_path = os.path.join(fold, new_name) - # self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi) - self.load_resize_save(sample["img_path"], new_img_path) - # self.load_flip_save(new_img_path, new_img_path) - page = { - "text": sample["label"], - } - self.charset = self.charset.union(set(page["text"])) - self.gt[set_name][new_name] = page - self.counter.update(page["text"]) - - -if __name__ == "__main__": - formatter = SynistDatasetFormatter("line", sem_token=False) - formatter.format() - print(formatter.counter) - print(formatter.counter.most_common(80)) - for k, v in formatter.counter.items(): - print(k) - print(k.encode("utf-8"), v) diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 64633ebe53445b4b1768e773ba76c9195e079f77..b423531cba90f7d1056b0abbc9f13246c50d0366 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- +import json import os -import pickle import random import cv2 @@ -39,7 +39,7 @@ class DatasetManager: self.batch_size = { "train": self.params["batch_size"], - "valid": self.params["valid_batch_size"] + "val": self.params["valid_batch_size"] if "valid_batch_size" in self.params else self.params["batch_size"], "test": self.params["test_batch_size"] @@ -70,12 +70,12 @@ class DatasetManager: ) self.apply_specific_treatment_after_dataset_loading(self.train_dataset) - for custom_name in self.params["valid"].keys(): + for custom_name in self.params["val"].keys(): self.valid_datasets[custom_name] = self.dataset_class( self.params, - "valid", + "val", custom_name, - self.get_paths_and_sets(self.params["valid"][custom_name]), + self.get_paths_and_sets(self.params["val"][custom_name]), ) self.apply_specific_treatment_after_dataset_loading( self.valid_datasets[custom_name] @@ -124,7 +124,7 @@ class DatasetManager: for key in self.valid_datasets.keys(): self.valid_loaders[key] = DataLoader( self.valid_datasets[key], - batch_size=self.batch_size["valid"], + batch_size=self.batch_size["val"], sampler=self.valid_samplers[key], batch_sampler=self.valid_samplers[key], shuffle=False, @@ -250,39 +250,38 @@ class GenericDataset(Dataset): Load images and labels """ samples = list() + for path_and_set in paths_and_sets: path = path_and_set["path"] + with open(os.path.join(path, "labels.json"), "rb") as f: + gt_per_set = json.load(f) set_name = path_and_set["set_name"] - with open(os.path.join(path, "labels.pkl"), "rb") as f: - info = pickle.load(f) - gt = info["ground_truth"][set_name] - for filename in natural_sort(gt.keys()): - name = os.path.join(os.path.basename(path), set_name, filename) - full_path = os.path.join(path, set_name, filename) - if isinstance(gt[filename], dict) and "text" in gt[filename]: - label = gt[filename]["text"] - else: - label = gt[filename] - samples.append( - { - "name": name, - "label": label, - "unchanged_label": label, - "path": full_path, - "nb_cols": 1 - if "nb_cols" not in gt[filename] - else gt[filename]["nb_cols"], - } - ) - if load_in_memory: - samples[-1]["img"] = GenericDataset.load_image(full_path) - if type(gt[filename]) is dict: - if "lines" in gt[filename].keys(): - samples[-1]["raw_line_seg_label"] = gt[filename]["lines"] - if "paragraphs" in gt[filename].keys(): - samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"] - if "pages" in gt[filename].keys(): - samples[-1]["pages_label"] = gt[filename]["pages"] + gt = gt_per_set[set_name] + for filename in natural_sort(gt.keys()): + if isinstance(gt[filename], dict) and "text" in gt[filename]: + label = gt[filename]["text"] + else: + label = gt[filename] + samples.append( + { + "name": os.path.basename(filename), + "label": label, + "unchanged_label": label, + "path": os.path.abspath(filename), + "nb_cols": 1 + if "nb_cols" not in gt[filename] + else gt[filename]["nb_cols"], + } + ) + if load_in_memory: + samples[-1]["img"] = GenericDataset.load_image(filename) + if type(gt[filename]) is dict: + if "lines" in gt[filename].keys(): + samples[-1]["raw_line_seg_label"] = gt[filename]["lines"] + if "paragraphs" in gt[filename].keys(): + samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"] + if "pages" in gt[filename].keys(): + samples[-1]["pages_label"] = gt[filename]["pages"] return samples def apply_preprocessing(self, preprocessings): @@ -344,7 +343,7 @@ class GenericDataset(Dataset): self.params["config"][key] if key in self.params["config"].keys() else None for key in ["augmentation", "valid_augmentation", "test_augmentation"] ] - for aug, set_name in zip(augs, ["train", "valid", "test"]): + for aug, set_name in zip(augs, ["train", "val", "test"]): if aug and self.set_name == set_name: return apply_data_augmentation(img, aug) return img, list() diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index 3bb572b25054455aadf14fdb72f9273c8c57bd03..c9e3a64203008e833366274c7df8335cd9883e08 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -5,8 +5,8 @@ import editdistance import networkx as nx import numpy as np -from dan.datasets.format.simara import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS from dan.post_processing import PostProcessingModuleSIMARA +from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS class MetricManager: diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 71463042e14e381161385ba443433e69ca9b404a..199cf5dcced64738942c65fc4c14bcc49cbd0a66 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -78,9 +78,8 @@ class OCRDatasetManager(DatasetManager): datasets = self.params["datasets"] charset = set() for key in datasets.keys(): - with open(os.path.join(datasets[key], "labels.pkl"), "rb") as f: - info = pickle.load(f) - charset = charset.union(set(info["charset"])) + with open(os.path.join(datasets[key], "charset.pkl"), "rb") as f: + charset = charset.union(set(pickle.load(f))) if ( "\n" in charset and "remove_linebreaks" in self.params["config"]["constraints"] diff --git a/dan/manager/training.py b/dan/manager/training.py index 90f1362648ed96fc8b130d5039ccae98556a15e9..682e3091735c01244e9249568675eb1ad86c53d0 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -828,7 +828,6 @@ class GenericTrainingManager: pbar.set_postfix(values=str(display_values)) pbar.update(len(batch_data["names"])) - self.dataset.remove_test_dataset(custom_name) # output metrics values if requested if output: if "pred" in metric_names: @@ -981,14 +980,14 @@ class OCRManager(GenericTrainingManager): os.makedirs(path, exist_ok=True) charset = set() dataset = None - gt = {"train": dict(), "valid": dict(), "test": dict()} - for set_name in ["train", "valid", "test"]: + gt = {"train": dict(), "val": dict(), "test": dict()} + for set_name in ["train", "val", "test"]: set_path = os.path.join(path, set_name) os.makedirs(set_path, exist_ok=True) if set_name == "train": dataset = self.dataset.train_dataset - elif set_name == "valid": - dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)] + elif set_name == "val": + dataset = self.dataset.valid_datasets["{}-val".format(dataset_name)] elif set_name == "test": self.dataset.generate_test_loader( "{}-test".format(dataset_name), @@ -1028,14 +1027,15 @@ class OCRManager(GenericTrainingManager): if "line_label" in sample: gt[set_name][img_name]["lines"] = sample["line_label"] - with open(os.path.join(path, "labels.pkl"), "wb") as f: - pickle.dump( - { - "ground_truth": gt, - "charset": sorted(list(charset)), - }, + with open(os.path.join(path / "labels.json"), "w") as f: + json.dump( + gt, f, + sort_keys=True, + indent=4, ) + with open(os.path.join(path / "charset.pkl"), "wb") as f: + pickle.dump(sorted(list(charset)), f) class Manager(OCRManager): diff --git a/dan/manager/utils.py b/dan/manager/utils.py index c95ca7ca81a65f10e030f76286d3c71128d1137c..b5b852b7c009178ab85da51fd44be9c4eeca60e4 100644 --- a/dan/manager/utils.py +++ b/dan/manager/utils.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import json import os import pickle @@ -24,14 +25,14 @@ class OCRManager(GenericTrainingManager): os.makedirs(path, exist_ok=True) charset = set() dataset = None - gt = {"train": dict(), "valid": dict(), "test": dict()} - for set_name in ["train", "valid", "test"]: + gt = {"train": dict(), "val": dict(), "test": dict()} + for set_name in ["train", "val", "test"]: set_path = os.path.join(path, set_name) os.makedirs(set_path, exist_ok=True) if set_name == "train": dataset = self.dataset.train_dataset - elif set_name == "valid": - dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)] + elif set_name == "val": + dataset = self.dataset.valid_datasets["{}-val".format(dataset_name)] elif set_name == "test": self.dataset.generate_test_loader( "{}-test".format(dataset_name), @@ -71,11 +72,12 @@ class OCRManager(GenericTrainingManager): if "line_label" in sample: gt[set_name][img_name]["lines"] = sample["line_label"] - with open(os.path.join(path, "labels.pkl"), "wb") as f: - pickle.dump( - { - "ground_truth": gt, - "charset": sorted(list(charset)), - }, + with open(os.path.join(path / "labels.json"), "w") as f: + json.dump( + gt, f, + sort_keys=True, + indent=4, ) + with open(os.path.join(path / "charset.pkl"), "wb") as f: + pickle.dump(sorted(list(charset)), f) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 64ca1d313ba8cca5ed7615a9104486208f8999a7..a2b5e5684580a4ee8fb9a27acb7c305215812a65 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -34,7 +34,7 @@ def train_and_test(rank, params): metrics = ["cer", "wer", "time", "map_cer", "loer"] for dataset_name in params["dataset_params"]["datasets"].keys(): - for set_name in ["test", "valid", "train"]: + for set_name in ["test", "val", "train"]: model.predict( "{}-{}".format(dataset_name, set_name), [ @@ -55,7 +55,7 @@ def run(): "dataset_manager": OCRDatasetManager, "dataset_class": OCRDataset, "datasets": { - dataset_name: "../../../Datasets/formatted/{}_{}{}".format( + dataset_name: "{}_{}{}".format( dataset_name, dataset_level, dataset_variant ), }, @@ -65,9 +65,9 @@ def run(): (dataset_name, "train"), ], }, - "valid": { - "{}-valid".format(dataset_name): [ - (dataset_name, "valid"), + "val": { + "{}-val".format(dataset_name): [ + (dataset_name, "val"), ], }, "config": { @@ -190,7 +190,7 @@ def run(): "eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training "focus_metric": "cer", # Metrics to focus on to determine best epoch "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value - "set_name_focus_metric": "{}-valid".format( + "set_name_focus_metric": "{}-val".format( dataset_name ), # Which dataset to focus on to select best weights "train_metrics": [ diff --git a/dan/ocr/line/generate_synthetic.py b/dan/ocr/line/generate_synthetic.py index de37803175c592b7a9f6c041666c2dae2e23742d..4efcd1129eb5f7e8d5e8185a27a05eeb49039368 100644 --- a/dan/ocr/line/generate_synthetic.py +++ b/dan/ocr/line/generate_synthetic.py @@ -48,9 +48,9 @@ def run(): (dataset_name, "train"), ], }, - "valid": { - "{}-valid".format(dataset_name): [ - (dataset_name, "valid"), + "val": { + "{}-val".format(dataset_name): [ + (dataset_name, "val"), ], }, "config": { @@ -135,7 +135,7 @@ def run(): "eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training "focus_metric": "cer", # Metrics to focus on to determine best epoch "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value - "set_name_focus_metric": "{}-valid".format(dataset_name), + "set_name_focus_metric": "{}-val".format(dataset_name), "train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training "eval_metrics": [ "loss_ctc", diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py index 7d2bab45505f995de8808636084f64e660c6c4b7..e76a54f3b588b83ae24fbb85cdc8446bb64a28ff 100644 --- a/dan/ocr/line/train.py +++ b/dan/ocr/line/train.py @@ -42,7 +42,7 @@ def train_and_test(rank, params): for dataset_name in params["dataset_params"]["datasets"].keys(): for set_name in [ "test", - "valid", + "val", "train", ]: model.predict( @@ -73,9 +73,9 @@ def run(): (dataset_name, "train"), ], }, - "valid": { - "{}-valid".format(dataset_name): [ - (dataset_name, "valid"), + "val": { + "{}-val".format(dataset_name): [ + (dataset_name, "val"), ], }, "config": { @@ -175,7 +175,7 @@ def run(): "eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training "focus_metric": "cer", # Metrics to focus on to determine best epoch "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value - "set_name_focus_metric": "{}-valid".format( + "set_name_focus_metric": "{}-val".format( dataset_name ), # Which dataset to focus on to select best weights "train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training diff --git a/dan/post_processing.py b/dan/post_processing.py index 1242959c8c315aa643c624d686e366d8e0a35b9a..e2fdca5623e66caa534b0726b2df89546143b0cf 100644 --- a/dan/post_processing.py +++ b/dan/post_processing.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np -from dan.datasets.format.simara import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS +from dan.utils import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS class PostProcessingModule: diff --git a/dan/utils.py b/dan/utils.py index 3c8ea8f7118122a2cf1c5563a94367ef6117945a..764bff9014689c50ae51126cab4d4447539525d5 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -4,6 +4,19 @@ import numpy as np import torch from torch.distributions.uniform import Uniform +# Layout string to token +SEM_MATCHING_TOKENS_STR = { + "INTITULE": "ⓘ", + "DATE": "â““", + "COTE_SERIE": "â“¢", + "ANALYSE_COMPL": "â“’", + "PRECISIONS_SUR_COTE": "â“Ÿ", + "COTE_ARTICLE": "â“", +} + +# Layout begin-token to end-token +SEM_MATCHING_TOKENS = {"ⓘ": "â’¾", "â““": "â’¹", "â“¢": "Ⓢ", "â“’": "â’¸", "â“Ÿ": "â“…", "â“": "â’¶"} + def randint(low, high): """ diff --git a/requirements.txt b/requirements.txt index e850c9790e2b3159977949c3363f3b1a81239b79..7b01e372606d830a98d7bf420e533c93293e762d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ networkx==2.8.8 numpy==1.23.5 opencv-python==4.6.0.66 PyYAML==6.0 +scipy==1.9.3 tensorboard==2.11.0 torch==1.13.0 torchvision==0.14.0