Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (3)
Showing
with 5 additions and 1324 deletions
......@@ -4,7 +4,6 @@ import errno
from dan.datasets import add_dataset_parser
from dan.ocr import add_train_parser
from dan.ocr.line import add_generate_parser
from dan.predict import add_predict_parser
......@@ -14,7 +13,6 @@ def get_parser():
add_dataset_parser(subcommands)
add_train_parser(subcommands)
add_generate_parser(subcommands)
add_predict_parser(subcommands)
return parser
......
......@@ -2,7 +2,6 @@
import re
import editdistance
import networkx as nx
import numpy as np
from dan.post_processing import PostProcessingModuleSIMARA
......@@ -31,19 +30,6 @@ class MetricManager:
"cer": ["edit_chars", "nb_chars"],
"wer": ["edit_words", "nb_words"],
"wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"],
"loer": [
"edit_graph",
"nb_nodes_and_edges",
"nb_pp_op_layout",
"nb_gt_layout_token",
],
"precision": ["precision", "weights"],
"map_cer_per_class": [
"map_cer",
],
"layout_precision_per_class_per_threshold": [
"map_cer",
],
}
self.init_metrics()
......@@ -84,41 +70,12 @@ class MetricManager:
for metric_name in metric_names:
value = None
if output:
if metric_name in ["nb_samples", "weights"]:
if metric_name == "nb_samples":
value = int(np.sum(self.epoch_metrics[metric_name]))
elif metric_name in [
"time",
]:
elif metric_name == "time":
value = int(np.sum(self.epoch_metrics[metric_name]))
sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = float(round(sample_time, 4))
elif metric_name == "loer":
display_values["pper"] = float(
round(
np.sum(self.epoch_metrics["nb_pp_op_layout"])
/ np.sum(self.epoch_metrics["nb_gt_layout_token"]),
4,
)
)
elif metric_name == "map_cer_per_class":
value = float(
compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
)
for key in value.keys():
display_values["map_cer_" + key] = float(round(value[key], 4))
continue
elif metric_name == "layout_precision_per_class_per_threshold":
value = float(
compute_global_precision_per_class_per_threshold(
self.epoch_metrics["map_cer"]
)
)
for key_class in value.keys():
for threshold in value[key_class].keys():
display_values[
"map_cer_{}_{}".format(key_class, threshold)
] = float(round(value[key_class][threshold], 4))
continue
if metric_name == "cer":
value = float(
np.sum(self.epoch_metrics["edit_chars"])
......@@ -156,13 +113,6 @@ class MetricManager:
weights=np.array(self.epoch_metrics["nb_samples"]),
)
)
elif metric_name == "map_cer":
value = float(compute_global_mAP(self.epoch_metrics[metric_name]))
elif metric_name == "loer":
value = float(
np.sum(self.epoch_metrics["edit_graph"])
/ np.sum(self.epoch_metrics["nb_nodes_and_edges"])
)
elif value is None:
continue
......@@ -175,9 +125,8 @@ class MetricManager:
values["nb_samples"],
],
}
for v in ["weights", "time"]:
if v in values:
metrics[v] = [values[v]]
if "time" in values:
metrics["time"] = [values["time"]]
for metric_name in metric_names:
if metric_name == "cer":
metrics["edit_chars"] = [
......@@ -223,35 +172,6 @@ class MetricManager:
metrics[metric_name] = [
values[metric_name],
]
elif metric_name == "map_cer":
pp_pred = list()
pp_score = list()
for pred, score in zip(values["str_x"], values["confidence_score"]):
pred_score = self.post_processing_module().post_process(pred, score)
pp_pred.append(pred_score[0])
pp_score.append(pred_score[1])
metrics[metric_name] = [
compute_layout_mAP_per_class(y, x, conf, self.matching_tokens)
for x, conf, y in zip(pp_pred, pp_score, values["str_y"])
]
elif metric_name == "loer":
pp_pred = list()
metrics["nb_pp_op_layout"] = list()
for pred in values["str_x"]:
pp_module = self.post_processing_module()
pp_pred.append(pp_module.post_process(pred))
metrics["nb_pp_op_layout"].append(pp_module.num_op)
metrics["nb_gt_layout_token"] = [
len(keep_only_ner_tokens(str_x, self.layout_tokens))
for str_x in values["str_x"]
]
edit_and_num_items = [
self.edit_and_num_edge_nodes(y, x)
for x, y in zip(pp_pred, values["str_y"])
]
metrics["edit_graph"], metrics["nb_nodes_and_edges"] = [
ei[0] for ei in edit_and_num_items
], [ei[1] for ei in edit_and_num_items]
return metrics
def get(self, name):
......@@ -331,217 +251,3 @@ def edit_wer_from_formatted_split_text(gt, pred):
Compute edit distance at word level from formatted string as list
"""
return editdistance.eval(gt, pred)
def extract_by_tokens(
input_str, begin_token, end_token, associated_score=None, order_by_score=False
):
"""
Extract list of text regions by begin and end tokens
Order the list by confidence score
"""
if order_by_score:
assert associated_score is not None
res = list()
for match in re.finditer(
"{}[^{}]*{}".format(begin_token, end_token, end_token), input_str
):
begin, end = match.regs[0]
if order_by_score:
res.append(
{
"confidence": np.mean(
[associated_score[begin], associated_score[end - 1]]
),
"content": input_str[begin + 1 : end - 1],
}
)
else:
res.append(input_str[begin + 1 : end - 1])
if order_by_score:
res = sorted(res, key=lambda x: x["confidence"], reverse=True)
res = [r["content"] for r in res]
return res
def compute_layout_precision_per_threshold(
gt, pred, score, begin_token, end_token, layout_tokens, return_weight=True
):
"""
Compute average precision of a given class for CER threshold from 5% to 50% with a step of 5%
"""
pred_list = extract_by_tokens(
pred, begin_token, end_token, associated_score=score, order_by_score=True
)
gt_list = extract_by_tokens(gt, begin_token, end_token)
pred_list = [keep_all_but_ner_tokens(p, layout_tokens) for p in pred_list]
gt_list = [keep_all_but_ner_tokens(gt, layout_tokens) for gt in gt_list]
precision_per_threshold = [
compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold / 100)
for threshold in range(5, 51, 5)
]
if return_weight:
return precision_per_threshold, len(gt_list)
return precision_per_threshold
def compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold):
"""
Compute average precision of a given class for a given CER threshold
"""
remaining_gt_list = gt_list.copy()
num_true = len(gt_list)
correct = np.zeros((len(pred_list)), dtype=np.bool)
for i, pred in enumerate(pred_list):
if len(remaining_gt_list) == 0:
break
cer_with_gt = [
edit_cer_from_string(gt, pred) / nb_chars_cer_from_string(gt)
for gt in remaining_gt_list
]
cer, ind = np.min(cer_with_gt), np.argmin(cer_with_gt)
if cer <= threshold:
correct[i] = True
del remaining_gt_list[ind]
precision = np.cumsum(correct, dtype=np.int) / np.arange(1, len(pred_list) + 1)
recall = np.cumsum(correct, dtype=np.int) / num_true
max_precision_from_recall = np.maximum.accumulate(precision[::-1])[::-1]
recall_diff = recall - np.concatenate(
[
np.array(
[
0,
]
),
recall[:-1],
]
)
P = np.sum(recall_diff * max_precision_from_recall)
return P
def compute_layout_mAP_per_class(gt, pred, score, tokens):
"""
Compute the mAP_cer for each class for a given sample
"""
layout_tokens = "".join(list(tokens.keys()))
AP_per_class = dict()
for token in tokens.keys():
if token in gt:
AP_per_class[token] = compute_layout_precision_per_threshold(
gt, pred, score, token, tokens[token], layout_tokens=layout_tokens
)
return AP_per_class
def compute_global_mAP(list_AP_per_class):
"""
Compute the global mAP_cer for several samples
"""
weights_per_doc = list()
mAP_per_doc = list()
for doc_AP_per_class in list_AP_per_class:
APs = np.array(
[np.mean(doc_AP_per_class[key][0]) for key in doc_AP_per_class.keys()]
)
weights = np.array(
[doc_AP_per_class[key][1] for key in doc_AP_per_class.keys()]
)
if np.sum(weights) == 0:
mAP_per_doc.append(0)
else:
mAP_per_doc.append(np.average(APs, weights=weights))
weights_per_doc.append(np.sum(weights))
if np.sum(weights_per_doc) == 0:
return 0
return np.average(mAP_per_doc, weights=weights_per_doc)
def compute_global_mAP_per_class(list_AP_per_class):
"""
Compute the mAP_cer per class for several samples
"""
mAP_per_class = dict()
for doc_AP_per_class in list_AP_per_class:
for key in doc_AP_per_class.keys():
if key not in mAP_per_class:
mAP_per_class[key] = {"AP": list(), "weights": list()}
mAP_per_class[key]["AP"].append(np.mean(doc_AP_per_class[key][0]))
mAP_per_class[key]["weights"].append(doc_AP_per_class[key][1])
for key in mAP_per_class.keys():
mAP_per_class[key] = np.average(
mAP_per_class[key]["AP"], weights=mAP_per_class[key]["weights"]
)
return mAP_per_class
def compute_global_precision_per_class_per_threshold(list_AP_per_class):
"""
Compute the mAP_cer per class and per threshold for several samples
"""
mAP_per_class = dict()
for doc_AP_per_class in list_AP_per_class:
for key in doc_AP_per_class.keys():
if key not in mAP_per_class:
mAP_per_class[key] = dict()
for threshold in range(5, 51, 5):
mAP_per_class[key][threshold] = {
"precision": list(),
"weights": list(),
}
for i, threshold in enumerate(range(5, 51, 5)):
mAP_per_class[key][threshold]["precision"].append(
np.mean(doc_AP_per_class[key][0][i])
)
mAP_per_class[key][threshold]["weights"].append(
doc_AP_per_class[key][1]
)
for key_class in mAP_per_class.keys():
for threshold in mAP_per_class[key_class]:
mAP_per_class[key_class][threshold] = np.average(
mAP_per_class[key_class][threshold]["precision"],
weights=mAP_per_class[key_class][threshold]["weights"],
)
return mAP_per_class
def str_to_graph_simara(str):
"""
Compute graph from string of layout tokens for the SIMARA dataset at page level
"""
begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys()))
layout_token_sequence = keep_only_ner_tokens(str, begin_layout_tokens)
g = nx.DiGraph()
g.add_node("D", type="document", level=2, page=0)
token_name_dict = {"": "I", "": "D", "": "S", "": "C", "": "P", "": "A"}
num = dict()
previous_node = None
for token in begin_layout_tokens:
num[token] = 0
for ind, c in enumerate(layout_token_sequence):
num[c] += 1
node_name = "{}_{}".format(token_name_dict[c], num[c])
g.add_node(node_name, type=token_name_dict[c], level=1, page=0)
g.add_edge("D", node_name)
if previous_node:
g.add_edge(previous_node, node_name)
previous_node = node_name
return g
def graph_edit_distance(g1, g2):
"""
Compute graph edit distance between two graphs
"""
for v in nx.optimize_graph_edit_distance(
g1,
g2,
node_ins_cost=lambda node: 1,
node_del_cost=lambda node: 1,
node_subst_cost=lambda node1, node2: 0 if node1["type"] == node2["type"] else 1,
edge_ins_cost=lambda edge: 1,
edge_del_cost=lambda edge: 1,
edge_subst_cost=lambda edge1, edge2: 0 if edge1 == edge2 else 1,
):
new_edit = v
return new_edit
......@@ -6,20 +6,10 @@ import pickle
import cv2
import numpy as np
import torch
from fontTools.ttLib import TTFont
from PIL import Image, ImageDraw, ImageFont
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.ocr.utils import LM_str_to_ind
from dan.utils import (
pad_image,
pad_image_width_random,
pad_images,
pad_sequences_1D,
rand,
rand_uniform,
randint,
)
from dan.utils import pad_image, pad_images, pad_sequences_1D, randint
class OCRDatasetManager(DatasetManager):
......@@ -35,12 +25,6 @@ class OCRDatasetManager(DatasetManager):
params["charset"] if "charset" in params else self.get_merged_charsets()
)
if (
"synthetic_data" in self.params["config"]
and self.params["config"]["synthetic_data"]
):
self.synthetic_data = self.params["config"]["synthetic_data"]
self.tokens = {
"pad": params["config"]["padding_token"],
}
......@@ -102,14 +86,6 @@ class OCRDataset(GenericDataset):
[params["config"]["height_divisor"], params["config"]["width_divisor"], 1]
)
self.collate_function = OCRCollateFunction
self.synthetic_id = 0
if (
"synthetic_data" in self.params["config"]
and self.params["config"]["synthetic_data"]
):
self.synthetic_data = self.params["config"]["synthetic_data"]
else:
self.synthetic_data = None
def __getitem__(self, idx):
sample = copy.deepcopy(self.samples[idx])
......@@ -120,13 +96,6 @@ class OCRDataset(GenericDataset):
sample, self.params["config"]["preprocessings"]
)
if (
"synthetic_data" in self.params["config"]
and self.params["config"]["synthetic_data"]
and self.set_name == "train"
):
sample = self.generate_synthetic_data(sample)
# Data augmentation
sample["img"] = self.apply_data_augmentation(sample["img"])
......@@ -213,237 +182,6 @@ class OCRDataset(GenericDataset):
sample["token_label"].insert(0, self.tokens["start"])
return sample
def generate_synthetic_data(self, sample):
proba = self.get_syn_proba_lines()
if rand() > proba:
return sample
if (
"mode" in self.synthetic_data
and self.synthetic_data["mode"] == "line_hw_to_printed"
):
sample["img"] = self.generate_typed_text_line_image(sample["label"])
return sample
return self.generate_synthetic_page_sample()
def get_syn_max_lines(self):
if self.synthetic_data["curriculum"]:
nb_samples = self.training_info["step"] * self.params["batch_size"]
max_nb_lines = min(
self.synthetic_data["max_nb_lines"],
(nb_samples - self.synthetic_data["curr_start"])
// self.synthetic_data["curr_step"]
+ 1,
)
return max(self.synthetic_data["min_nb_lines"], max_nb_lines)
return self.synthetic_data["max_nb_lines"]
def get_syn_proba_lines(self):
if self.synthetic_data["init_proba"] == self.synthetic_data["end_proba"]:
return self.synthetic_data["init_proba"]
nb_samples = self.training_info["step"] * self.params["batch_size"]
if self.synthetic_data["start_scheduler_at_max_line"]:
max_step = self.synthetic_data["num_steps_proba"]
current_step = max(
0,
min(
nb_samples
- self.synthetic_data["curr_step"]
* (
self.synthetic_data["max_nb_lines"]
- self.synthetic_data["min_nb_lines"]
),
max_step,
),
)
proba = (
self.synthetic_data["init_proba"]
if self.get_syn_max_lines() < self.synthetic_data["max_nb_lines"]
else self.synthetic_data["proba_scheduler_function"](
self.synthetic_data["init_proba"],
self.synthetic_data["end_proba"],
current_step,
max_step,
)
)
else:
proba = self.synthetic_data["proba_scheduler_function"](
self.synthetic_data["init_proba"],
self.synthetic_data["end_proba"],
min(nb_samples, self.synthetic_data["num_steps_proba"]),
self.synthetic_data["num_steps_proba"],
)
return proba
def generate_synthetic_page_sample(self):
max_nb_lines_per_page = self.get_syn_max_lines()
crop = (
self.synthetic_data["crop_curriculum"]
and max_nb_lines_per_page < self.synthetic_data["max_nb_lines"]
)
sample = {"name": "synthetic_data_{}".format(self.synthetic_id), "path": None}
self.synthetic_id += 1
nb_pages = 2 if "double" in self.synthetic_data["dataset_level"] else 1
background_sample = copy.deepcopy(self.samples[randint(0, len(self))])
pages = list()
backgrounds = list()
h, w, c = background_sample["img"].shape
page_width = w // 2 if nb_pages == 2 else w
for i in range(nb_pages):
nb_lines_per_page = randint(
self.synthetic_data["min_nb_lines"], max_nb_lines_per_page + 1
)
background = (
np.ones((h, page_width, c), dtype=background_sample["img"].dtype) * 255
)
if i == 0 and nb_pages == 2:
background[:, -2:, :] = 0
backgrounds.append(background)
if "READ_2016" in self.params["datasets"].keys():
side = background_sample["pages_label"][i]["side"]
coords = {
"left": int(0.15 * page_width)
if side == "left"
else int(0.05 * page_width),
"right": int(0.95 * page_width)
if side == "left"
else int(0.85 * page_width),
"top": int(0.05 * h),
"bottom": int(0.85 * h),
}
pages.append(
self.generate_synthetic_read2016_page(
background,
coords,
side=side,
crop=crop,
nb_lines=nb_lines_per_page,
)
)
elif "RIMES" in self.params["datasets"].keys():
pages.append(
self.generate_synthetic_rimes_page(
background, nb_lines=nb_lines_per_page, crop=crop
)
)
else:
# Get a page-level transcription and split it by lines
texts = self.samples[randint(0, len(self))]["label"].split("\n")
# Select some lines to be generated
n_lines = min(len(texts), nb_lines_per_page)
i = randint(0, len(texts) - n_lines + 1)
texts = texts[i : i + n_lines]
# Generate the synthetic document (of n_lines)
pages.append(
self.generate_typed_text_paragraph_image(
texts=texts,
same_font_size=True,
)
)
if nb_pages == 1:
sample["img"] = pages[0][0]
sample["label_raw"] = pages[0][1]["raw"]
sample["label_begin"] = pages[0][1]["begin"]
sample["label_sem"] = pages[0][1]["sem"]
sample["label"] = pages[0][1]
else:
if pages[0][0].shape[0] != pages[1][0].shape[0]:
max_height = max(pages[0][0].shape[0], pages[1][0].shape[0])
backgrounds[0] = backgrounds[0][:max_height]
backgrounds[0][: pages[0][0].shape[0]] = pages[0][0]
backgrounds[1] = backgrounds[1][:max_height]
backgrounds[1][: pages[1][0].shape[0]] = pages[1][0]
pages[0][0] = backgrounds[0]
pages[1][0] = backgrounds[1]
sample["label_raw"] = pages[0][1]["raw"] + "\n" + pages[1][1]["raw"]
sample["label_begin"] = pages[0][1]["begin"] + pages[1][1]["begin"]
sample["label_sem"] = pages[0][1]["sem"] + pages[1][1]["sem"]
sample["img"] = np.concatenate([pages[0][0], pages[1][0]], axis=1)
sample["label"] = sample["label_raw"]
if "" in self.charset:
sample["label"] = sample["label_begin"]
if "" in self.charset:
sample["label"] = sample["label_sem"]
sample["unchanged_label"] = sample["label"]
sample = self.convert_sample_labels(sample)
return sample
def generate_typed_text_line_image(self, text):
return generate_typed_text_line_image(text, self.synthetic_data["config"])
def generate_typed_text_paragraph_image(
self, texts, padding_value=255, max_pad_left_ratio=0.1, same_font_size=False
):
"""
Generate a synthetic paragraph from a list of texts where each line is generated with a different font.
"""
if same_font_size:
images = list()
txt_color = self.synthetic_data["config"]["text_color_default"]
bg_color = self.synthetic_data["config"]["background_color_default"]
font_size = randint(
self.synthetic_data["config"]["font_size_min"],
self.synthetic_data["config"]["font_size_max"] + 1,
)
for text in texts:
font_path = self.synthetic_data["config"]["valid_fonts"][
randint(0, len(self.synthetic_data["config"]["valid_fonts"]))
]
fnt = ImageFont.truetype(font_path, font_size)
text_width, text_height = fnt.getsize(text)
padding_top = get_random_padding(
self.synthetic_data["config"]["padding_top_ratio_min"],
self.synthetic_data["config"]["padding_top_ratio_max"],
text_height,
)
padding_bottom = get_random_padding(
self.synthetic_data["config"]["padding_bottom_ratio_min"],
self.synthetic_data["config"]["padding_bottom_ratio_max"],
text_height,
)
padding_left = get_random_padding(
self.synthetic_data["config"]["padding_left_ratio_min"],
self.synthetic_data["config"]["padding_left_ratio_max"],
text_width,
)
padding_right = get_random_padding(
self.synthetic_data["config"]["padding_right_ratio_min"],
self.synthetic_data["config"]["padding_right_ratio_max"],
text_width,
)
padding = [padding_top, padding_bottom, padding_left, padding_right]
images.append(
generate_typed_text_line_image_from_params(
text,
fnt,
bg_color,
txt_color,
self.synthetic_data["config"]["color_mode"],
padding,
)
)
else:
images = [generate_typed_text_line_image(t) for t in texts]
max_width = max([img.shape[1] for img in images])
padded_images = [
pad_image_width_random(
img,
max_width,
padding_value=padding_value,
max_pad_left_ratio=max_pad_left_ratio,
)
for img in images
]
label = {
"sem": "\n".join(texts),
"begin": "\n".join(texts),
"raw": "\n".join(texts),
}
# image, label, n_col
return [np.concatenate(padded_images, axis=0), label, 1]
class OCRCollateFunction:
"""
......@@ -491,73 +229,3 @@ class OCRCollateFunction:
)
return formatted_batch_data
def get_random_padding(min_ratio, max_ratio, text_size):
"""
Compute random padding value as a ratio of text width or height
"""
return int(rand_uniform(min_ratio, max_ratio) * text_size)
def generate_typed_text_line_image(
text, config, bg_color=(255, 255, 255), txt_color=(0, 0, 0)
):
if text == "":
text = " "
if "text_color_default" in config:
txt_color = config["text_color_default"]
if "background_color_default" in config:
bg_color = config["background_color_default"]
font_path = config["valid_fonts"][randint(0, len(config["valid_fonts"]))]
font_size = randint(config["font_size_min"], config["font_size_max"] + 1)
fnt = ImageFont.truetype(font_path, font_size)
text_width, text_height = fnt.getsize(text)
padding_top = get_random_padding(
config["padding_top_ratio_min"],
config["padding_top_ratio_max"],
text_height,
)
padding_bottom = get_random_padding(
config["padding_bottom_ratio_min"],
config["padding_bottom_ratio_max"],
text_height,
)
padding_left = get_random_padding(
config["padding_left_ratio_min"],
config["padding_left_ratio_max"],
text_width,
)
padding_right = get_random_padding(
config["padding_right_ratio_min"],
config["padding_right_ratio_max"],
text_width,
)
padding = [padding_top, padding_bottom, padding_left, padding_right]
return generate_typed_text_line_image_from_params(
text, fnt, bg_color, txt_color, config["color_mode"], padding
)
def generate_typed_text_line_image_from_params(
text, font, bg_color, txt_color, color_mode, padding
):
padding_top, padding_bottom, padding_left, padding_right = padding
text_width, text_height = font.getsize(text)
img_height = padding_top + padding_bottom + text_height
img_width = padding_left + padding_right + text_width
img = Image.new(color_mode, (img_width, img_height), color=bg_color)
d = ImageDraw.Draw(img)
d.text((padding_left, padding_bottom), text, font=font, fill=txt_color, spacing=0)
return np.array(img)
def char_in_font(unicode_char, font_path):
with TTFont(font_path) as font:
for cmap in font["cmap"].tables:
if cmap.isUnicode():
if ord(unicode_char) in cmap.cmap:
return True
return False
......@@ -2,7 +2,6 @@
import copy
import json
import os
import pickle
import random
from time import time
......@@ -11,7 +10,6 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yaml
from PIL import Image
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.nn.init import kaiming_uniform_
......@@ -975,73 +973,6 @@ class OCRManager(GenericTrainingManager):
super(OCRManager, self).__init__(params)
self.params["model_params"]["vocab_size"] = len(self.dataset.charset)
def generate_syn_line_dataset(self, name):
"""
Generate synthetic line dataset from currently loaded dataset
"""
dataset_name = list(self.params["dataset_params"]["datasets"].keys())[0]
path = os.path.join(
os.path.dirname(self.params["dataset_params"]["datasets"][dataset_name]),
name,
)
os.makedirs(path, exist_ok=True)
charset = set()
dataset = None
gt = {"train": dict(), "val": dict(), "test": dict()}
for set_name in ["train", "val", "test"]:
set_path = os.path.join(path, set_name)
os.makedirs(set_path, exist_ok=True)
if set_name == "train":
dataset = self.dataset.train_dataset
elif set_name == "val":
dataset = self.dataset.valid_datasets["{}-val".format(dataset_name)]
elif set_name == "test":
self.dataset.generate_test_loader(
"{}-test".format(dataset_name),
[
(dataset_name, "test"),
],
)
dataset = self.dataset.test_datasets["{}-test".format(dataset_name)]
samples = list()
for sample in dataset.samples:
for line_label in sample["label"].split("\n"):
for chunk in [
line_label[i : i + 100] for i in range(0, len(line_label), 100)
]:
charset = charset.union(set(chunk))
if len(chunk) > 0:
samples.append(
{
"path": sample["path"],
"label": chunk,
}
)
for i, sample in enumerate(samples):
ext = sample["path"].split(".")[-1]
img_name = "{}_{}.{}".format(set_name, i, ext)
img_path = os.path.join(set_path, img_name)
img = dataset.generate_typed_text_line_image(sample["label"])
Image.fromarray(img).save(img_path)
gt[set_name][img_name] = {
"text": sample["label"],
}
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
with open(os.path.join(path / "labels.json"), "w") as f:
json.dump(
gt,
f,
sort_keys=True,
indent=4,
)
with open(os.path.join(path / "charset.pkl"), "wb") as f:
pickle.dump(sorted(list(charset)), f)
class Manager(OCRManager):
def __init__(self, params):
......
# -*- coding: utf-8 -*-
import json
import os
import pickle
from PIL import Image
from dan.manager.training import GenericTrainingManager
class OCRManager(GenericTrainingManager):
def __init__(self, params):
super(OCRManager, self).__init__(params)
self.params["model_params"]["vocab_size"] = len(self.dataset.charset)
def generate_syn_line_dataset(self, name):
"""
Generate synthetic line dataset from currently loaded dataset
"""
dataset_name = list(self.params["dataset_params"]["datasets"].keys())[0]
path = os.path.join(
os.path.dirname(self.params["dataset_params"]["datasets"][dataset_name]),
name,
)
os.makedirs(path, exist_ok=True)
charset = set()
dataset = None
gt = {"train": dict(), "val": dict(), "test": dict()}
for set_name in ["train", "val", "test"]:
set_path = os.path.join(path, set_name)
os.makedirs(set_path, exist_ok=True)
if set_name == "train":
dataset = self.dataset.train_dataset
elif set_name == "val":
dataset = self.dataset.valid_datasets["{}-val".format(dataset_name)]
elif set_name == "test":
self.dataset.generate_test_loader(
"{}-test".format(dataset_name),
[
(dataset_name, "test"),
],
)
dataset = self.dataset.test_datasets["{}-test".format(dataset_name)]
samples = list()
for sample in dataset.samples:
for line_label in sample["label"].split("\n"):
for chunk in [
line_label[i : i + 100] for i in range(0, len(line_label), 100)
]:
charset = charset.union(set(chunk))
if len(chunk) > 0:
samples.append(
{
"path": sample["path"],
"label": chunk,
}
)
for i, sample in enumerate(samples):
ext = sample["path"].split(".")[-1]
img_name = "{}_{}.{}".format(set_name, i, ext)
img_path = os.path.join(set_path, img_name)
img = dataset.generate_typed_text_line_image(sample["label"])
Image.fromarray(img).save(img_path)
gt[set_name][img_name] = {
"text": sample["label"],
}
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
with open(os.path.join(path / "labels.json"), "w") as f:
json.dump(
gt,
f,
sort_keys=True,
indent=4,
)
with open(os.path.join(path / "charset.pkl"), "wb") as f:
pickle.dump(sorted(list(charset)), f)
......@@ -4,7 +4,6 @@ Train a new DAN model.
"""
from dan.ocr.document import add_document_parser
from dan.ocr.line import add_line_parser
def add_train_parser(subcommands) -> None:
......@@ -15,5 +14,4 @@ def add_train_parser(subcommands) -> None:
)
subcommands = parser.add_subparsers(metavar="subcommand")
add_line_parser(subcommands)
add_document_parser(subcommands)
......@@ -126,7 +126,6 @@ def get_config():
},
],
"augmentation": aug_config(0.9, 0.1),
"synthetic_data": None,
},
},
"model_params": {
......@@ -269,19 +268,6 @@ def serialize_config(config):
serialized_config["training_params"]["nb_gpu"]
)
if (
"synthetic_data" in config["dataset_params"]["config"]
and config["dataset_params"]["config"]["synthetic_data"]
):
# The Probability scheduler is a function and needs to be casted to string
serialized_config["dataset_params"]["config"]["synthetic_data"][
"proba_scheduler_function"
] = str(
serialized_config["dataset_params"]["config"]["synthetic_data"][
"proba_scheduler_function"
]
)
return serialized_config
......
# -*- coding: utf-8 -*-
from dan.ocr.line.generate_synthetic import run as run_generate
from dan.ocr.line.train import run as run_train
def add_generate_parser(subcommands) -> None:
parser = subcommands.add_parser(
"generate",
description="Generate synthetic data to train DAN models.",
help="Generate synthetic data to train DAN models.",
)
parser.set_defaults(func=run_generate)
def add_line_parser(subcommands) -> None:
parser = subcommands.add_parser(
"line",
description="Train a DAN model at line level.",
help="Train a DAN model at line level.",
)
parser.set_defaults(func=run_train)
# -*- coding: utf-8 -*-
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from dan.models import FCN_Encoder
from dan.ocr.line.model_utils import Decoder
from dan.ocr.line.utils import TrainerLineCTC
from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
from dan.transforms import line_aug_config
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
params["training_params"]["ddp_rank"] = rank
model = TrainerLineCTC(params)
model.generate_syn_line_dataset(
"READ_2016_syn_line"
) # ["RIMES_syn_line", "READ_2016_syn_line"]
def run():
dataset_name = "READ_2016"
dataset_level = "page"
params = {
"dataset_params": {
"datasets": {
dataset_name: "../../../Datasets/formatted/{}_{}".format(
dataset_name, dataset_level
),
},
"train": {
"name": "{}-train".format(dataset_name),
"datasets": [
(dataset_name, "train"),
],
},
"val": {
"{}-val".format(dataset_name): [
(dataset_name, "val"),
],
},
"config": {
"load_in_memory": False, # Load all images in CPU memory
"worker_per_gpu": 4,
"width_divisor": 8, # Image width will be divided by 8
"height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value
"padding_token": 1000, # Label padding value (None: default value is chosen)
"padding_mode": "br", # Padding at bottom and right
"charset_mode": "CTC", # add blank token
"constraints": [], # Padding for CTC requirements if necessary
"normalize": True, # Normalize with mean and variance of training dataset
"preprocessings": [],
# Augmentation techniques to use at training time
"augmentation": line_aug_config(0.9, 0.1),
#
"synthetic_data": {
"mode": "line_hw_to_printed",
"init_proba": 1,
"end_proba": 1,
"num_steps_proba": 1e5,
"proba_scheduler_function": exponential_scheduler,
"config": {
"background_color_default": (255, 255, 255),
"background_color_eps": 15,
"text_color_default": (0, 0, 0),
"text_color_eps": 15,
"font_size_min": 30,
"font_size_max": 50,
"color_mode": "RGB",
"padding_left_ratio_min": 0.02,
"padding_left_ratio_max": 0.1,
"padding_right_ratio_min": 0.02,
"padding_right_ratio_max": 0.1,
"padding_top_ratio_min": 0.02,
"padding_top_ratio_max": 0.2,
"padding_bottom_ratio_min": 0.02,
"padding_bottom_ratio_max": 0.2,
},
},
},
},
"model_params": {
# Model classes to use for each module
"models": {
"encoder": FCN_Encoder,
"decoder": Decoder,
},
"transfer_learning": None,
"input_channels": 3, # 1 for grayscale images, 3 for RGB ones (or grayscale as RGB)
"enc_size": 256,
"dropout_scheduler": {
"function": exponential_dropout_scheduler,
"T": 5e4,
},
"dropout": 0.5,
},
"training_params": {
"output_folder": "FCN_Encoder_read_syn_line_all_pad_max_cursive", # folder names for logs and weights
"max_nb_epochs": 10000, # max number of epochs for the training
"max_training_time": 3600
* 24
* 1.9, # max training time limit (in seconds)
"load_epoch": "last", # ["best", "last"], to load weights from best epoch or last trained epoch
"interval_save_weights": None, # None: keep best and last only
"use_ddp": False, # Use DistributedDataParallel
"use_amp": True, # Enable automatic mix-precision
"nb_gpu": torch.cuda.device_count(),
"batch_size": 1, # mini-batch size per GPU
"optimizers": {
"all": {
"class": Adam,
"args": {
"lr": 0.0001,
"amsgrad": False,
},
}
},
"lr_schedulers": None,
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format(dataset_name),
"train_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes to run on cpu only
},
}
if (
params["training_params"]["use_ddp"]
and not params["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"]
)
else:
train_and_test(0, params)
# -*- coding: utf-8 -*-
from torch.nn import AdaptiveMaxPool2d, Conv1d, Module
from torch.nn.functional import log_softmax
class Decoder(Module):
def __init__(self, params):
super(Decoder, self).__init__()
self.vocab_size = params["vocab_size"]
self.ada_pool = AdaptiveMaxPool2d((1, None))
self.end_conv = Conv1d(
in_channels=params["enc_size"],
out_channels=self.vocab_size + 1,
kernel_size=1,
)
def forward(self, x):
x = self.ada_pool(x).squeeze(2)
x = self.end_conv(x)
return log_softmax(x, dim=1)
# -*- coding: utf-8 -*-
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from dan.models import FCN_Encoder
from dan.ocr.line.model_utils import Decoder
from dan.ocr.line.utils import TrainerLineCTC
from dan.schedulers import exponential_dropout_scheduler, exponential_scheduler
from dan.transforms import line_aug_config
def train_and_test(rank, params):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
params["training_params"]["ddp_rank"] = rank
model = TrainerLineCTC(params)
model.load_model()
# Model trains until max_time_training or max_nb_epochs is reached
model.train()
# load weights giving best CER on valid set
model.params["training_params"]["load_epoch"] = "best"
model.load_model()
# compute metrics on train, valid and test sets (in eval conditions)
metrics = [
"cer",
"wer",
"wer_no_punct",
"time",
]
for dataset_name in params["dataset_params"]["datasets"].keys():
for set_name in [
"test",
"val",
"train",
]:
model.predict(
"{}-{}".format(dataset_name, set_name),
[
(dataset_name, set_name),
],
metrics,
output=True,
)
def run():
dataset_name = "READ_2016" # ["RIMES", "READ_2016"]
dataset_level = "syn_line"
params = {
"dataset_params": {
"datasets": {
dataset_name: "../../../Datasets/formatted/{}_{}".format(
dataset_name, dataset_level
),
},
"train": {
"name": "{}-train".format(dataset_name),
"datasets": [
(dataset_name, "train"),
],
},
"val": {
"{}-val".format(dataset_name): [
(dataset_name, "val"),
],
},
"config": {
"load_in_memory": True, # Load all images in CPU memory
"worker_per_gpu": 8, # Num of parallel processes per gpu for data loading
"width_divisor": 8, # Image width will be divided by 8
"height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value
"padding_token": 1000, # Label padding value (None: default value is chosen)
"padding_mode": "br", # Padding at bottom and right
"charset_mode": "CTC", # add blank token
"constraints": [
"CTC_line",
], # Padding for CTC requirements if necessary
"normalize": True, # Normalize with mean and variance of training dataset
"padding": {
"min_height": "max", # Pad to reach max height of training samples
"min_width": "max", # Pad to reach max width of training samples
"min_pad": None,
"max_pad": None,
"mode": "br", # Padding at bottom and right
"train_only": False, # Add padding at training time and evaluation time
},
"preprocessings": [
{
"type": "to_RGB",
# if grayscale image, produce RGB one (3 channels with same value) otherwise do nothing
},
],
# Augmentation techniques to use at training time
"augmentation": line_aug_config(0.9, 0.1),
#
"synthetic_data": {
"mode": "line_hw_to_printed",
"init_proba": 1,
"end_proba": 1,
"num_steps_proba": 1e5,
"probadocument_scheduler_function": exponential_scheduler,
"config": {
"background_color_default": (255, 255, 255),
"background_color_eps": 15,
"text_color_default": (0, 0, 0),
"text_color_eps": 15,
"font_size_min": 30,
"font_size_max": 50,
"color_mode": "RGB",
"padding_left_ratio_min": 0.02,
"padding_left_ratio_max": 0.1,
"padding_right_ratio_min": 0.02,
"padding_right_ratio_max": 0.1,
"padding_top_ratio_min": 0.02,
"padding_top_ratio_max": 0.2,
"padding_bottom_ratio_min": 0.02,
"padding_bottom_ratio_max": 0.2,
},
},
},
},
"model_params": {
# Model classes to use for each module
"models": {
"encoder": FCN_Encoder,
"decoder": Decoder,
},
"transfer_learning": None,
"input_channels": 3, # 1 for grayscale images, 3 for RGB ones (or grayscale as RGB)
"enc_size": 256,
"dropout_scheduler": {
"function": exponential_dropout_scheduler,
"T": 5e4,
},
"dropout": 0.5,
},
"training_params": {
"output_folder": "outputs/FCN_read_2016_line_syn", # folder names for logs and weights
"max_nb_epochs": 10000, # max number of epochs for the training
"max_training_time": 3600
* 24
* 1.9, # max training time limit (in seconds)
"load_epoch": "last", # ["best", "last"], to load weights from best epoch or last trained epoch
"interval_save_weights": None, # None: keep best and last only
"use_ddp": False, # Use DistributedDataParallel
"use_amp": True, # Enable automatic mix-precision
"nb_gpu": torch.cuda.device_count(),
"batch_size": 16, # mini-batch size per GPU
"optimizers": {
"all": {
"class": Adam,
"args": {
"lr": 0.0001,
"amsgrad": False,
},
}
},
"lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
"train_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes to run on cpu only
},
}
if (
params["training_params"]["use_ddp"]
and not params["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"]
)
else:
train_and_test(0, params)
# -*- coding: utf-8 -*-
import re
import time
import torch
from torch.cuda.amp import autocast
from torch.nn import CTCLoss
from dan.manager.training import OCRManager
from dan.ocr.utils import LM_ind_to_str
class TrainerLineCTC(OCRManager):
def __init__(self, params):
super(TrainerLineCTC, self).__init__(params)
def train_batch(self, batch_data, metric_names):
"""
Forward and backward pass for training
"""
x = batch_data["imgs"].to(self.device)
y = batch_data["labels"].to(self.device)
x_reduced_len = [s[1] for s in batch_data["imgs_reduced_shape"]]
y_len = batch_data["labels_len"]
loss_ctc = CTCLoss(blank=self.dataset.tokens["blank"])
self.zero_optimizers()
with autocast(enabled=self.params["training_params"]["use_amp"]):
x = self.models["encoder"](x)
global_pred = self.models["decoder"](x)
loss = loss_ctc(global_pred.permute(2, 0, 1), y, x_reduced_len, y_len)
self.backward_loss(loss)
self.step_optimizers()
pred = torch.argmax(global_pred, dim=1).cpu().numpy()
values = {
"nb_samples": len(batch_data["raw_labels"]),
"loss_ctc": loss.item(),
"str_x": self.pred_to_str(pred, x_reduced_len),
"str_y": batch_data["raw_labels"],
}
return values
def evaluate_batch(self, batch_data, metric_names):
"""
Forward pass only for validation and test
"""
x = batch_data["imgs"].to(self.device)
y = batch_data["labels"].to(self.device)
x_reduced_len = [s[1] for s in batch_data["imgs_reduced_shape"]]
y_len = batch_data["labels_len"]
loss_ctc = CTCLoss(blank=self.dataset.tokens["blank"])
start_time = time.time()
with autocast(enabled=self.params["training_params"]["use_amp"]):
x = self.models["encoder"](x)
global_pred = self.models["decoder"](x)
loss = loss_ctc(global_pred.permute(2, 0, 1), y, x_reduced_len, y_len)
pred = torch.argmax(global_pred, dim=1).cpu().numpy()
str_x = self.pred_to_str(pred, x_reduced_len)
process_time = time.time() - start_time
values = {
"nb_samples": len(batch_data["raw_labels"]),
"loss_ctc": loss.item(),
"str_x": str_x,
"str_y": batch_data["raw_labels"],
"time": process_time,
}
return values
def ctc_remove_successives_identical_ind(self, ind):
res = []
for i in ind:
if res and res[-1] == i:
continue
res.append(i)
return res
def pred_to_str(self, pred, pred_len):
"""
convert prediction tokens to string
"""
ind_x = [pred[i][: pred_len[i]] for i in range(pred.shape[0])]
ind_x = [self.ctc_remove_successives_identical_ind(t) for t in ind_x]
str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in ind_x]
str_x = [re.sub("( )+", " ", t).strip(" ") for t in str_x]
return str_x
# Utils
::: dan.manager.utils
# Synthetic data generation
::: dan.ocr.line.generate_synthetic
# Line
# Model utils
::: dan.ocr.line.model_utils
# Training
::: dan.ocr.line.train
# Utils
::: dan.ocr.line.utils
# Generate
Use the `teklia-dan generate` command to generate synthetic data for training.
......@@ -8,8 +8,5 @@ When `teklia-dan` is installed in your environment, you may use the following co
`teklia-dan train`
: To train a new DAN model. More details in [the dedicated section](./train/index.md).
`teklia-dan generate`
: To generate synthetic data to train DAN models. More details in [the dedicated section](./generate.md).
`teklia-dan predict`
: To predict an image using a trained DAN model. More details in [the dedicated section](./predict.md).