From f9b83350ab94c067b1921768b80743fde694bc73 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Mon, 14 Nov 2022 17:40:51 +0100 Subject: [PATCH] remove not needed files, reorganize inside dan module --- .../dataset_formatters/bessin_formatter.py | 133 ----------------- .../dataset_formatters/synist_formatter.py | 137 ------------------ dan/cli.py | 4 +- dan/datasets/extract/arkindex_utils.py | 1 - dan/datasets/extract/extract_from_arkindex.py | 8 - dan/datasets/extract/utils.py | 26 +++- dan/datasets/format/bessin.py | 113 +++++++++++++++ dan/datasets/format/synist.py | 120 +++++++++++++++ dan/datasets/utils.py | 33 +++++ dan/decoder.py | 1 - dan/ocr/train.py | 3 +- prediction-requirements.txt | 8 - requirements.txt | 8 + setup.py | 1 - 14 files changed, 303 insertions(+), 293 deletions(-) delete mode 100644 Datasets/dataset_formatters/bessin_formatter.py delete mode 100644 Datasets/dataset_formatters/synist_formatter.py create mode 100644 dan/datasets/format/bessin.py create mode 100644 dan/datasets/format/synist.py delete mode 100644 prediction-requirements.txt diff --git a/Datasets/dataset_formatters/bessin_formatter.py b/Datasets/dataset_formatters/bessin_formatter.py deleted file mode 100644 index c7c3db66..00000000 --- a/Datasets/dataset_formatters/bessin_formatter.py +++ /dev/null @@ -1,133 +0,0 @@ -# contributors : -# - Denis Coquenet -# -# -# This software is a computer program written in Python whose purpose is to -# provide public implementation of deep learning works, in pytorch. -# -# This software is governed by the CeCILL-C license under French law and -# abiding by the rules of distribution of free software. You can use, -# modify and/ or redistribute the software under the terms of the CeCILL-C -# license as circulated by CEA, CNRS and INRIA at the following URL -# "http://www.cecill.info". -# -# As a counterpart to the access to the source code and rights to copy, -# modify and redistribute granted by the license, users are provided only -# with a limited warranty and the software's author, the holder of the -# economic rights, and the successive licensors have only limited -# liability. -# -# In this respect, the user's attention is drawn to the risks associated -# with loading, using, modifying and/or developing or reproducing the -# software by the user in light of its specific status of free software, -# that may mean that it is complicated to manipulate, and that also -# therefore means that it is reserved for developers and experienced -# professionals having in-depth computer knowledge. Users are therefore -# encouraged to load and test the software's suitability as regards their -# requirements in conditions enabling the security of their systems and/or -# data to be ensured and, more generally, to use and operate it in the -# same conditions as regards security. -# -# The fact that you are presently reading this means that you have had -# knowledge of the CeCILL-C license and that you accept its terms. - -from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter -import os -import numpy as np -from Datasets.dataset_formatters.utils_dataset import natural_sort -from PIL import Image -import xml.etree.ElementTree as ET -import re -from tqdm import tqdm -from collections import Counter - - -def remove_spaces(text): - # remove begin/ending spaces - text = text.strip() - # replace \t with regular space - text = re.sub("\t", " ", text) - # remove consecutive spaces - text = re.sub(" +", " ", text) -# text = text.encode('ascii', 'ignore').decode("utf-8") - return text - - -class SynistDatasetFormatter(OCRDatasetFormatter): - def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False): - super(SynistDatasetFormatter, self).__init__("bessin", level, "_sem" if sem_token else "", set_names) - - self.dpi = dpi - self.counter = Counter() - self.map_datasets_files.update({ - "bessin": { - # (1,050 for train, 100 for validation and 100 for test) - "line": { - "needed_files": [], - "arx_files": [], - "format_function": self.format_bessin_zone - } - } - }) - - def preformat_bessin_zone(self): - """ - Extract all information from dataset and correct some annotations - """ - - dataset = { - "train": list(), - "valid": list(), - "test": list() - } - img_folder_path = os.path.join("Datasets", "raw", "bessin", "images") - labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels") - - train_files = [ - os.path.join(labels_folder_path, 'train', name) - for name in os.listdir(os.path.join(labels_folder_path, 'train'))] - valid_files = [ - os.path.join(labels_folder_path, 'valid', name) - for name in os.listdir(os.path.join(labels_folder_path, 'valid'))] - test_files = [ - os.path.join(labels_folder_path, 'test', name) - for name in os.listdir(os.path.join(labels_folder_path, 'test'))] - - for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]): - for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)): - with open(label_file, 'r') as f: - text = remove_spaces(f.read()) - - dataset[set_name].append({ - "img_path": os.path.join( - img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')), - "label": text.strip() - }) - - return dataset - - def format_bessin_zone(self): - """ - Format synist page dataset - """ - dataset = self.preformat_bessin_zone() - for set_name in self.set_names: - fold = os.path.join(self.target_fold_path, set_name) - for sample in tqdm(dataset[set_name], desc='Formatting '+set_name): - new_name = sample['img_path'].split('/')[-1] - new_img_path = os.path.join(fold, new_name) - self.load_resize_save(sample["img_path"], new_img_path) - zone = { - "text": sample["label"], - } - self.charset = self.charset.union(set(zone["text"])) - self.gt[set_name][new_name] = zone - self.counter.update(zone['text']) - - -if __name__ == "__main__": - formatter = SynistDatasetFormatter("line", sem_token=False) - formatter.format() - print("Character freq: ") - for k,v in formatter.counter.items(): - print(k, v) \ No newline at end of file diff --git a/Datasets/dataset_formatters/synist_formatter.py b/Datasets/dataset_formatters/synist_formatter.py deleted file mode 100644 index c028ee73..00000000 --- a/Datasets/dataset_formatters/synist_formatter.py +++ /dev/null @@ -1,137 +0,0 @@ -# contributors : -# - Denis Coquenet -# -# -# This software is a computer program written in Python whose purpose is to -# provide public implementation of deep learning works, in pytorch. -# -# This software is governed by the CeCILL-C license under French law and -# abiding by the rules of distribution of free software. You can use, -# modify and/ or redistribute the software under the terms of the CeCILL-C -# license as circulated by CEA, CNRS and INRIA at the following URL -# "http://www.cecill.info". -# -# As a counterpart to the access to the source code and rights to copy, -# modify and redistribute granted by the license, users are provided only -# with a limited warranty and the software's author, the holder of the -# economic rights, and the successive licensors have only limited -# liability. -# -# In this respect, the user's attention is drawn to the risks associated -# with loading, using, modifying and/or developing or reproducing the -# software by the user in light of its specific status of free software, -# that may mean that it is complicated to manipulate, and that also -# therefore means that it is reserved for developers and experienced -# professionals having in-depth computer knowledge. Users are therefore -# encouraged to load and test the software's suitability as regards their -# requirements in conditions enabling the security of their systems and/or -# data to be ensured and, more generally, to use and operate it in the -# same conditions as regards security. -# -# The fact that you are presently reading this means that you have had -# knowledge of the CeCILL-C license and that you accept its terms. - -from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter -import os -import numpy as np -from Datasets.dataset_formatters.utils_dataset import natural_sort -from PIL import Image -import xml.etree.ElementTree as ET -import re -from tqdm import tqdm -from collections import Counter - - -def remove_spaces(text): - # remove begin/ending spaces - text = text.strip() - # replace \t with regular space - text = re.sub("\t", " ", text) - # remove consecutive spaces - text = re.sub(" +", " ", text) -# text = text.encode('ascii', 'ignore').decode("utf-8") - return text - - -class SynistDatasetFormatter(OCRDatasetFormatter): - def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False): - super(SynistDatasetFormatter, self).__init__("synist_synth", level, "_sem" if sem_token else "", set_names) - - self.dpi = dpi - self.counter = Counter() - self.map_datasets_files.update({ - "synist_synth": { - # (1,050 for train, 100 for validation and 100 for test) - "line": { - "needed_files": [], - "arx_files": [], - "format_function": self.format_synist_page - } - } - }) - - def preformat_synist_page(self): - """ - Extract all information from dataset and correct some annotations - """ - - dataset = { - "train": list(), - "valid": list(), - "test": list() - } - img_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "images") - labels_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "labels") - - train_files = [ - os.path.join(labels_folder_path, 'train', name) - for name in os.listdir(os.path.join(labels_folder_path, 'train'))] - valid_files = [ - os.path.join(labels_folder_path, 'valid', name) - for name in os.listdir(os.path.join(labels_folder_path, 'valid'))] - test_files = [ - os.path.join(labels_folder_path, 'test', name) - for name in os.listdir(os.path.join(labels_folder_path, 'test'))] - - for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]): - for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)): - with open(label_file, 'r') as f: - text = remove_spaces(f.read()) - - dataset[set_name].append({ - "img_path": os.path.join( - img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'png')), - "label": text.strip() - }) - - return dataset - - def format_synist_page(self): - """ - Format synist page dataset - """ - dataset = self.preformat_synist_page() - for set_name in self.set_names: - fold = os.path.join(self.target_fold_path, set_name) - for sample in tqdm(dataset[set_name], desc='Formatting '+set_name): - new_name = sample['img_path'].split('/')[-1] - new_img_path = os.path.join(fold, new_name) - #self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi) - self.load_resize_save(sample["img_path"], new_img_path) - #self.load_flip_save(new_img_path, new_img_path) - page = { - "text": sample["label"], - } - self.charset = self.charset.union(set(page["text"])) - self.gt[set_name][new_name] = page - self.counter.update(page['text']) - - -if __name__ == "__main__": - formatter = SynistDatasetFormatter("line", sem_token=False) - formatter.format() - print(formatter.counter) - print(formatter.counter.most_common(80)) - for k,v in formatter.counter.items(): - print(k) - print(k.encode('utf-8'), v) diff --git a/dan/cli.py b/dan/cli.py index b0ed2199..fff057b9 100644 --- a/dan/cli.py +++ b/dan/cli.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- import argparse import errno + from dan.datasets.extract.extract_from_arkindex import add_extract_parser from dan.ocr.line.generate_synthetic import add_generate_parser - from dan.ocr.train import add_train_parser - def get_parser(): parser = argparse.ArgumentParser(prog="TEKLIA DAN training") subcommands = parser.add_subparsers(metavar="subcommand") @@ -17,6 +16,7 @@ def get_parser(): add_generate_parser(subcommands) return parser + def main(): parser = get_parser() args = vars(parser.parse_args()) diff --git a/dan/datasets/extract/arkindex_utils.py b/dan/datasets/extract/arkindex_utils.py index 5216a7e0..d9d2e065 100644 --- a/dan/datasets/extract/arkindex_utils.py +++ b/dan/datasets/extract/arkindex_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ diff --git a/dan/datasets/extract/extract_from_arkindex.py b/dan/datasets/extract/extract_from_arkindex.py index 1a6e207d..c9c1c2c1 100644 --- a/dan/datasets/extract/extract_from_arkindex.py +++ b/dan/datasets/extract/extract_from_arkindex.py @@ -1,12 +1,4 @@ -#!/usr/bin/env python # -*- coding: utf-8 -*- -# -# Example of minimal usage: -# python extract_from_arkindex.py -# --corpus "AN-Simara annotations E (2022-06-20)" -# --parents-types folder -# --parents-names FRAN_IR_032031_4538.pdf -# --output-dir ../ """ The extraction module diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 72ab461b..77d15804 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ @@ -25,6 +24,13 @@ def get_cli_args(): help="Name of the corpus from which the data will be retrieved.", required=True, ) + parser.add_argument( + "--element-type", + nargs="+", + type=str, + help="Type of elements to retrieve", + required=True, + ) parser.add_argument( "--parents-types", nargs="+", @@ -47,4 +53,22 @@ def get_cli_args(): help="Names of parents of the elements.", default=None, ) + parser.add_argument( + "--no-entities", action="store_true", help="Extract text without entities" + ) + + parser.add_argument( + "--use-existing-split", + action="store_true", + help="Do not partition pages into train/val/test", + ) + + parser.add_argument( + "--train-prob", type=float, default=0.7, help="Training set probability" + ) + + parser.add_argument( + "--val-prob", type=float, default=0.15, help="Validation set probability" + ) + return parser.parse_args() diff --git a/dan/datasets/format/bessin.py b/dan/datasets/format/bessin.py new file mode 100644 index 00000000..1b2248b2 --- /dev/null +++ b/dan/datasets/format/bessin.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +import os +import re +from collections import Counter + +from tqdm import tqdm + +from dan.datasets.format.generic import OCRDatasetFormatter + + +def remove_spaces(text): + # remove begin/ending spaces + text = text.strip() + # replace \t with regular space + text = re.sub("\t", " ", text) + # remove consecutive spaces + text = re.sub(" +", " ", text) + # text = text.encode('ascii', 'ignore').decode("utf-8") + return text + + +class BessinDatasetFormatter(OCRDatasetFormatter): + def __init__( + self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False + ): + super(BessinDatasetFormatter, self).__init__( + "bessin", level, "_sem" if sem_token else "", set_names + ) + + self.dpi = dpi + self.counter = Counter() + self.map_datasets_files.update( + { + "bessin": { + # (1,050 for train, 100 for validation and 100 for test) + "line": { + "needed_files": [], + "arx_files": [], + "format_function": self.format_bessin_zone, + } + } + } + ) + + def preformat_bessin_zone(self): + """ + Extract all information from dataset and correct some annotations + """ + + dataset = {"train": list(), "valid": list(), "test": list()} + img_folder_path = os.path.join("Datasets", "raw", "bessin", "images") + labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels") + + train_files = [ + os.path.join(labels_folder_path, "train", name) + for name in os.listdir(os.path.join(labels_folder_path, "train")) + ] + valid_files = [ + os.path.join(labels_folder_path, "valid", name) + for name in os.listdir(os.path.join(labels_folder_path, "valid")) + ] + test_files = [ + os.path.join(labels_folder_path, "test", name) + for name in os.listdir(os.path.join(labels_folder_path, "test")) + ] + + for set_name, files in zip( + self.set_names, [train_files, valid_files, test_files] + ): + for i, label_file in enumerate( + tqdm(files, desc="Pre-formatting " + set_name) + ): + with open(label_file, "r") as f: + text = remove_spaces(f.read()) + + dataset[set_name].append( + { + "img_path": os.path.join( + img_folder_path, + set_name, + label_file.split("/")[-1].replace("txt", "jpg"), + ), + "label": text.strip(), + } + ) + + return dataset + + def format_bessin_zone(self): + """ + Format synist page dataset + """ + dataset = self.preformat_bessin_zone() + for set_name in self.set_names: + fold = os.path.join(self.target_fold_path, set_name) + for sample in tqdm(dataset[set_name], desc="Formatting " + set_name): + new_name = sample["img_path"].split("/")[-1] + new_img_path = os.path.join(fold, new_name) + self.load_resize_save(sample["img_path"], new_img_path) + zone = { + "text": sample["label"], + } + self.charset = self.charset.union(set(zone["text"])) + self.gt[set_name][new_name] = zone + self.counter.update(zone["text"]) + + +if __name__ == "__main__": + formatter = BessinDatasetFormatter("line", sem_token=False) + formatter.format() + print("Character freq: ") + for k, v in formatter.counter.items(): + print(k, v) diff --git a/dan/datasets/format/synist.py b/dan/datasets/format/synist.py new file mode 100644 index 00000000..87c18fea --- /dev/null +++ b/dan/datasets/format/synist.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +import os +import re +from collections import Counter + +from tqdm import tqdm + +from dan.datasets.format.generic import OCRDatasetFormatter + + +def remove_spaces(text): + # remove begin/ending spaces + text = text.strip() + # replace \t with regular space + text = re.sub("\t", " ", text) + # remove consecutive spaces + text = re.sub(" +", " ", text) + return text + + +class SynistDatasetFormatter(OCRDatasetFormatter): + def __init__( + self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False + ): + super(SynistDatasetFormatter, self).__init__( + "synist_synth", level, "_sem" if sem_token else "", set_names + ) + + self.dpi = dpi + self.counter = Counter() + self.map_datasets_files.update( + { + "synist_synth": { + # (1,050 for train, 100 for validation and 100 for test) + "line": { + "needed_files": [], + "arx_files": [], + "format_function": self.format_synist_page, + } + } + } + ) + + def preformat_synist_page(self): + """ + Extract all information from dataset and correct some annotations + """ + + dataset = {"train": list(), "valid": list(), "test": list()} + img_folder_path = os.path.join( + "Datasets", "raw", "synist_synth_lines", "images" + ) + labels_folder_path = os.path.join( + "Datasets", "raw", "synist_synth_lines", "labels" + ) + + train_files = [ + os.path.join(labels_folder_path, "train", name) + for name in os.listdir(os.path.join(labels_folder_path, "train")) + ] + valid_files = [ + os.path.join(labels_folder_path, "valid", name) + for name in os.listdir(os.path.join(labels_folder_path, "valid")) + ] + test_files = [ + os.path.join(labels_folder_path, "test", name) + for name in os.listdir(os.path.join(labels_folder_path, "test")) + ] + + for set_name, files in zip( + self.set_names, [train_files, valid_files, test_files] + ): + for i, label_file in enumerate( + tqdm(files, desc="Pre-formatting " + set_name) + ): + with open(label_file, "r") as f: + text = remove_spaces(f.read()) + + dataset[set_name].append( + { + "img_path": os.path.join( + img_folder_path, + set_name, + label_file.split("/")[-1].replace("txt", "png"), + ), + "label": text.strip(), + } + ) + + return dataset + + def format_synist_page(self): + """ + Format synist page dataset + """ + dataset = self.preformat_synist_page() + for set_name in self.set_names: + fold = os.path.join(self.target_fold_path, set_name) + for sample in tqdm(dataset[set_name], desc="Formatting " + set_name): + new_name = sample["img_path"].split("/")[-1] + new_img_path = os.path.join(fold, new_name) + # self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi) + self.load_resize_save(sample["img_path"], new_img_path) + # self.load_flip_save(new_img_path, new_img_path) + page = { + "text": sample["label"], + } + self.charset = self.charset.union(set(page["text"])) + self.gt[set_name][new_name] = page + self.counter.update(page["text"]) + + +if __name__ == "__main__": + formatter = SynistDatasetFormatter("line", sem_token=False) + formatter.format() + print(formatter.counter) + print(formatter.counter.most_common(80)) + for k, v in formatter.counter.items(): + print(k) + print(k.encode("utf-8"), v) diff --git a/dan/datasets/utils.py b/dan/datasets/utils.py index c911cc9b..6422357c 100644 --- a/dan/datasets/utils.py +++ b/dan/datasets/utils.py @@ -1,6 +1,12 @@ # -*- coding: utf-8 -*- +import json +import random import re +import cv2 + +random.seed(42) + def convert(text): return int(text) if text.isdigit() else text.lower() @@ -8,3 +14,30 @@ def convert(text): def natural_sort(data): return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)]) + + +def assign_random_split(train_prob, val_prob): + """ + assuming train_prob + val_prob + test_prob = 1 + """ + prob = random.random() + if prob <= train_prob: + return "train" + elif prob <= train_prob + val_prob: + return "val" + else: + return "test" + + +def save_text(path, text): + with open(path, "w") as f: + f.write(text) + + +def save_image(path, image): + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + +def save_json(path, dict): + with open(path, "w") as outfile: + json.dump(dict, outfile, indent=4) diff --git a/dan/decoder.py b/dan/decoder.py index d2070fc6..f69d8e9a 100644 --- a/dan/decoder.py +++ b/dan/decoder.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # -*- coding: utf-8 -*- import torch diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 1bcbe1c7..dda43ba8 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from dan.ocr.line.train import add_line_parser from dan.ocr.document.train import add_document_parser +from dan.ocr.line.train import add_line_parser + def add_train_parser(subcommands) -> None: parser = subcommands.add_parser( diff --git a/prediction-requirements.txt b/prediction-requirements.txt deleted file mode 100644 index 86bfaeb8..00000000 --- a/prediction-requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -arkindex-client==1.0.11 -editdistance==0.6.0 -fontTools==4.29.1 -imageio==2.16.0 -networkx==2.6.3 -tensorboard==0.2.1 -torchvision==0.12.0 -tqdm==4.62.3 diff --git a/requirements.txt b/requirements.txt index 1fe32d37..27758aa6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,12 @@ +arkindex-client==1.0.11 +editdistance==0.6.0 +fontTools==4.29.1 +imageio==2.16.0 +networkx==2.6.3 numpy==1.22.3 opencv-python==4.5.5.64 PyYAML==6.0 +tensorboard==0.2.1 torch==1.11.0 +torchvision==0.12.0 +tqdm==4.62.3 diff --git a/setup.py b/setup.py index a0002964..51d55e0f 100755 --- a/setup.py +++ b/setup.py @@ -28,5 +28,4 @@ setup( "teklia-dan=dan.cli:main", ] }, - extras_require={"predict": parse_requirements("prediction-requirements.txt")}, ) -- GitLab