Skip to content
Snippets Groups Projects
Verified Commit f9b83350 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

remove not needed files, reorganize inside dan module

parent 2a289354
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
import argparse
import errno
from dan.datasets.extract.extract_from_arkindex import add_extract_parser
from dan.ocr.line.generate_synthetic import add_generate_parser
from dan.ocr.train import add_train_parser
def get_parser():
parser = argparse.ArgumentParser(prog="TEKLIA DAN training")
subcommands = parser.add_subparsers(metavar="subcommand")
......@@ -17,6 +16,7 @@ def get_parser():
add_generate_parser(subcommands)
return parser
def main():
parser = get_parser()
args = vars(parser.parse_args())
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Example of minimal usage:
# python extract_from_arkindex.py
# --corpus "AN-Simara annotations E (2022-06-20)"
# --parents-types folder
# --parents-names FRAN_IR_032031_4538.pdf
# --output-dir ../
"""
The extraction module
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
......@@ -25,6 +24,13 @@ def get_cli_args():
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
......@@ -47,4 +53,22 @@ def get_cli_args():
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities", action="store_true", help="Extract text without entities"
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test",
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set probability"
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set probability"
)
return parser.parse_args()
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in Python whose purpose is to
# provide public implementation of deep learning works, in pytorch.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
# -*- coding: utf-8 -*-
import os
import numpy as np
from Datasets.dataset_formatters.utils_dataset import natural_sort
from PIL import Image
import xml.etree.ElementTree as ET
import re
from tqdm import tqdm
from collections import Counter
from tqdm import tqdm
from dan.datasets.format.generic import OCRDatasetFormatter
def remove_spaces(text):
# remove begin/ending spaces
......@@ -49,60 +15,74 @@ def remove_spaces(text):
text = re.sub("\t", " ", text)
# remove consecutive spaces
text = re.sub(" +", " ", text)
# text = text.encode('ascii', 'ignore').decode("utf-8")
# text = text.encode('ascii', 'ignore').decode("utf-8")
return text
class SynistDatasetFormatter(OCRDatasetFormatter):
def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False):
super(SynistDatasetFormatter, self).__init__("bessin", level, "_sem" if sem_token else "", set_names)
class BessinDatasetFormatter(OCRDatasetFormatter):
def __init__(
self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False
):
super(BessinDatasetFormatter, self).__init__(
"bessin", level, "_sem" if sem_token else "", set_names
)
self.dpi = dpi
self.counter = Counter()
self.map_datasets_files.update({
"bessin": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_bessin_zone
}
self.map_datasets_files.update(
{
"bessin": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_bessin_zone,
}
}
}
})
)
def preformat_bessin_zone(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {
"train": list(),
"valid": list(),
"test": list()
}
dataset = {"train": list(), "valid": list(), "test": list()}
img_folder_path = os.path.join("Datasets", "raw", "bessin", "images")
labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels")
train_files = [
os.path.join(labels_folder_path, 'train', name)
for name in os.listdir(os.path.join(labels_folder_path, 'train'))]
os.path.join(labels_folder_path, "train", name)
for name in os.listdir(os.path.join(labels_folder_path, "train"))
]
valid_files = [
os.path.join(labels_folder_path, 'valid', name)
for name in os.listdir(os.path.join(labels_folder_path, 'valid'))]
os.path.join(labels_folder_path, "valid", name)
for name in os.listdir(os.path.join(labels_folder_path, "valid"))
]
test_files = [
os.path.join(labels_folder_path, 'test', name)
for name in os.listdir(os.path.join(labels_folder_path, 'test'))]
os.path.join(labels_folder_path, "test", name)
for name in os.listdir(os.path.join(labels_folder_path, "test"))
]
for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
with open(label_file, 'r') as f:
for set_name, files in zip(
self.set_names, [train_files, valid_files, test_files]
):
for i, label_file in enumerate(
tqdm(files, desc="Pre-formatting " + set_name)
):
with open(label_file, "r") as f:
text = remove_spaces(f.read())
dataset[set_name].append({
"img_path": os.path.join(
img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')),
"label": text.strip()
})
dataset[set_name].append(
{
"img_path": os.path.join(
img_folder_path,
set_name,
label_file.split("/")[-1].replace("txt", "jpg"),
),
"label": text.strip(),
}
)
return dataset
......@@ -113,8 +93,8 @@ class SynistDatasetFormatter(OCRDatasetFormatter):
dataset = self.preformat_bessin_zone()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
new_name = sample['img_path'].split('/')[-1]
for sample in tqdm(dataset[set_name], desc="Formatting " + set_name):
new_name = sample["img_path"].split("/")[-1]
new_img_path = os.path.join(fold, new_name)
self.load_resize_save(sample["img_path"], new_img_path)
zone = {
......@@ -122,12 +102,12 @@ class SynistDatasetFormatter(OCRDatasetFormatter):
}
self.charset = self.charset.union(set(zone["text"]))
self.gt[set_name][new_name] = zone
self.counter.update(zone['text'])
self.counter.update(zone["text"])
if __name__ == "__main__":
formatter = SynistDatasetFormatter("line", sem_token=False)
formatter = BessinDatasetFormatter("line", sem_token=False)
formatter.format()
print("Character freq: ")
for k,v in formatter.counter.items():
print(k, v)
\ No newline at end of file
for k, v in formatter.counter.items():
print(k, v)
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in Python whose purpose is to
# provide public implementation of deep learning works, in pytorch.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
# -*- coding: utf-8 -*-
import os
import numpy as np
from Datasets.dataset_formatters.utils_dataset import natural_sort
from PIL import Image
import xml.etree.ElementTree as ET
import re
from tqdm import tqdm
from collections import Counter
from tqdm import tqdm
from dan.datasets.format.generic import OCRDatasetFormatter
def remove_spaces(text):
# remove begin/ending spaces
......@@ -49,60 +15,77 @@ def remove_spaces(text):
text = re.sub("\t", " ", text)
# remove consecutive spaces
text = re.sub(" +", " ", text)
# text = text.encode('ascii', 'ignore').decode("utf-8")
return text
class SynistDatasetFormatter(OCRDatasetFormatter):
def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False):
super(SynistDatasetFormatter, self).__init__("synist_synth", level, "_sem" if sem_token else "", set_names)
def __init__(
self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False
):
super(SynistDatasetFormatter, self).__init__(
"synist_synth", level, "_sem" if sem_token else "", set_names
)
self.dpi = dpi
self.counter = Counter()
self.map_datasets_files.update({
"synist_synth": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_synist_page
}
self.map_datasets_files.update(
{
"synist_synth": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_synist_page,
}
}
}
})
)
def preformat_synist_page(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {
"train": list(),
"valid": list(),
"test": list()
}
img_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "images")
labels_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "labels")
dataset = {"train": list(), "valid": list(), "test": list()}
img_folder_path = os.path.join(
"Datasets", "raw", "synist_synth_lines", "images"
)
labels_folder_path = os.path.join(
"Datasets", "raw", "synist_synth_lines", "labels"
)
train_files = [
os.path.join(labels_folder_path, 'train', name)
for name in os.listdir(os.path.join(labels_folder_path, 'train'))]
os.path.join(labels_folder_path, "train", name)
for name in os.listdir(os.path.join(labels_folder_path, "train"))
]
valid_files = [
os.path.join(labels_folder_path, 'valid', name)
for name in os.listdir(os.path.join(labels_folder_path, 'valid'))]
os.path.join(labels_folder_path, "valid", name)
for name in os.listdir(os.path.join(labels_folder_path, "valid"))
]
test_files = [
os.path.join(labels_folder_path, 'test', name)
for name in os.listdir(os.path.join(labels_folder_path, 'test'))]
os.path.join(labels_folder_path, "test", name)
for name in os.listdir(os.path.join(labels_folder_path, "test"))
]
for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
with open(label_file, 'r') as f:
for set_name, files in zip(
self.set_names, [train_files, valid_files, test_files]
):
for i, label_file in enumerate(
tqdm(files, desc="Pre-formatting " + set_name)
):
with open(label_file, "r") as f:
text = remove_spaces(f.read())
dataset[set_name].append({
"img_path": os.path.join(
img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'png')),
"label": text.strip()
})
dataset[set_name].append(
{
"img_path": os.path.join(
img_folder_path,
set_name,
label_file.split("/")[-1].replace("txt", "png"),
),
"label": text.strip(),
}
)
return dataset
......@@ -113,18 +96,18 @@ class SynistDatasetFormatter(OCRDatasetFormatter):
dataset = self.preformat_synist_page()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
new_name = sample['img_path'].split('/')[-1]
for sample in tqdm(dataset[set_name], desc="Formatting " + set_name):
new_name = sample["img_path"].split("/")[-1]
new_img_path = os.path.join(fold, new_name)
#self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi)
# self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi)
self.load_resize_save(sample["img_path"], new_img_path)
#self.load_flip_save(new_img_path, new_img_path)
# self.load_flip_save(new_img_path, new_img_path)
page = {
"text": sample["label"],
}
self.charset = self.charset.union(set(page["text"]))
self.gt[set_name][new_name] = page
self.counter.update(page['text'])
self.counter.update(page["text"])
if __name__ == "__main__":
......@@ -132,6 +115,6 @@ if __name__ == "__main__":
formatter.format()
print(formatter.counter)
print(formatter.counter.most_common(80))
for k,v in formatter.counter.items():
for k, v in formatter.counter.items():
print(k)
print(k.encode('utf-8'), v)
print(k.encode("utf-8"), v)
# -*- coding: utf-8 -*-
import json
import random
import re
import cv2
random.seed(42)
def convert(text):
return int(text) if text.isdigit() else text.lower()
......@@ -8,3 +14,30 @@ def convert(text):
def natural_sort(data):
return sorted(data, key=lambda key: [convert(c) for c in re.split("([0-9]+)", key)])
def assign_random_split(train_prob, val_prob):
"""
assuming train_prob + val_prob + test_prob = 1
"""
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "val"
else:
return "test"
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
......
# -*- coding: utf-8 -*-
from dan.ocr.line.train import add_line_parser
from dan.ocr.document.train import add_document_parser
from dan.ocr.line.train import add_line_parser
def add_train_parser(subcommands) -> None:
parser = subcommands.add_parser(
......
arkindex-client==1.0.11
editdistance==0.6.0
fontTools==4.29.1
imageio==2.16.0
networkx==2.6.3
tensorboard==0.2.1
torchvision==0.12.0
tqdm==4.62.3
arkindex-client==1.0.11
editdistance==0.6.0
fontTools==4.29.1
imageio==2.16.0
networkx==2.6.3
numpy==1.22.3
opencv-python==4.5.5.64
PyYAML==6.0
tensorboard==0.2.1
torch==1.11.0
torchvision==0.12.0
tqdm==4.62.3
......@@ -28,5 +28,4 @@ setup(
"teklia-dan=dan.cli:main",
]
},
extras_require={"predict": parse_requirements("prediction-requirements.txt")},
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment