Skip to content
Snippets Groups Projects
Commit 2607d3bc authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Add data formatters for Synist and Bessin

parent 9babb89c
No related branches found
No related tags found
1 merge request!7Add data formatters for Synist and Bessin
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in Python whose purpose is to
# provide public implementation of deep learning works, in pytorch.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
import os
import numpy as np
from Datasets.dataset_formatters.utils_dataset import natural_sort
from PIL import Image
import xml.etree.ElementTree as ET
import re
from tqdm import tqdm
from collections import Counter
def remove_spaces(text):
# remove begin/ending spaces
text = text.strip()
# replace \t with regular space
text = re.sub("\t", " ", text)
# remove consecutive spaces
text = re.sub(" +", " ", text)
# text = text.encode('ascii', 'ignore').decode("utf-8")
return text
class SynistDatasetFormatter(OCRDatasetFormatter):
def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False):
super(SynistDatasetFormatter, self).__init__("bessin", level, "_sem" if sem_token else "", set_names)
self.dpi = dpi
self.counter = Counter()
self.map_datasets_files.update({
"bessin": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_bessin_zone
}
}
})
def preformat_bessin_zone(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {
"train": list(),
"valid": list(),
"test": list()
}
img_folder_path = os.path.join("Datasets", "raw", "bessin", "images")
labels_folder_path = os.path.join("Datasets", "raw", "bessin", "labels")
train_files = [
os.path.join(labels_folder_path, 'train', name)
for name in os.listdir(os.path.join(labels_folder_path, 'train'))]
valid_files = [
os.path.join(labels_folder_path, 'valid', name)
for name in os.listdir(os.path.join(labels_folder_path, 'valid'))]
test_files = [
os.path.join(labels_folder_path, 'test', name)
for name in os.listdir(os.path.join(labels_folder_path, 'test'))]
for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
with open(label_file, 'r') as f:
text = remove_spaces(f.read())
dataset[set_name].append({
"img_path": os.path.join(
img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')),
"label": text.strip()
})
return dataset
def format_bessin_zone(self):
"""
Format synist page dataset
"""
dataset = self.preformat_bessin_zone()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
new_name = sample['img_path'].split('/')[-1]
new_img_path = os.path.join(fold, new_name)
self.load_resize_save(sample["img_path"], new_img_path)
zone = {
"text": sample["label"],
}
self.charset = self.charset.union(set(zone["text"]))
self.gt[set_name][new_name] = zone
self.counter.update(zone['text'])
if __name__ == "__main__":
formatter = SynistDatasetFormatter("line", sem_token=False)
formatter.format()
print("Character freq: ")
for k,v in formatter.counter.items():
print(k, v)
\ No newline at end of file
......@@ -23,6 +23,7 @@ from tqdm import tqdm
from extraction_utils import arkindex_utils as ark
from extraction_utils import utils
from utils_dataset import assign_random_split, save_text, save_image, save_json
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
......@@ -53,6 +54,37 @@ SEM_MATCHING_TOKENS = {
}
def extract_transcription(element, no_entities=False):
tr = client.request('ListTranscriptions', id=element['id'], worker_version=None)['results']
tr = [one for one in tr if one['worker_version_id'] is None]
if len(tr) != 1:
return None
if no_entities:
text = tr[0]['text'].strip()
else:
for one_tr in tr:
ent = client.request('ListTranscriptionEntities', id=one_tr['id'])['results']
print(ent)
ent = [one for one in ent if one['worker_version_id'] is None]
if len(ent) == 0:
continue
else:
text = one_tr['text']
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e['entity']['metas']['subtype']]
end_token = SEM_MATCHING_TOKENS[start_token]
text = text[:count+e['offset']] + start_token + text[count+e['offset']:]
count += 1
text = text[:count+e['offset']+e['length']] + end_token + text[count+e['offset']+e['length']:]
count += 1
return text
if __name__ == '__main__':
args = utils.get_cli_args()
......@@ -62,52 +94,76 @@ if __name__ == '__main__':
# Login to arkindex.
client = ArkindexClient(**options_from_env())
corpus = ark.retrieve_corpus(client, args.corpus)
subsets = ark.retrieve_subsets(
client, corpus, args.parents_types, args.parents_names
)
# Create out directories
split_dict = {}
split_names = [subset["name"] for subset in subsets] if args.use_existing_split else ["train", "valid", "test"]
for split in split_names:
os.makedirs(os.path.join(args.output_dir, IMAGES_DIR, split), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, LABELS_DIR, split), exist_ok=True)
split_dict[split] = {}
# Iterate over the subsets to find the page images and labels.
for subset in subsets:
os.makedirs(os.path.join(args.output_dir, IMAGES_DIR, subset["name"]), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, LABELS_DIR, subset["name"]), exist_ok=True)
# Iterate over the pages to create splits at page level.
for page in tqdm(
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
),
desc="Set " + subset["name"],
):
image = iio.imread(page["zone"]["url"])
cv2.imwrite(
os.path.join(args.output_dir, IMAGES_DIR, subset['name'], f"{page['id']}.jpg"),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)
tr = client.request('ListTranscriptions', id=page['id'], worker_version=None)['results']
tr = [one for one in tr if one['worker_version_id'] is None]
assert len(tr) == 1, page['id']
for one_tr in tr:
ent = client.request('ListTranscriptionEntities', id=one_tr['id'])['results']
ent = [one for one in ent if one['worker_version_id'] is None]
if len(ent) == 0:
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
)
):
if not args.use_existing_split:
split = assign_random_split(args.train_prob, args.val_prob)
else:
split = subset["name"]
# Extract only pages
if args.elements_types == ["page"]:
text = extract_transcription(element, no_entities=args.no_entities)
image = iio.imread(element["zone"]["url"])
if not text:
logging.warning(f"Skipping {element['id']} (zero or multiple transcriptions with worker_version=None)")
continue
else:
text = one_tr['text']
new_text = text
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e['entity']['metas']['subtype']]
end_token = SEM_MATCHING_TOKENS[start_token]
new_text = new_text[:count+e['offset']] + start_token + new_text[count+e['offset']:]
count += 1
new_text = new_text[:count+e['offset']+e['length']] + end_token + new_text[count+e['offset']+e['length']:]
count += 1
with open(os.path.join(args.output_dir, LABELS_DIR, subset['name'], f"{page['id']}.txt"), 'w') as f:
f.write(new_text)
logging.info(f"Processed {element_type} {element['id']}")
split_dict[split].append(page["id"])
im_path = os.path.join(args.output_dir, IMAGES_DIR, split, f"{element_type}_{element['id']}.jpg")
txt_path = os.path.join(args.output_dir, LABELS_DIR, split, f"{element_type}_{element['id']}.txt")
save_text(txt_path, text)
save_image(im_path, image)
# Extract page's children elements (text_zone, text_line)
else:
split_dict[split][page["id"]] = {}
for element_type in args.elements_types:
split_dict[split][page["id"]][element_type] = []
for element in client.paginate("ListElementChildren", id=page["id"], type=element_type, recursive=True):
text = extract_transcription(element, no_entities=args.no_entities)
image = iio.imread(element["zone"]["url"])
if not text:
logging.warning(f"Skipping {element_type} {element['id']} (zero or multiple transcriptions with worker_version=None)")
continue
else:
logging.info(f"Processed {element_type} {element['id']}")
split_dict[split][page["id"]][element_type].append(element['id'])
im_path = os.path.join(args.output_dir, IMAGES_DIR, split, f"{element_type}_{element['id']}.jpg")
txt_path = os.path.join(args.output_dir, LABELS_DIR, split, f"{element_type}_{element['id']}.txt")
save_text(txt_path, text)
save_image(im_path, image)
save_json(os.path.join(args.output_dir, "split.json"), split_dict)
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in Python whose purpose is to
# provide public implementation of deep learning works, in pytorch.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
import os
import numpy as np
from Datasets.dataset_formatters.utils_dataset import natural_sort
from PIL import Image
import xml.etree.ElementTree as ET
import re
from tqdm import tqdm
from collections import Counter
def remove_spaces(text):
# remove begin/ending spaces
text = text.strip()
# replace \t with regular space
text = re.sub("\t", " ", text)
# remove consecutive spaces
text = re.sub(" +", " ", text)
# text = text.encode('ascii', 'ignore').decode("utf-8")
return text
class SynistDatasetFormatter(OCRDatasetFormatter):
def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=False):
super(SynistDatasetFormatter, self).__init__("synist_synth", level, "_sem" if sem_token else "", set_names)
self.dpi = dpi
self.counter = Counter()
self.map_datasets_files.update({
"synist_synth": {
# (1,050 for train, 100 for validation and 100 for test)
"line": {
"needed_files": [],
"arx_files": [],
"format_function": self.format_synist_page
}
}
})
def preformat_synist_page(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {
"train": list(),
"valid": list(),
"test": list()
}
img_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "images")
labels_folder_path = os.path.join("Datasets", "raw", "synist_synth_lines", "labels")
train_files = [
os.path.join(labels_folder_path, 'train', name)
for name in os.listdir(os.path.join(labels_folder_path, 'train'))]
valid_files = [
os.path.join(labels_folder_path, 'valid', name)
for name in os.listdir(os.path.join(labels_folder_path, 'valid'))]
test_files = [
os.path.join(labels_folder_path, 'test', name)
for name in os.listdir(os.path.join(labels_folder_path, 'test'))]
for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
with open(label_file, 'r') as f:
text = remove_spaces(f.read())
dataset[set_name].append({
"img_path": os.path.join(
img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'png')),
"label": text.strip()
})
return dataset
def format_synist_page(self):
"""
Format synist page dataset
"""
dataset = self.preformat_synist_page()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
new_name = sample['img_path'].split('/')[-1]
new_img_path = os.path.join(fold, new_name)
#self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi)
self.load_resize_save(sample["img_path"], new_img_path)
#self.load_flip_save(new_img_path, new_img_path)
page = {
"text": sample["label"],
}
self.charset = self.charset.union(set(page["text"]))
self.gt[set_name][new_name] = page
self.counter.update(page['text'])
if __name__ == "__main__":
formatter = SynistDatasetFormatter("line", sem_token=False)
formatter.format()
print(formatter.counter)
print(formatter.counter.most_common(80))
for k,v in formatter.counter.items():
print(k)
print(k.encode('utf-8'), v)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment