Skip to content
Snippets Groups Projects
Commit 825bb4ab authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Solene Tarride
Browse files

Remove code from original repo, small refactoring and packaging

parent f3e30e2d
No related branches found
No related tags found
1 merge request!9Remove code from original repo, small refactoring and packaging
Showing
with 218 additions and 3215 deletions
# Only run on our the DAN python module
files: '^dan'
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Example of minimal usage:
# python extract_from_arkindex.py
# --corpus "AN-Simara annotations E (2022-06-20)"
# --parents-types folder
# --parents-names FRAN_IR_032031_4538.pdf
# --output-dir ../
"""
The extraction module
======================
"""
import logging
import os
from re import S
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
from extraction_utils import arkindex_utils as ark
from extraction_utils import utils
from utils_dataset import assign_random_split, save_text, save_image, save_json
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
IMAGES_DIR = "./images/" # Path to the images directory.
LABELS_DIR = "./labels/" # Path to the labels directory.
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL.": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": ""
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {
"": "",
"": "",
"": "",
"": "",
"": "",
"": ""
}
def extract_transcription(element, no_entities=False):
tr = client.request('ListTranscriptions', id=element['id'], worker_version=None)['results']
tr = [one for one in tr if one['worker_version_id'] is None]
if len(tr) != 1:
return None
if no_entities:
text = tr[0]['text'].strip()
else:
for one_tr in tr:
ent = client.request('ListTranscriptionEntities', id=one_tr['id'])['results']
print(ent)
ent = [one for one in ent if one['worker_version_id'] is None]
if len(ent) == 0:
continue
else:
text = one_tr['text']
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e['entity']['metas']['subtype']]
end_token = SEM_MATCHING_TOKENS[start_token]
text = text[:count+e['offset']] + start_token + text[count+e['offset']:]
count += 1
text = text[:count+e['offset']+e['length']] + end_token + text[count+e['offset']+e['length']:]
count += 1
return text
if __name__ == '__main__':
args = utils.get_cli_args()
# Get and initialize the parameters.
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)
# Login to arkindex.
client = ArkindexClient(**options_from_env())
corpus = ark.retrieve_corpus(client, args.corpus)
subsets = ark.retrieve_subsets(
client, corpus, args.parents_types, args.parents_names
)
# Create out directories
split_dict = {}
split_names = [subset["name"] for subset in subsets] if args.use_existing_split else ["train", "valid", "test"]
for split in split_names:
os.makedirs(os.path.join(args.output_dir, IMAGES_DIR, split), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, LABELS_DIR, split), exist_ok=True)
split_dict[split] = {}
# Iterate over the subsets to find the page images and labels.
for subset in subsets:
# Iterate over the pages to create splits at page level.
for page in tqdm(
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
)
):
if not args.use_existing_split:
split = assign_random_split(args.train_prob, args.val_prob)
else:
split = subset["name"]
# Extract only pages
if args.elements_types == ["page"]:
text = extract_transcription(element, no_entities=args.no_entities)
image = iio.imread(element["zone"]["url"])
if not text:
logging.warning(f"Skipping {element['id']} (zero or multiple transcriptions with worker_version=None)")
continue
else:
logging.info(f"Processed {element_type} {element['id']}")
split_dict[split].append(page["id"])
im_path = os.path.join(args.output_dir, IMAGES_DIR, split, f"{element_type}_{element['id']}.jpg")
txt_path = os.path.join(args.output_dir, LABELS_DIR, split, f"{element_type}_{element['id']}.txt")
save_text(txt_path, text)
save_image(im_path, image)
# Extract page's children elements (text_zone, text_line)
else:
split_dict[split][page["id"]] = {}
for element_type in args.elements_types:
split_dict[split][page["id"]][element_type] = []
for element in client.paginate("ListElementChildren", id=page["id"], type=element_type, recursive=True):
text = extract_transcription(element, no_entities=args.no_entities)
image = iio.imread(element["zone"]["url"])
if not text:
logging.warning(f"Skipping {element_type} {element['id']} (zero or multiple transcriptions with worker_version=None)")
continue
else:
logging.info(f"Processed {element_type} {element['id']}")
split_dict[split][page["id"]][element_type].append(element['id'])
im_path = os.path.join(args.output_dir, IMAGES_DIR, split, f"{element_type}_{element['id']}.jpg")
txt_path = os.path.join(args.output_dir, LABELS_DIR, split, f"{element_type}_{element['id']}.txt")
save_text(txt_path, text)
save_image(im_path, image)
save_json(os.path.join(args.output_dir, "split.json"), split_dict)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The utils module
======================
"""
import argparse
def get_cli_args():
"""
Get the command-line arguments.
:return: The command-line arguments.
"""
parser = argparse.ArgumentParser(
description="Arkindex DAN Training Label Generation"
)
# Required arguments.
parser.add_argument(
"--corpus",
type=str,
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
type=str,
help="Type of parents of the elements.",
required=True,
)
parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--parents-names",
nargs="+",
type=str,
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities",
action="store_true",
help="Extract text without entities")
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test")
parser.add_argument(
"--train-prob",
type=float,
default=0.7,
help="Training set probability")
parser.add_argument(
"--val-prob",
type=float,
default=0.15,
help="Validation set probability")
return parser.parse_args()
This diff is collapsed.
# Copyright Université de Rouen Normandie (1), INSA Rouen (2),
# tutelles du laboratoire LITIS (1 et 2)
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in Python whose purpose is to
# provide public implementation of deep learning works, in pytorch.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
import os
import numpy as np
from Datasets.dataset_formatters.utils_dataset import natural_sort
from PIL import Image
import xml.etree.ElementTree as ET
import re
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
'Ouverture': "", # opening
'Corps de texte': "", # body
'PS/PJ': "", # post scriptum
'Coordonnées Expéditeur': "", # sender
'Reference': "", # also counted as sender information
'Objet': "", # why
'Date, Lieu': "", # where, when
'Coordonnées Destinataire': "", # recipient
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": ""
}
class RIMESDatasetFormatter(OCRDatasetFormatter):
def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True):
super(RIMESDatasetFormatter, self).__init__("RIMES", level, "_sem" if sem_token else "", set_names)
self.source_fold_path = os.path.join("../raw", "RIMES")
self.dpi = dpi
self.sem_token = sem_token
self.map_datasets_files.update({
"RIMES": {
# (1,050 for train, 100 for validation and 100 for test)
"page": {
"arx_files": ["RIMES_page.tar.gz", ],
"needed_files": [],
"format_function": self.format_rimes_page,
},
}
})
self.matching_tokens_str = SEM_MATCHING_TOKENS_STR
self.matching_tokens = SEM_MATCHING_TOKENS
self.ordering_function = order_text_regions
def preformat_rimes_page(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {
"train": list(),
"valid": list(),
"test": list()
}
img_folder_path = os.path.join(self.temp_fold, "RIMES page", "Images")
xml_folder_path = os.path.join(self.temp_fold, "RIMES page", "XML")
xml_files = natural_sort([os.path.join(xml_folder_path, name) for name in os.listdir(xml_folder_path)])
train_xml = xml_files[:1050]
valid_xml = xml_files[1050:1150]
test_xml = xml_files[1150:]
for set_name, xml_files in zip(self.set_names, [train_xml, valid_xml, test_xml]):
for i, xml_path in enumerate(xml_files):
text_regions = list()
root = ET.parse(xml_path).getroot()
img_name = root.find("source").text
if img_name == "01160_L.png":
text_regions.append({
"label": "LETTRE RECOMMANDEE\nAVEC ACCUSE DE RECEPTION",
"type": "",
"coords": {
"left": 88,
"right": 1364,
"top": 1224,
"bottom": 1448,
}
})
for text_region in root.findall("box"):
type = text_region.find("type").text
label = text_region.find("text").text
if label is None or len(label.strip()) <= 0:
continue
if label == "Ref : QVLCP¨65":
label = label.replace("¨", "")
if img_name == "01094_L.png" and type == "Corps de texte":
label = "Suite à la tempête du 19.11.06, un\narbre est tombé sur mon toît et l'a endommagé.\nJe d'eplore une cinquantaine de tuiles à changer,\nune poutre à réparer et une gouttière à\nremplacer. Veuillez trouver ci-joint le devis\nde réparation. Merci de m'envoyer votre\nexpert le plus rapidement possible.\nEn esperant une réponse rapide de votre\npart, veuillez accepter, madame, monsieur,\nmes salutations distinguées."
elif img_name == "01111_L.png" and type == "Corps de texte":
label = "Je vous ai envoyé un courrier le 20 octobre 2006\nvous signalant un sinistre survenu dans ma\nmaison, un dégât des eaux consécutif aux\nfortes pluis.\nVous deviez envoyer un expert pour constater\nles dégâts. Personne n'est venu à ce jour\nJe vous prie donc de faire le nécessaire\nafin que les réparations nécessaires puissent\nêtre commencés.\nDans l'attente, veuillez agréer, Monsieur,\nmes sincères salutations"
label = self.convert_label_accent(label)
label = self.convert_label(label)
label = self.format_text_label(label)
coords = {
"left": int(text_region.attrib["top_left_x"]),
"right": int(text_region.attrib["bottom_right_x"]),
"top": int(text_region.attrib["top_left_y"]),
"bottom": int(text_region.attrib["bottom_right_y"]),
}
text_regions.append({
"label": label,
"type": type,
"coords": coords
})
text_regions = self.ordering_function(text_regions)
dataset[set_name].append({
"text_regions": text_regions,
"img_path": os.path.join(img_folder_path, img_name),
"label": "\n".join([tr["label"] for tr in text_regions]),
"sem_label": "".join([self.sem_label(tr["label"], tr["type"]) for tr in text_regions]),
})
return dataset
def convert_label_accent(self, label):
"""
Solve encoding issues
"""
return label.replace("\\n", "\n").replace("<euro>", "").replace(">euro>", "").replace(">fligne>", " ")\
.replace("¤", "¤").replace("û", "û").replace("", "").replace("ï¿©", "é").replace("ç", "ç")\
.replace("é", "é").replace("ô", "ô").replace(u'\xa0', " ").replace("è", "è").replace("°", "°")\
.replace("À", "À").replace("ì", "À").replace("ê", "ê").replace("î", "î").replace("â", "â")\
.replace("²", "²").replace("ù", "ù").replace("Ã", "à").replace("¬", "")
def format_rimes_page(self):
"""
Format RIMES page dataset
"""
dataset = self.preformat_rimes_page()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in dataset[set_name]:
new_name = "{}_{}.png".format(set_name, len(os.listdir(fold)))
new_img_path = os.path.join(fold, new_name)
self.load_resize_save(sample["img_path"], new_img_path, 300, self.dpi)
for tr in sample["text_regions"]:
tr["coords"] = self.adjust_coord_ratio(tr["coords"], self.dpi / 300)
page = {
"text": sample["label"] if not self.sem_token else sample["sem_label"],
"paragraphs": sample["text_regions"],
"nb_cols": 1,
}
self.charset = self.charset.union(set(page["text"]))
self.gt[set_name][new_name] = page
def convert_label(self, label):
"""
Some annotations presents many options for a given text part, always keep the first one only
"""
if "¤" in label:
label = re.sub('¤{([^¤]*)[/|]([^¤]*)}¤', r'\1', label, flags=re.DOTALL)
label = re.sub('¤{([^¤]*)[/|]([^¤]*)[/|]([^¤]*)>', r'\1', label, flags=re.DOTALL)
label = re.sub('¤([^¤]*)[/|]([^¤]*)¤', r'\1', label, flags=re.DOTALL)
label = re.sub('¤{}¤([^¤]*)[/|]([^ ]*)', r'\1', label, flags=re.DOTALL)
label = re.sub('¤{/([^¤]*)/([^ ]*)', r'\1', label, flags=re.DOTALL)
label = re.sub('¤{([^¤]*)[/|]([^ ]*)', r'\1', label, flags=re.DOTALL)
label = re.sub('([^¤]*)/(.*)[¤}{]+', r'\1', label, flags=re.DOTALL)
label = re.sub('[¤}{]+([^¤}{]*)[¤}{]+', r'\1', label, flags=re.DOTALL)
label = re.sub('¤([^¤]*)¤', r'\1', label, flags=re.DOTALL)
label = re.sub('[ ]+', " ", label, flags=re.DOTALL)
label = label.strip()
return label
def sem_label(self, label, type):
"""
Add layout tokens
"""
if type == "":
return label
begin_token = self.matching_tokens_str[type]
end_token = self.matching_tokens[begin_token]
return begin_token + label + end_token
def order_text_regions(text_regions):
"""
Establish reading order based on text region pixel positions
"""
sorted_text_regions = list()
for tr in text_regions:
added = False
if len(sorted_text_regions) == 0:
sorted_text_regions.append(tr)
added = True
else:
for i, sorted_tr in enumerate(sorted_text_regions):
tr_height = tr["coords"]["bottom"] - tr["coords"]["top"]
sorted_tr_height = sorted_tr["coords"]["bottom"] - sorted_tr["coords"]["top"]
tr_is_totally_above = tr["coords"]["bottom"] < sorted_tr["coords"]["top"]
tr_is_top_above = tr["coords"]["top"] < sorted_tr["coords"]["top"]
is_same_level = sorted_tr["coords"]["top"] <= tr["coords"]["bottom"] <= sorted_tr["coords"]["bottom"] or\
sorted_tr["coords"]["top"] <= tr["coords"]["top"] <= sorted_tr["coords"]["bottom"] or\
tr["coords"]["top"] <= sorted_tr["coords"]["bottom"] <= tr["coords"]["bottom"] or\
tr["coords"]["top"] <= sorted_tr["coords"]["top"] <= tr["coords"]["bottom"]
vertical_shared_space = tr["coords"]["bottom"]-sorted_tr["coords"]["top"] if tr_is_top_above else sorted_tr["coords"]["bottom"]-tr["coords"]["top"]
reach_same_level_limit = vertical_shared_space > 0.3*min(tr_height, sorted_tr_height)
is_more_at_left = tr["coords"]["left"] < sorted_tr["coords"]["left"]
equivalent_height = abs(tr_height-sorted_tr_height) < 0.3*min(tr_height, sorted_tr_height)
is_middle_above_top = np.mean([tr["coords"]["top"], tr["coords"]["bottom"]]) < sorted_tr["coords"]["top"]
if tr_is_totally_above or\
(is_same_level and equivalent_height and is_more_at_left and reach_same_level_limit) or\
(is_same_level and equivalent_height and tr_is_top_above and not reach_same_level_limit) or\
(is_same_level and not equivalent_height and is_middle_above_top):
sorted_text_regions.insert(i, tr)
added = True
break
if not added:
sorted_text_regions.append(tr)
return sorted_text_regions
if __name__ == "__main__":
RIMESDatasetFormatter("page", sem_token=True).format()
RIMESDatasetFormatter("page", sem_token=False).format()
# Copyright Université de Rouen Normandie (1), INSA Rouen (2),
# tutelles du laboratoire LITIS (1 et 2)
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in Python whose purpose is to
# provide public implementation of deep learning works, in pytorch.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
import os
import numpy as np
from Datasets.dataset_formatters.utils_dataset import natural_sort
from PIL import Image
import xml.etree.ElementTree as ET
import re
from tqdm import tqdm
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": ""
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {
"": "",
"": "",
"": "",
"": "",
"": "",
"": ""
}
class SimaraDatasetFormatter(OCRDatasetFormatter):
def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True):
super(SimaraDatasetFormatter, self).__init__("simara", level, "_sem" if sem_token else "", set_names)
self.dpi = dpi
self.sem_token = sem_token
self.map_datasets_files.update({
"simara": {
# (1,050 for train, 100 for validation and 100 for test)
"page": {
"format_function": self.format_simara_page,
},
}
})
self.matching_tokens_str = SEM_MATCHING_TOKENS_STR
self.matching_tokens = SEM_MATCHING_TOKENS
def preformat_simara_page(self):
"""
Extract all information from dataset and correct some annotations
"""
dataset = {
"train": list(),
"valid": list(),
"test": list()
}
img_folder_path = os.path.join("Datasets", "raw", "simara", "images")
labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels")
sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem")
train_files = [
os.path.join(labels_folder_path, 'train', name)
for name in os.listdir(os.path.join(sem_labels_folder_path, 'train'))]
valid_files = [
os.path.join(labels_folder_path, 'valid', name)
for name in os.listdir(os.path.join(sem_labels_folder_path, 'valid'))]
test_files = [
os.path.join(labels_folder_path, 'test', name)
for name in os.listdir(os.path.join(sem_labels_folder_path, 'test'))]
for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
with open(label_file, 'r') as f:
text = f.read()
with open(label_file.replace('labels', 'labels_sem'), 'r') as f:
sem_text = f.read()
dataset[set_name].append({
"img_path": os.path.join(
img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')),
"label": text,
"sem_label": sem_text,
})
print(dataset['test'], len(dataset['test']))
return dataset
def format_simara_page(self):
"""
Format simara page dataset
"""
dataset = self.preformat_simara_page()
for set_name in self.set_names:
fold = os.path.join(self.target_fold_path, set_name)
for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
new_name = sample['img_path'].split('/')[-1]
new_img_path = os.path.join(fold, new_name)
self.load_resize_save(sample["img_path"], new_img_path)#, 300, self.dpi)
page = {
"text": sample["label"] if not self.sem_token else sample["sem_label"],
}
self.charset = self.charset.union(set(page["text"]))
self.gt[set_name][new_name] = page
if __name__ == "__main__":
SimaraDatasetFormatter("page", sem_token=True).format()
#SimaraDatasetFormatter("page", sem_token=False).format()
# Copyright Université de Rouen Normandie (1), INSA Rouen (2),
# tutelles du laboratoire LITIS (1 et 2)
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in Python whose purpose is to
# provide public implementation of deep learning works, in pytorch.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
import re
import random
import cv2
import json
random.seed(42)
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
def assign_random_split(train_prob, val_prob):
"""
assuming train_prob + val_prob + test_prob = 1
"""
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "val"
else:
return "test"
def save_text(path, text):
with open(path, 'w') as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
\ No newline at end of file
"../../../Fonts/lato/Lato-HairlineItalic.ttf",
"../../../Fonts/lato/Lato-HeavyItalic.ttf",
"../../../Fonts/lato/Lato-BoldItalic.ttf",
"../../../Fonts/lato/Lato-Black.ttf",
"../../../Fonts/lato/Lato-Heavy.ttf",
"../../../Fonts/lato/Lato-Regular.ttf",
"../../../Fonts/lato/Lato-LightItalic.ttf",
"../../../Fonts/lato/Lato-Italic.ttf",
"../../../Fonts/lato/Lato-ThinItalic.ttf",
"../../../Fonts/lato/Lato-Bold.ttf",
"../../../Fonts/lato/Lato-Hairline.ttf",
"../../../Fonts/lato/Lato-Medium.ttf",
"../../../Fonts/lato/Lato-SemiboldItalic.ttf",
"../../../Fonts/lato/Lato-BlackItalic.ttf",
"../../../Fonts/lato/Lato-MediumItalic.ttf",
"../../../Fonts/lato/Lato-Semibold.ttf",
"../../../Fonts/lato/Lato-Thin.ttf",
"../../../Fonts/lato/Lato-Light.ttf",
"../../../Fonts/gentiumplus/GentiumPlus-I.ttf",
"../../../Fonts/gentiumplus/GentiumPlus-R.ttf",
"../../../Fonts/dejavu/DejaVuSansMono-BoldOblique.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed.ttf",
"../../../Fonts/dejavu/DejaVuSans-BoldOblique.ttf",
"../../../Fonts/dejavu/DejaVuSans-ExtraLight.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed-Oblique.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed-BoldItalic.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSerif-Italic.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed-Italic.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSansMono.ttf",
"../../../Fonts/dejavu/DejaVuSerif-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSans-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSerif.ttf",
"../../../Fonts/dejavu/DejaVuSansMono-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSerif-BoldItalic.ttf",
"../../../Fonts/dejavu/DejaVuSansMono-Oblique.ttf",
"../../../Fonts/dejavu/DejaVuSans.ttf",
"../../../Fonts/dejavu/DejaVuSans-Oblique.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed-BoldOblique.ttf"
\ No newline at end of file
"../../../Fonts/handwritten-mix/Parisienne-Regular.ttf",
"../../../Fonts/handwritten-mix/A little sunshine.ttf",
"../../../Fonts/handwritten-mix/Massillo.ttf",
"../../../Fonts/handwritten-mix/Cursive standard Bold.ttf",
"../../../Fonts/handwritten-mix/Merveille-mj8j.ttf",
"../../../Fonts/handwritten-mix/Cursive standard.ttf",
"../../../Fonts/handwritten-mix/Roustel.ttf",
"../../../Fonts/handwritten-mix/Baby Doll.ttf",
"../../../Fonts/handwritten-mix/flashback Demo.ttf",
"../../../Fonts/handwritten-mix/CreamShoes.ttf",
"../../../Fonts/handwritten-mix/Gentle Remind.ttf",
"../../../Fonts/handwritten-mix/Alexandria Rose.ttf",
"../../../Fonts/lato/Lato-HairlineItalic.ttf",
"../../../Fonts/lato/Lato-HeavyItalic.ttf",
"../../../Fonts/lato/Lato-BoldItalic.ttf",
"../../../Fonts/lato/Lato-Black.ttf",
"../../../Fonts/lato/Lato-Heavy.ttf",
"../../../Fonts/lato/Lato-Regular.ttf",
"../../../Fonts/lato/Lato-LightItalic.ttf",
"../../../Fonts/lato/Lato-Italic.ttf",
"../../../Fonts/lato/Lato-ThinItalic.ttf",
"../../../Fonts/lato/Lato-Bold.ttf",
"../../../Fonts/lato/Lato-Hairline.ttf",
"../../../Fonts/lato/Lato-Medium.ttf",
"../../../Fonts/lato/Lato-SemiboldItalic.ttf",
"../../../Fonts/lato/Lato-BlackItalic.ttf",
"../../../Fonts/lato/Lato-MediumItalic.ttf",
"../../../Fonts/lato/Lato-Semibold.ttf",
"../../../Fonts/lato/Lato-Thin.ttf",
"../../../Fonts/lato/Lato-Light.ttf",
"../../../Fonts/gentiumplus/GentiumPlus-I.ttf",
"../../../Fonts/gentiumplus/GentiumPlus-R.ttf",
"../../../Fonts/dejavu/DejaVuSansMono-BoldOblique.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed.ttf",
"../../../Fonts/dejavu/DejaVuSans-BoldOblique.ttf",
"../../../Fonts/dejavu/DejaVuSans-ExtraLight.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed-Oblique.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed-BoldItalic.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSerif-Italic.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed-Italic.ttf",
"../../../Fonts/dejavu/DejaVuSerifCondensed-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSansMono.ttf",
"../../../Fonts/dejavu/DejaVuSerif-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSans-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSerif.ttf",
"../../../Fonts/dejavu/DejaVuSansMono-Bold.ttf",
"../../../Fonts/dejavu/DejaVuSerif-BoldItalic.ttf",
"../../../Fonts/dejavu/DejaVuSansMono-Oblique.ttf",
"../../../Fonts/dejavu/DejaVuSans.ttf",
"../../../Fonts/dejavu/DejaVuSans-Oblique.ttf",
"../../../Fonts/dejavu/DejaVuSansCondensed-BoldOblique.ttf",
"../../../Fonts/open-sans/OpenSans-SemiboldItalic.ttf",
"../../../Fonts/open-sans/OpenSans-CondLight.ttf",
"../../../Fonts/open-sans/OpenSans-Light.ttf",
"../../../Fonts/open-sans/OpenSans-Italic.ttf",
"../../../Fonts/open-sans/OpenSans-CondBold.ttf",
"../../../Fonts/open-sans/OpenSans-Bold.ttf",
"../../../Fonts/open-sans/OpenSans-CondLightItalic.ttf",
"../../../Fonts/open-sans/OpenSans-ExtraBold.ttf",
"../../../Fonts/open-sans/OpenSans-Semibold.ttf",
"../../../Fonts/open-sans/OpenSans-Regular.ttf",
"../../../Fonts/open-sans/OpenSans-BoldItalic.ttf",
"../../../Fonts/open-sans/OpenSans-LightItalic.ttf",
"../../../Fonts/open-sans/OpenSans-ExtraBoldItalic.ttf",
"../../../Fonts/msttcorefonts/Arial.ttf",
"../../../Fonts/msttcorefonts/Verdana_Italic.ttf",
"../../../Fonts/msttcorefonts/Georgia_Bold_Italic.ttf",
"../../../Fonts/msttcorefonts/Andale_Mono.ttf",
"../../../Fonts/msttcorefonts/Courier_New_Italic.ttf",
"../../../Fonts/msttcorefonts/Georgia_Italic.ttf",
"../../../Fonts/msttcorefonts/Arial_Black.ttf",
"../../../Fonts/msttcorefonts/Trebuchet_MS_Italic.ttf",
"../../../Fonts/msttcorefonts/Verdana.ttf",
"../../../Fonts/msttcorefonts/Courier_New.ttf",
"../../../Fonts/msttcorefonts/Verdana_Bold.ttf",
"../../../Fonts/msttcorefonts/Arial_Bold_Italic.ttf",
"../../../Fonts/msttcorefonts/Georgia.ttf",
"../../../Fonts/msttcorefonts/Trebuchet_MS_Bold_Italic.ttf",
"../../../Fonts/msttcorefonts/Impact.ttf",
"../../../Fonts/msttcorefonts/Courier_New_Bold.ttf",
"../../../Fonts/msttcorefonts/Times_New_Roman_Italic.ttf",
"../../../Fonts/msttcorefonts/Georgia_Bold.ttf",
"../../../Fonts/msttcorefonts/Times_New_Roman_Bold.ttf",
"../../../Fonts/msttcorefonts/Times_New_Roman.ttf",
"../../../Fonts/msttcorefonts/Comic_Sans_MS.ttf",
"../../../Fonts/msttcorefonts/Trebuchet_MS_Bold.ttf",
"../../../Fonts/msttcorefonts/Trebuchet_MS.ttf",
"../../../Fonts/msttcorefonts/Arial_Italic.ttf",
"../../../Fonts/msttcorefonts/Courier_New_Bold_Italic.ttf",
"../../../Fonts/msttcorefonts/Verdana_Bold_Italic.ttf",
"../../../Fonts/msttcorefonts/Arial_Bold.ttf",
"../../../Fonts/msttcorefonts/Times_New_Roman_Bold_Italic.ttf",
"../../../Fonts/msttcorefonts/Comic_Sans_MS_Bold.ttf"
\ No newline at end of file
This diff is collapsed.
from OCR.ocr_manager import OCRManager
from torch.nn import CrossEntropyLoss
import torch
from dan.ocr_utils import LM_ind_to_str
import numpy as np
from torch.cuda.amp import autocast
import time
class Manager(OCRManager):
def __init__(self, params):
super(Manager, self).__init__(params)
def load_save_info(self, info_dict):
if "curriculum_config" in info_dict.keys():
if self.dataset.train_dataset is not None:
self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"]
def add_save_info(self, info_dict):
info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config
return info_dict
def get_init_hidden(self, batch_size):
num_layers = 1
hidden_size = self.params["model_params"]["enc_dim"]
return torch.zeros(num_layers, batch_size, hidden_size), torch.zeros(num_layers, batch_size, hidden_size)
def apply_teacher_forcing(self, y, y_len, error_rate):
y_error = y.clone()
for b in range(len(y_len)):
for i in range(1, y_len[b]):
if np.random.rand() < error_rate and y[b][i] != self.dataset.tokens["pad"]:
y_error[b][i] = np.random.randint(0, len(self.dataset.charset)+2)
return y_error, y_len
def train_batch(self, batch_data, metric_names):
loss_func = CrossEntropyLoss(ignore_index=self.dataset.tokens["pad"])
sum_loss = 0
x = batch_data["imgs"].to(self.device)
y = batch_data["labels"].to(self.device)
reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
y_len = batch_data["labels_len"]
# add errors in teacher forcing
if "teacher_forcing_error_rate" in self.params["training_params"] and self.params["training_params"]["teacher_forcing_error_rate"] is not None:
error_rate = self.params["training_params"]["teacher_forcing_error_rate"]
simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate)
elif "teacher_forcing_scheduler" in self.params["training_params"]:
error_rate = self.params["training_params"]["teacher_forcing_scheduler"]["min_error_rate"] + min(self.latest_step, self.params["training_params"]["teacher_forcing_scheduler"]["total_num_steps"]) * (self.params["training_params"]["teacher_forcing_scheduler"]["max_error_rate"]-self.params["training_params"]["teacher_forcing_scheduler"]["min_error_rate"]) / self.params["training_params"]["teacher_forcing_scheduler"]["total_num_steps"]
simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate)
else:
simulated_y_pred = y
with autocast(enabled=self.params["training_params"]["use_amp"]):
hidden_predict = None
cache = None
raw_features = self.models["encoder"](x)
features_size = raw_features.size()
b, c, h, w = features_size
pos_features = self.models["decoder"].features_updater.get_pos_features(raw_features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2, 0, 1)
enhanced_features = pos_features
enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features,
simulated_y_pred[:, :-1],
reduced_size,
[max(y_len) for _ in range(b)],
features_size,
start=0,
hidden_predict=hidden_predict,
cache=cache,
keep_all_weights=True)
loss_ce = loss_func(pred, y[:, 1:])
sum_loss += loss_ce
with autocast(enabled=False):
self.backward_loss(sum_loss)
self.step_optimizers()
self.zero_optimizers()
predicted_tokens = torch.argmax(pred, dim=1).detach().cpu().numpy()
predicted_tokens = [predicted_tokens[i, :y_len[i]] for i in range(b)]
str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens]
values = {
"nb_samples": b,
"str_y": batch_data["raw_labels"],
"str_x": str_x,
"loss": sum_loss.item(),
"loss_ce": loss_ce.item(),
"syn_max_lines": self.dataset.train_dataset.get_syn_max_lines() if self.params["dataset_params"]["config"]["synthetic_data"] else 0,
}
return values
def evaluate_batch(self, batch_data, metric_names):
x = batch_data["imgs"].to(self.device)
reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
max_chars = self.params["training_params"]["max_char_prediction"]
start_time = time.time()
with autocast(enabled=self.params["training_params"]["use_amp"]):
b = x.size(0)
reached_end = torch.zeros((b, ), dtype=torch.bool, device=self.device)
prediction_len = torch.zeros((b, ), dtype=torch.int, device=self.device)
predicted_tokens = torch.ones((b, 1), dtype=torch.long, device=self.device) * self.dataset.tokens["start"]
predicted_tokens_len = torch.ones((b, ), dtype=torch.int, device=self.device)
whole_output = list()
confidence_scores = list()
cache = None
hidden_predict = None
if b > 1:
features_list = list()
for i in range(b):
pos = batch_data["imgs_position"]
features_list.append(self.models["encoder"](x[i:i+1, :, pos[i][0][0]:pos[i][0][1], pos[i][1][0]:pos[i][1][1]]))
max_height = max([f.size(2) for f in features_list])
max_width = max([f.size(3) for f in features_list])
features = torch.zeros((b, features_list[0].size(1), max_height, max_width), device=self.device, dtype=features_list[0].dtype)
for i in range(b):
features[i, :, :features_list[i].size(2), :features_list[i].size(3)] = features_list[i]
else:
features = self.models["encoder"](x)
features_size = features.size()
coverage_vector = torch.zeros((features.size(0), 1, features.size(2), features.size(3)), device=self.device)
pos_features = self.models["decoder"].features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2, 0, 1)
enhanced_features = pos_features
enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features, predicted_tokens, reduced_size, predicted_tokens_len, features_size, start=0, hidden_predict=hidden_predict, cache=cache, num_pred=1)
whole_output.append(output)
confidence_scores.append(torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat([predicted_tokens, torch.argmax(pred[:, :, -1], dim=1, keepdim=True)], dim=1)
reached_end = torch.logical_or(reached_end, torch.eq(predicted_tokens[:, -1], self.dataset.tokens["end"]))
predicted_tokens_len += 1
prediction_len[reached_end == False] = i + 1
if torch.all(reached_end):
break
confidence_scores = torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
predicted_tokens = predicted_tokens[:, 1:]
prediction_len[torch.eq(reached_end, False)] = max_chars - 1
predicted_tokens = [predicted_tokens[i, :prediction_len[i]] for i in range(b)]
confidence_scores = [confidence_scores[i, :prediction_len[i]].tolist() for i in range(b)]
str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens]
process_time = time.time() - start_time
values = {
"nb_samples": b,
"str_y": batch_data["raw_labels"],
"str_x": str_x,
"confidence_score": confidence_scores,
"time": process_time,
}
return values
# Copyright Université de Rouen Normandie (1), INSA Rouen (2),
# tutelles du laboratoire LITIS (1 et 2)
# contributors :
# - Denis Coquenet
#
#
# This software is a computer program written in XXX whose purpose is XXX.
#
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-C license and that you accept its terms.
from torch.nn.functional import log_softmax
from torch.nn import AdaptiveMaxPool2d, Conv1d
from torch.nn import Module
class Decoder(Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.vocab_size = params["vocab_size"]
self.ada_pool = AdaptiveMaxPool2d((1, None))
self.end_conv = Conv1d(in_channels=params["enc_size"], out_channels=self.vocab_size+1, kernel_size=1)
def forward(self, x):
x = self.ada_pool(x).squeeze(2)
x = self.end_conv(x)
return log_softmax(x, dim=1)
This diff is collapsed.
......@@ -2,10 +2,10 @@
This repository is a public implementation of the paper: "DAN: a Segmentation-free Document Attention Network for Handwritten Document Recognition".
![Prediction visualization](visual.png)
![Prediction visualization](images/visual.png)
The model uses a character-level attention to handle slanted lines:
![Prediction visualization on slanted lines](visual_slanted_lines.png)
![Prediction visualization on slanted lines](images/visual_slanted_lines.png)
The paper is available at https://arxiv.org/abs/2203.12273.
......@@ -49,68 +49,6 @@ Install the dependencies:
pip install -r requirements.txt
```
## Datasets
This section is dedicated to the datasets used in the paper: download and formatting instructions are provided
for experiment replication purposes.
RIMES dataset at page level was distributed during the [evaluation compaign of 2009](https://ieeexplore.ieee.org/document/5277557).
READ 2016 dataset corresponds to the one used in the [ICFHR 2016 competition on handwritten text recognition](https://ieeexplore.ieee.org/document/7814136).
It can be found [here](https://zenodo.org/record/1164045#.YiINkBvjKEA)
Raw dataset files must be placed in Datasets/raw/{dataset_name} \
where dataset name is "READ 2016" or "RIMES"
## Training And Evaluation
### Step 1: Download the dataset
### Step 2: Format the dataset
```
python3 Datasets/dataset_formatters/read2016_formatter.py
python3 Datasets/dataset_formatters/rimes_formatter.py
```
### Step 3: Add any font you want as .ttf file in the folder Fonts
### Step 4 : Generate synthetic line dataset for pre-training
```
python3 OCR/line_OCR/ctc/main_syn_line.py
```
There are two lines in this script to adapt to the used dataset:
```
model.generate_syn_line_dataset("READ_2016_syn_line")
dataset_name = "READ_2016"
```
### Step 5 : Pre-training on synthetic lines
```
python3 OCR/line_OCR/ctc/main_line_ctc.py
```
There are two lines in this script to adapt to the used dataset:
```
dataset_name = "READ_2016"
"output_folder": "FCN_read_line_syn"
```
Weights and evaluation results are stored in OCR/line_OCR/ctc/outputs
### Step 6 : Training the DAN
```
python3 OCR/document_OCR/dan/main_dan.py
```
The following lines must be adapted to the dataset used and pre-training folder names:
```
dataset_name = "READ_2016"
"transfer_learning": {
# model_name: [state_dict_name, checkpoint_path, learnable, strict]
"encoder": ["encoder", "../../line_OCR/ctc/outputs/FCN_read_2016_line_syn/checkpoints/best.pt", True, True],
"decoder": ["decoder", "../../line_OCR/ctc/outputs/FCN_read_2016_line_syn/best.pt", True, False],
},
```
Weights and evaluation results are stored in OCR/document_OCR/dan/outputs
### Remarks (for pre-training and training)
All hyperparameters are specified and editable in the training scripts (meaning are in comments).\
Evaluation is performed just after training ending (training is stopped when the maximum elapsed time is reached or after a maximum number of epoch as specified in the training script).\
......@@ -154,20 +92,3 @@ To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In
```python
text, confidence_scores = model.predict(image, confidences=True)
```
## Citation
```bibtex
@misc{Coquenet2022b,
author = {Coquenet, Denis and Chatelain, Clément and Paquet, Thierry},
title = {DAN: a Segmentation-free Document Attention Network for Handwritten Document Recognition},
doi = {10.48550/ARXIV.2203.12273},
url = {https://arxiv.org/abs/2203.12273},
publisher = {arXiv},
year = {2022},
}
```
## License
This whole project is under Cecill-C license.
# -*- coding: utf-8 -*-
import argparse
import errno
from dan.datasets.extract.extract_from_arkindex import add_extract_parser
from dan.ocr.line.generate_synthetic import add_generate_parser
from dan.ocr.train import add_train_parser
def get_parser():
parser = argparse.ArgumentParser(prog="teklia-dan")
subcommands = parser.add_subparsers(metavar="subcommand")
add_train_parser(subcommands)
add_extract_parser(subcommands)
add_generate_parser(subcommands)
return parser
def main():
parser = get_parser()
args = vars(parser.parse_args())
if "func" in args:
# Run the subcommand's function
try:
status = args.pop("func")(**args)
parser.exit(status=status)
except KeyboardInterrupt:
# Just quit silently on ^C instead of displaying a long traceback
parser.exit(status=errno.EOWNERDEAD)
else:
parser.error("A subcommand is required.")
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
......
# -*- coding: utf-8 -*-
"""
The extraction module
======================
"""
import logging
import os
import cv2
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
from dan.datasets.extract.arkindex_utils import retrieve_corpus, retrieve_subsets
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
IMAGES_DIR = "./images/" # Path to the images directory.
LABELS_DIR = "./labels/" # Path to the labels directory.
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL.": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": "",
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
def add_extract_parser(subcommands) -> None:
parser = subcommands.add_parser(
"extract",
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
"--corpus",
type=str,
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
type=str,
help="Type of parents of the elements.",
required=True,
)
parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--parents-names",
nargs="+",
type=str,
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities", action="store_true", help="Extract text without entities"
)
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test",
)
parser.add_argument(
"--train-prob", type=float, default=0.7, help="Training set probability"
)
parser.add_argument(
"--val-prob", type=float, default=0.15, help="Validation set probability"
)
parser.set_defaults(func=run)
def run(
corpus,
element_type,
parents_types,
output_dir,
parents_names,
no_entities,
use_existing_split,
train_prob,
val_prob,
):
# Get and initialize the parameters.
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)
# Login to arkindex.
client = ArkindexClient(**options_from_env())
corpus = retrieve_corpus(client, corpus)
subsets = retrieve_subsets(client, corpus, parents_types, parents_names)
# Iterate over the subsets to find the page images and labels.
for subset in subsets:
os.makedirs(os.path.join(output_dir, IMAGES_DIR, subset["name"]), exist_ok=True)
os.makedirs(os.path.join(output_dir, LABELS_DIR, subset["name"]), exist_ok=True)
for page in tqdm(
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
),
desc="Set " + subset["name"],
):
image = iio.imread(page["zone"]["url"])
cv2.imwrite(
os.path.join(
output_dir, IMAGES_DIR, subset["name"], f"{page['id']}.jpg"
),
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
)
tr = client.request(
"ListTranscriptions", id=page["id"], worker_version=None
)["results"]
tr = [one for one in tr if one["worker_version_id"] is None]
assert len(tr) == 1, page["id"]
for one_tr in tr:
ent = client.request("ListTranscriptionEntities", id=one_tr["id"])[
"results"
]
ent = [one for one in ent if one["worker_version_id"] is None]
if len(ent) == 0:
continue
else:
text = one_tr["text"]
new_text = text
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e["entity"]["metas"]["subtype"]]
end_token = SEM_MATCHING_TOKENS[start_token]
new_text = (
new_text[: count + e["offset"]]
+ start_token
+ new_text[count + e["offset"] :]
)
count += 1
new_text = (
new_text[: count + e["offset"] + e["length"]]
+ end_token
+ new_text[count + e["offset"] + e["length"] :]
)
count += 1
with open(
os.path.join(
output_dir, LABELS_DIR, subset["name"], f"{page['id']}.txt"
),
"w",
) as f:
f.write(new_text)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment