Skip to content
Snippets Groups Projects
Verified Commit 5ec47c7d authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

remove Datasets folder

parent b2695212
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Example of minimal usage:
# python extract_from_arkindex.py
# --corpus "AN-Simara annotations E (2022-06-20)"
# --parents-types folder
# --parents-names FRAN_IR_032031_4538.pdf
# --output-dir ../
"""
The extraction module
======================
"""
import logging
import os
from re import S
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
from extraction_utils import arkindex_utils as ark
from extraction_utils import utils
from utils_dataset import assign_random_split, save_text, save_image, save_json
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
IMAGES_DIR = "./images/" # Path to the images directory.
LABELS_DIR = "./labels/" # Path to the labels directory.
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
"DATE": "",
"COTE_SERIE": "",
"ANALYSE_COMPL.": "",
"PRECISIONS_SUR_COTE": "",
"COTE_ARTICLE": ""
}
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {
"": "",
"": "",
"": "",
"": "",
"": "",
"": ""
}
def extract_transcription(element, no_entities=False):
tr = client.request('ListTranscriptions', id=element['id'], worker_version=None)['results']
tr = [one for one in tr if one['worker_version_id'] is None]
if len(tr) != 1:
return None
if no_entities:
text = tr[0]['text'].strip()
else:
for one_tr in tr:
ent = client.request('ListTranscriptionEntities', id=one_tr['id'])['results']
print(ent)
ent = [one for one in ent if one['worker_version_id'] is None]
if len(ent) == 0:
continue
else:
text = one_tr['text']
count = 0
for e in ent:
start_token = SEM_MATCHING_TOKENS_STR[e['entity']['metas']['subtype']]
end_token = SEM_MATCHING_TOKENS[start_token]
text = text[:count+e['offset']] + start_token + text[count+e['offset']:]
count += 1
text = text[:count+e['offset']+e['length']] + end_token + text[count+e['offset']+e['length']:]
count += 1
return text
if __name__ == '__main__':
args = utils.get_cli_args()
# Get and initialize the parameters.
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(LABELS_DIR, exist_ok=True)
# Login to arkindex.
client = ArkindexClient(**options_from_env())
corpus = ark.retrieve_corpus(client, args.corpus)
subsets = ark.retrieve_subsets(
client, corpus, args.parents_types, args.parents_names
)
# Create out directories
split_dict = {}
split_names = [subset["name"] for subset in subsets] if args.use_existing_split else ["train", "valid", "test"]
for split in split_names:
os.makedirs(os.path.join(args.output_dir, IMAGES_DIR, split), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, LABELS_DIR, split), exist_ok=True)
split_dict[split] = {}
# Iterate over the subsets to find the page images and labels.
for subset in subsets:
# Iterate over the pages to create splits at page level.
for page in tqdm(
client.paginate(
"ListElementChildren", id=subset["id"], type="page", recursive=True
)
):
if not args.use_existing_split:
split = assign_random_split(args.train_prob, args.val_prob)
else:
split = subset["name"]
# Extract only pages
if args.elements_types == ["page"]:
text = extract_transcription(element, no_entities=args.no_entities)
image = iio.imread(element["zone"]["url"])
if not text:
logging.warning(f"Skipping {element['id']} (zero or multiple transcriptions with worker_version=None)")
continue
else:
logging.info(f"Processed {element_type} {element['id']}")
split_dict[split].append(page["id"])
im_path = os.path.join(args.output_dir, IMAGES_DIR, split, f"{element_type}_{element['id']}.jpg")
txt_path = os.path.join(args.output_dir, LABELS_DIR, split, f"{element_type}_{element['id']}.txt")
save_text(txt_path, text)
save_image(im_path, image)
# Extract page's children elements (text_zone, text_line)
else:
split_dict[split][page["id"]] = {}
for element_type in args.elements_types:
split_dict[split][page["id"]][element_type] = []
for element in client.paginate("ListElementChildren", id=page["id"], type=element_type, recursive=True):
text = extract_transcription(element, no_entities=args.no_entities)
image = iio.imread(element["zone"]["url"])
if not text:
logging.warning(f"Skipping {element_type} {element['id']} (zero or multiple transcriptions with worker_version=None)")
continue
else:
logging.info(f"Processed {element_type} {element['id']}")
split_dict[split][page["id"]][element_type].append(element['id'])
im_path = os.path.join(args.output_dir, IMAGES_DIR, split, f"{element_type}_{element['id']}.jpg")
txt_path = os.path.join(args.output_dir, LABELS_DIR, split, f"{element_type}_{element['id']}.txt")
save_text(txt_path, text)
save_image(im_path, image)
save_json(os.path.join(args.output_dir, "split.json"), split_dict)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The utils module
======================
"""
import argparse
def get_cli_args():
"""
Get the command-line arguments.
:return: The command-line arguments.
"""
parser = argparse.ArgumentParser(
description="Arkindex DAN Training Label Generation"
)
# Required arguments.
parser.add_argument(
"--corpus",
type=str,
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
type=str,
help="Type of parents of the elements.",
required=True,
)
parser.add_argument(
"--output-dir",
type=str,
help="Path to the output directory.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--parents-names",
nargs="+",
type=str,
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities",
action="store_true",
help="Extract text without entities")
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test")
parser.add_argument(
"--train-prob",
type=float,
default=0.7,
help="Training set probability")
parser.add_argument(
"--val-prob",
type=float,
default=0.15,
help="Validation set probability")
return parser.parse_args()
import re
import random
import cv2
import json
random.seed(42)
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
def assign_random_split(train_prob, val_prob):
"""
assuming train_prob + val_prob + test_prob = 1
"""
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "val"
else:
return "test"
def save_text(path, text):
with open(path, 'w') as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment