Skip to content
Snippets Groups Projects
Commit c2258f4e authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Training code upgrades

parent 2607d3bc
No related branches found
No related tags found
1 merge request!5Training code upgrades
......@@ -15,8 +15,8 @@
import logging
import os
from re import S
import cv2
import imageio.v2 as iio
from arkindex import ArkindexClient, options_from_env
from tqdm import tqdm
......@@ -33,6 +33,7 @@ logging.basicConfig(
IMAGES_DIR = "./images/" # Path to the images directory.
LABELS_DIR = "./labels/" # Path to the labels directory.
# Layout string to token
SEM_MATCHING_TOKENS_STR = {
"INTITULE": "",
......
......@@ -24,6 +24,13 @@ def get_cli_args():
help="Name of the corpus from which the data will be retrieved.",
required=True,
)
parser.add_argument(
"--element-type",
nargs="+",
type=str,
help="Type of elements to retrieve",
required=True,
)
parser.add_argument(
"--parents-types",
nargs="+",
......@@ -46,4 +53,26 @@ def get_cli_args():
help="Names of parents of the elements.",
default=None,
)
parser.add_argument(
"--no-entities",
action="store_true",
help="Extract text without entities")
parser.add_argument(
"--use-existing-split",
action="store_true",
help="Do not partition pages into train/val/test")
parser.add_argument(
"--train-prob",
type=float,
default=0.7,
help="Training set probability")
parser.add_argument(
"--val-prob",
type=float,
default=0.15,
help="Validation set probability")
return parser.parse_args()
......@@ -35,9 +35,37 @@
# knowledge of the CeCILL-C license and that you accept its terms.
import re
import random
import cv2
import json
random.seed(42)
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
\ No newline at end of file
return sorted(l, key=alphanum_key)
def assign_random_split(train_prob, val_prob):
"""
assuming train_prob + val_prob + test_prob = 1
"""
prob = random.random()
if prob <= train_prob:
return "train"
elif prob <= train_prob + val_prob:
return "val"
else:
return "test"
def save_text(path, text):
with open(path, 'w') as f:
f.write(text)
def save_image(path, image):
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def save_json(path, dict):
with open(path, "w") as outfile:
json.dump(dict, outfile, indent=4)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment