Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (27)
Showing with 275 additions and 417 deletions
*gif filter=lfs diff=lfs merge=lfs -text
**/*.pt filter=lfs diff=lfs merge=lfs -text
......@@ -50,7 +50,7 @@ test:
- schedules
script:
- tox
- tox -- -v
# Make sure docs still build correctly
.docs:
......
.PHONY: release
release:
$(eval version:=$(shell cat VERSION))
echo $(version)
git commit VERSION -m "Version $(version)"
git tag $(version)
git push origin main $(version)
......@@ -4,7 +4,6 @@
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
## Installation
To use DAN in your own scripts, install it using pip:
......@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
### Synthetic data generation
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation.
### Model prediction
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation.
0.0.1
0.2.0-dev1
......@@ -15,12 +15,14 @@ from dan.transforms import apply_data_augmentation
class DatasetManager:
def __init__(self, params):
def __init__(self, params, device: str):
self.params = params
self.dataset_class = params["dataset_class"]
self.dataset_class = None
self.img_padding_value = params["config"]["padding_value"]
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
......@@ -115,7 +117,7 @@ class DatasetManager:
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
......@@ -129,7 +131,7 @@ class DatasetManager:
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
......@@ -174,7 +176,7 @@ class DatasetManager:
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
......@@ -268,20 +270,10 @@ class GenericDataset(Dataset):
"label": label,
"unchanged_label": label,
"path": os.path.abspath(filename),
"nb_cols": 1
if "nb_cols" not in gt[filename]
else gt[filename]["nb_cols"],
}
)
if load_in_memory:
samples[-1]["img"] = GenericDataset.load_image(filename)
if type(gt[filename]) is dict:
if "lines" in gt[filename].keys():
samples[-1]["raw_line_seg_label"] = gt[filename]["lines"]
if "paragraphs" in gt[filename].keys():
samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"]
if "pages" in gt[filename].keys():
samples[-1]["pages_label"] = gt[filename]["pages"]
return samples
def apply_preprocessing(self, preprocessings):
......@@ -346,7 +338,7 @@ class GenericDataset(Dataset):
for aug, set_name in zip(augs, ["train", "val", "test"]):
if aug and self.set_name == set_name:
return apply_data_augmentation(img, aug)
return img, list()
return img
def get_sample_img(self, i):
"""
......@@ -424,15 +416,6 @@ def apply_preprocessing(sample, preprocessings):
temp_img = np.expand_dims(temp_img, axis=2)
img = temp_img
resize_ratio = [ratio, ratio]
if resize_ratio != [1, 1] and "raw_line_seg_label" in sample:
for li in range(len(sample["raw_line_seg_label"])):
for side, ratio in zip(
(["bottom", "top"], ["right", "left"]), resize_ratio
):
for s in side:
sample["raw_line_seg_label"][li][s] = (
sample["raw_line_seg_label"][li][s] * ratio
)
sample["img"] = img
sample["resize_ratio"] = resize_ratio
......
......@@ -16,7 +16,6 @@ class MetricManager:
if "simara" in dataset_name and "page" in dataset_name:
self.post_processing_module = PostProcessingModuleSIMARA
self.matching_tokens = SIMARA_MATCHING_TOKENS
self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_simara
else:
self.matching_tokens = dict()
......@@ -56,7 +55,6 @@ class MetricManager:
self.epoch_metrics = {
"nb_samples": list(),
"names": list(),
"ids": list(),
}
for metric_name in self.metric_names:
......@@ -87,71 +85,83 @@ class MetricManager:
value = None
if output:
if metric_name in ["nb_samples", "weights"]:
value = np.sum(self.epoch_metrics[metric_name])
value = int(np.sum(self.epoch_metrics[metric_name]))
elif metric_name in [
"time",
]:
total_time = np.sum(self.epoch_metrics[metric_name])
sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = round(sample_time, 4)
value = total_time
value = int(np.sum(self.epoch_metrics[metric_name]))
sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = float(round(sample_time, 4))
elif metric_name == "loer":
display_values["pper"] = round(
np.sum(self.epoch_metrics["nb_pp_op_layout"])
/ np.sum(self.epoch_metrics["nb_gt_layout_token"]),
4,
display_values["pper"] = float(
round(
np.sum(self.epoch_metrics["nb_pp_op_layout"])
/ np.sum(self.epoch_metrics["nb_gt_layout_token"]),
4,
)
)
elif metric_name == "map_cer_per_class":
value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
value = float(
compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
)
for key in value.keys():
display_values["map_cer_" + key] = round(value[key], 4)
display_values["map_cer_" + key] = float(round(value[key], 4))
continue
elif metric_name == "layout_precision_per_class_per_threshold":
value = compute_global_precision_per_class_per_threshold(
self.epoch_metrics["map_cer"]
value = float(
compute_global_precision_per_class_per_threshold(
self.epoch_metrics["map_cer"]
)
)
for key_class in value.keys():
for threshold in value[key_class].keys():
display_values[
"map_cer_{}_{}".format(key_class, threshold)
] = round(value[key_class][threshold], 4)
] = float(round(value[key_class][threshold], 4))
continue
if metric_name == "cer":
value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum(
self.epoch_metrics["nb_chars"]
value = float(
np.sum(self.epoch_metrics["edit_chars"])
/ np.sum(self.epoch_metrics["nb_chars"])
)
if output:
display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"])
display_values["nb_chars"] = int(
np.sum(self.epoch_metrics["nb_chars"])
)
elif metric_name == "wer":
value = np.sum(self.epoch_metrics["edit_words"]) / np.sum(
self.epoch_metrics["nb_words"]
value = float(
np.sum(self.epoch_metrics["edit_words"])
/ np.sum(self.epoch_metrics["nb_words"])
)
if output:
display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
display_values["nb_words"] = int(
np.sum(self.epoch_metrics["nb_words"])
)
elif metric_name == "wer_no_punct":
value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum(
self.epoch_metrics["nb_words_no_punct"]
value = float(
np.sum(self.epoch_metrics["edit_words_no_punct"])
/ np.sum(self.epoch_metrics["nb_words_no_punct"])
)
if output:
display_values["nb_words_no_punct"] = np.sum(
self.epoch_metrics["nb_words_no_punct"]
display_values["nb_words_no_punct"] = int(
np.sum(self.epoch_metrics["nb_words_no_punct"])
)
elif metric_name in [
"loss",
"loss_ctc",
"loss_ce",
"syn_max_lines",
"syn_prob_lines",
]:
value = np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
value = float(
np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
)
)
elif metric_name == "map_cer":
value = compute_global_mAP(self.epoch_metrics[metric_name])
value = float(compute_global_mAP(self.epoch_metrics[metric_name]))
elif metric_name == "loer":
value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum(
self.epoch_metrics["nb_nodes_and_edges"]
value = float(
np.sum(self.epoch_metrics["edit_graph"])
/ np.sum(self.epoch_metrics["nb_nodes_and_edges"])
)
elif value is None:
continue
......@@ -207,11 +217,8 @@ class MetricManager:
]
metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
elif metric_name in [
"loss_ctc",
"loss_ce",
"loss",
"syn_max_lines",
"syn_prob_lines",
]:
metrics[metric_name] = [
values[metric_name],
......@@ -235,7 +242,7 @@ class MetricManager:
pp_pred.append(pp_module.post_process(pred))
metrics["nb_pp_op_layout"].append(pp_module.num_op)
metrics["nb_gt_layout_token"] = [
len(keep_only_tokens(str_x, self.layout_tokens))
len(keep_only_ner_tokens(str_x, self.layout_tokens))
for str_x in values["str_x"]
]
edit_and_num_items = [
......@@ -251,16 +258,16 @@ class MetricManager:
return self.epoch_metrics[name]
def keep_only_tokens(str, tokens):
def keep_only_ner_tokens(str, tokens):
"""
Remove all but layout tokens from string
Remove all but ner tokens from string
"""
return re.sub("([^" + tokens + "])", "", str)
def keep_all_but_tokens(str, tokens):
def keep_all_but_ner_tokens(str, tokens):
"""
Remove all layout tokens from string
Remove all ner tokens from string
"""
return re.sub("([" + tokens + "])", "", str)
......@@ -299,7 +306,7 @@ def format_string_for_wer(str, layout_tokens, remove_punct=False):
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation
if layout_tokens is not None:
str = keep_all_but_tokens(
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([ \n])+", " ", str).strip() # keep only one space character
......@@ -311,7 +318,7 @@ def format_string_for_cer(str, layout_tokens):
Format string for CER computation: remove layout tokens and extra spaces
"""
if layout_tokens is not None:
str = keep_all_but_tokens(
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks
......@@ -367,8 +374,8 @@ def compute_layout_precision_per_threshold(
pred, begin_token, end_token, associated_score=score, order_by_score=True
)
gt_list = extract_by_tokens(gt, begin_token, end_token)
pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list]
gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list]
pred_list = [keep_all_but_ner_tokens(p, layout_tokens) for p in pred_list]
gt_list = [keep_all_but_ner_tokens(gt, layout_tokens) for gt in gt_list]
precision_per_threshold = [
compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold / 100)
for threshold in range(5, 51, 5)
......@@ -503,7 +510,7 @@ def str_to_graph_simara(str):
Compute graph from string of layout tokens for the SIMARA dataset at page level
"""
begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys()))
layout_token_sequence = keep_only_tokens(str, begin_layout_tokens)
layout_token_sequence = keep_only_ner_tokens(str, begin_layout_tokens)
g = nx.DiGraph()
g.add_node("D", type="document", level=2, page=0)
token_name_dict = {"": "I", "": "D", "": "S", "": "C", "": "P", "": "A"}
......@@ -538,16 +545,3 @@ def graph_edit_distance(g1, g2):
):
new_edit = v
return new_edit
def edit_and_num_items_for_ged_from_str_simara(str_gt, str_pred):
"""
Compute graph edit distance and num nodes/edges for normalized graph edit distance
For the SIMARA dataset
"""
g_gt = str_to_graph_simara(str_gt)
g_pred = str_to_graph_simara(str_pred)
return (
graph_edit_distance(g_gt, g_pred),
g_gt.number_of_nodes() + g_gt.number_of_edges(),
)
......@@ -9,13 +9,11 @@ import torch
from fontTools.ttLib import TTFont
from PIL import Image, ImageDraw, ImageFont
from dan import logger
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.ocr.utils import LM_str_to_ind
from dan.utils import (
pad_image,
pad_image_width_random,
pad_image_width_right,
pad_images,
pad_sequences_1D,
rand,
......@@ -29,9 +27,10 @@ class OCRDatasetManager(DatasetManager):
Specific class to handle OCR/HTR tasks
"""
def __init__(self, params):
super(OCRDatasetManager, self).__init__(params)
def __init__(self, params, device: str):
super(OCRDatasetManager, self).__init__(params, device)
self.dataset_class = OCRDataset
self.charset = (
params["charset"] if "charset" in params else self.get_merged_charsets()
)
......@@ -41,30 +40,16 @@ class OCRDatasetManager(DatasetManager):
and self.params["config"]["synthetic_data"]
):
self.synthetic_data = self.params["config"]["synthetic_data"]
if "config" in self.synthetic_data:
self.synthetic_data["config"]["valid_fonts"] = self.get_valid_fonts()
if "new_tokens" in params:
self.charset = sorted(
list(set(self.charset).union(set(params["new_tokens"])))
)
self.tokens = {
"pad": params["config"]["padding_token"],
}
if self.params["config"]["charset_mode"].lower() == "ctc":
self.tokens["blank"] = len(self.charset)
self.tokens["pad"] = (
self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 1
)
self.params["config"]["padding_token"] = self.tokens["pad"]
elif self.params["config"]["charset_mode"] == "seq2seq":
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1
self.tokens["pad"] = (
self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2
)
self.params["config"]["padding_token"] = self.tokens["pad"]
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1
self.tokens["pad"] = (
self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2
)
self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self):
"""
......@@ -103,34 +88,6 @@ class OCRDatasetManager(DatasetManager):
[s["img"].shape[1] for s in self.train_dataset.samples]
)
def get_valid_fonts(self):
"""
Select fonts that are compatible with the alphabet
"""
font_path = self.synthetic_data["font_path"]
alphabet = self.charset.copy()
special_chars = ["\n"]
alphabet = [char for char in alphabet if char not in special_chars]
valid_fonts = list()
for fold_detail in os.walk(font_path):
if fold_detail[2]:
for font_name in fold_detail[2]:
if ".ttf" not in font_name:
continue
font_path = os.path.join(fold_detail[0], font_name)
to_add = True
if alphabet is not None:
for char in alphabet:
if not char_in_font(char, font_path):
to_add = False
break
if to_add:
valid_fonts.append(font_path)
else:
valid_fonts.append(font_path)
logger.info(f"Found {len(valid_fonts)} fonts.")
return valid_fonts
class OCRDataset(GenericDataset):
"""
......@@ -171,9 +128,7 @@ class OCRDataset(GenericDataset):
sample = self.generate_synthetic_data(sample)
# Data augmentation
sample["img"], sample["applied_da"] = self.apply_data_augmentation(
sample["img"]
)
sample["img"] = self.apply_data_augmentation(sample["img"])
if "max_size" in self.params["config"] and self.params["config"]["max_size"]:
max_ratio = max(
......@@ -191,49 +146,18 @@ class OCRDataset(GenericDataset):
if "normalize" in self.params["config"] and self.params["config"]["normalize"]:
sample["img"] = (sample["img"] - self.mean) / self.std
sample["img_shape"] = sample["img"].shape
sample["img_reduced_shape"] = np.ceil(
sample["img_shape"] / self.reduce_dims_factor
sample["img"].shape / self.reduce_dims_factor
).astype(int)
# Padding to handle CTC requirements
if self.set_name == "train":
max_label_len = 0
height = 1
ctc_padding = False
if "CTC_line" in self.params["config"]["constraints"]:
max_label_len = sample["label_len"]
ctc_padding = True
if "CTC_va" in self.params["config"]["constraints"]:
max_label_len = max(sample["line_label_len"])
ctc_padding = True
if "CTC_pg" in self.params["config"]["constraints"]:
max_label_len = sample["label_len"]
height = max(sample["img_reduced_shape"][0], 1)
ctc_padding = True
if (
ctc_padding
and 2 * max_label_len + 1 > sample["img_reduced_shape"][1] * height
):
sample["img"] = pad_image_width_right(
sample["img"],
int(
np.ceil((2 * max_label_len + 1) / height)
* self.reduce_dims_factor[1]
),
self.padding_value,
)
sample["img_shape"] = sample["img"].shape
sample["img_reduced_shape"] = np.ceil(
sample["img_shape"] / self.reduce_dims_factor
).astype(int)
sample["img_reduced_shape"] = [
max(1, t) for t in sample["img_reduced_shape"]
]
sample["img_position"] = [
[0, sample["img_shape"][0]],
[0, sample["img_shape"][1]],
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
# Padding constraints to handle model needs
if "padding" in self.params["config"] and self.params["config"]["padding"]:
......@@ -264,10 +188,6 @@ class OCRDataset(GenericDataset):
padding_mode=self.params["config"]["padding"]["mode"],
return_position=True,
)
sample["img_reduced_position"] = [
np.ceil(p / factor).astype(int)
for p, factor in zip(sample["img_position"], self.reduce_dims_factor[:2])
]
return sample
def convert_labels(self):
......@@ -279,13 +199,10 @@ class OCRDataset(GenericDataset):
def convert_sample_labels(self, sample):
label = sample["label"]
line_labels = label.split("\n")
if "remove_linebreaks" in self.params["config"]["constraints"]:
full_label = label.replace("\n", " ").replace(" ", " ")
word_labels = full_label.split(" ")
else:
full_label = label
word_labels = label.replace("\n", " ").replace(" ", " ").split(" ")
sample["label"] = full_label
sample["token_label"] = LM_str_to_ind(self.charset, full_label)
......@@ -294,20 +211,6 @@ class OCRDataset(GenericDataset):
sample["label_len"] = len(sample["token_label"])
if "add_sot" in self.params["config"]["constraints"]:
sample["token_label"].insert(0, self.tokens["start"])
sample["line_label"] = line_labels
sample["token_line_label"] = [
LM_str_to_ind(self.charset, label) for label in line_labels
]
sample["line_label_len"] = [len(label) for label in line_labels]
sample["nb_lines"] = len(line_labels)
sample["word_label"] = word_labels
sample["token_word_label"] = [
LM_str_to_ind(self.charset, label) for label in word_labels
]
sample["word_label_len"] = [len(label) for label in word_labels]
sample["nb_words"] = len(word_labels)
return sample
def generate_synthetic_data(self, sample):
......@@ -444,7 +347,6 @@ class OCRDataset(GenericDataset):
sample["label_begin"] = pages[0][1]["begin"]
sample["label_sem"] = pages[0][1]["sem"]
sample["label"] = pages[0][1]
sample["nb_cols"] = pages[0][2]
else:
if pages[0][0].shape[0] != pages[1][0].shape[0]:
max_height = max(pages[0][0].shape[0], pages[1][0].shape[0])
......@@ -458,7 +360,6 @@ class OCRDataset(GenericDataset):
sample["label_begin"] = pages[0][1]["begin"] + pages[1][1]["begin"]
sample["label_sem"] = pages[0][1]["sem"] + pages[1][1]["sem"]
sample["img"] = np.concatenate([pages[0][0], pages[1][0]], axis=1)
sample["nb_cols"] = pages[0][2] + pages[1][2]
sample["label"] = sample["label_raw"]
if "" in self.charset:
sample["label"] = sample["label_begin"]
......@@ -555,123 +456,40 @@ class OCRCollateFunction:
self.config = config
def __call__(self, batch_data):
names = [batch_data[i]["name"] for i in range(len(batch_data))]
ids = [
batch_data[i]["name"].split("/")[-1].split(".")[0]
for i in range(len(batch_data))
]
applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))]
labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
labels = torch.tensor(labels).long()
reverse_labels = [
[
batch_data[i]["token_label"][0],
]
+ batch_data[i]["token_label"][-2:0:-1]
+ [
batch_data[i]["token_label"][-1],
]
for i in range(len(batch_data))
]
reverse_labels = pad_sequences_1D(
reverse_labels, padding_value=self.label_padding_value
)
reverse_labels = torch.tensor(reverse_labels).long()
labels_len = [batch_data[i]["label_len"] for i in range(len(batch_data))]
raw_labels = [batch_data[i]["label"] for i in range(len(batch_data))]
unchanged_labels = [
batch_data[i]["unchanged_label"] for i in range(len(batch_data))
]
nb_cols = [batch_data[i]["nb_cols"] for i in range(len(batch_data))]
nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))]
line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))]
line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))]
pad_line_token = list()
line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_lines)):
current_lines = [
line_token[j][i] if i < nb_lines[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_line_token.append(
torch.tensor(
pad_sequences_1D(
current_lines, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_lines[j]:
line_len[j].append(0)
line_len = [i for i in zip(*line_len)]
nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))]
word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))]
word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))]
pad_word_token = list()
word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_words)):
current_words = [
word_token[j][i] if i < nb_words[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_word_token.append(
torch.tensor(
pad_sequences_1D(
current_words, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_words[j]:
word_len[j].append(0)
word_len = [i for i in zip(*word_len)]
padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
imgs_shape = [batch_data[i]["img_shape"] for i in range(len(batch_data))]
imgs_reduced_shape = [
batch_data[i]["img_reduced_shape"] for i in range(len(batch_data))
]
imgs_position = [batch_data[i]["img_position"] for i in range(len(batch_data))]
imgs_reduced_position = [
batch_data[i]["img_reduced_position"] for i in range(len(batch_data))
]
imgs = pad_images(
imgs, padding_value=self.img_padding_value, padding_mode=padding_mode
)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = {
"names": names,
"ids": ids,
"nb_lines": nb_lines,
"nb_cols": nb_cols,
"labels": labels,
"reverse_labels": reverse_labels,
"raw_labels": raw_labels,
"unchanged_labels": unchanged_labels,
"labels_len": labels_len,
"imgs": imgs,
"imgs_shape": imgs_shape,
"imgs_reduced_shape": imgs_reduced_shape,
"imgs_position": imgs_position,
"imgs_reduced_position": imgs_reduced_position,
"line_raw": line_raw,
"line_labels": pad_line_token,
"line_labels_len": line_len,
"nb_words": nb_words,
"word_raw": word_raw,
"word_labels": pad_word_token,
"word_labels_len": word_len,
"applied_da": applied_da,
formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))]
for formatted_key, initial_key in zip(
[
"names",
"labels_len",
"raw_labels",
"imgs_position",
"imgs_reduced_shape",
],
["name", "label_len", "label", "img_position", "img_reduced_shape"],
)
}
formatted_batch_data.update(
{
"imgs": imgs,
"labels": labels,
}
)
return formatted_batch_data
......
......@@ -4,14 +4,13 @@ import json
import os
import pickle
import random
import sys
from datetime import date
from time import time
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yaml
from PIL import Image
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
......@@ -21,6 +20,7 @@ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.manager.metrics import MetricManager
from dan.manager.ocr import OCRDatasetManager
from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler
......@@ -77,9 +77,7 @@ class GenericTrainingManager:
"""
Create output folders for results and checkpoints
"""
output_path = os.path.join(
"outputs", self.params["training_params"]["output_folder"]
)
output_path = self.params["training_params"]["output_folder"]
os.makedirs(output_path, exist_ok=True)
checkpoints_path = os.path.join(output_path, "checkpoints")
os.makedirs(checkpoints_path, exist_ok=True)
......@@ -118,8 +116,8 @@ class GenericTrainingManager:
if "worker_per_gpu" not in self.params["dataset_params"]
else self.params["dataset_params"]["worker_per_gpu"]
)
self.dataset = self.params["dataset_params"]["dataset_manager"](
self.params["dataset_params"]
self.dataset = OCRDatasetManager(
self.params["dataset_params"], device=self.device
)
self.dataset.load_datasets()
self.dataset.load_ddp_samplers()
......@@ -164,7 +162,10 @@ class GenericTrainingManager:
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
self.params["model_params"]["device"] = self.device.type
if self.device == "cpu":
self.params["model_params"]["device"] = "cpu"
else:
self.params["model_params"]["device"] = self.device.type
# Print GPU info
# global
if (
......@@ -342,12 +343,6 @@ class GenericTrainingManager:
if c in old_charset:
new_weights[i] = weights[old_charset.index(c)]
pretrained_chars.append(c)
if (
"transfered_charset_last_is_ctc_blank" in self.params["model_params"]
and self.params["model_params"]["transfered_charset_last_is_ctc_blank"]
):
new_weights[-1] = weights[-1]
pretrained_chars.append("<blank>")
checkpoint["{}_state_dict".format(state_dict_name)][key] = new_weights
self.models[model_name].load_state_dict(
{key: checkpoint["{}_state_dict".format(state_dict_name)][key]},
......@@ -522,7 +517,6 @@ class GenericTrainingManager:
return
params = copy.deepcopy(self.params)
params = class_to_str_dict(params)
params["date"] = date.today().strftime("%d/%m/%Y")
total_params = 0
for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name])
......@@ -532,21 +526,6 @@ class GenericTrainingManager:
]
total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params)
params["hardware"] = dict()
if self.device != "cpu":
for i in range(self.params["training_params"]["nb_gpu"]):
params["hardware"][str(i)] = "{} {}".format(
torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)
)
else:
params["hardware"]["0"] = "CPU"
params["software"] = {
"python_version": sys.version,
"pytorch_version": torch.__version__,
"cuda_version": torch.version.cuda,
"cudnn_version": torch.backends.cudnn.version(),
}
with open(path, "w") as f:
json.dump(params, f, indent=4)
......@@ -640,7 +619,6 @@ class GenericTrainingManager:
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
batch_metrics["ids"] = batch_data["ids"]
# Merge metrics if Distributed Data Parallel is used
if self.params["training_params"]["use_ddp"]:
batch_metrics = self.merge_ddp_metrics(batch_metrics)
......@@ -785,7 +763,6 @@ class GenericTrainingManager:
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
batch_metrics["ids"] = batch_data["ids"]
# merge metrics values if Distributed Data Parallel is used
if self.params["training_params"]["use_ddp"]:
batch_metrics = self.merge_ddp_metrics(batch_metrics)
......@@ -805,9 +782,6 @@ class GenericTrainingManager:
mlflow_logging,
self.is_master,
)
if "cer_by_nb_cols" in metric_names:
self.log_cer_by_nb_cols(set_name)
return display_values
def predict(
......@@ -841,7 +815,6 @@ class GenericTrainingManager:
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
batch_metrics["ids"] = batch_data["ids"]
# merge batch metrics if Distributed Data Parallel is used
if self.params["training_params"]["use_ddp"]:
batch_metrics = self.merge_ddp_metrics(batch_metrics)
......@@ -868,22 +841,22 @@ class GenericTrainingManager:
metrics = self.metric_manager[custom_name].get_display_values(output=True)
path = os.path.join(
self.paths["results"],
"predict_{}_{}.txt".format(custom_name, self.latest_epoch),
"predict_{}_{}.yaml".format(custom_name, self.latest_epoch),
)
with open(path, "w") as f:
for metric_name in metrics.keys():
f.write("{}: {}\n".format(metric_name, metrics[metric_name]))
yaml.dump(metrics, stream=f)
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
if mlflow_logging:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
def output_pred(self, name):
path = os.path.join(
self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)
self.paths["results"], "pred_{}_{}.yaml".format(name, self.latest_epoch)
)
pred = "\n".join(self.metric_manager[name].get("pred"))
with open(path, "w") as f:
f.write(pred)
yaml.dump(pred, stream=f)
def launch_ddp(self):
"""
......@@ -916,14 +889,12 @@ class GenericTrainingManager:
"edit_chars_force_len",
"edit_chars_curr",
"nb_chars_curr",
"ids",
]:
metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name])
elif metric_name in [
"nb_samples",
"loss",
"loss_ce",
"loss_ctc",
"loss_ce_end",
]:
metrics[metric_name] = self.sum_ddp_metric(
......@@ -1045,7 +1016,6 @@ class OCRManager(GenericTrainingManager):
{
"path": sample["path"],
"label": chunk,
"nb_cols": 1,
}
)
......@@ -1058,7 +1028,6 @@ class OCRManager(GenericTrainingManager):
Image.fromarray(img).save(img_path)
gt[set_name][img_name] = {
"text": sample["label"],
"nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1,
}
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
......@@ -1089,7 +1058,7 @@ class Manager(OCRManager):
info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config
return info_dict
def apply_teacher_forcing(self, y, y_len, error_rate):
def add_label_noise(self, y, y_len, error_rate):
y_error = y.clone()
for b in range(len(y_len)):
for i in range(1, y_len[b]):
......@@ -1109,37 +1078,30 @@ class Manager(OCRManager):
reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
y_len = batch_data["labels_len"]
# add errors in teacher forcing
if (
"teacher_forcing_error_rate" in self.params["training_params"]
and self.params["training_params"]["teacher_forcing_error_rate"] is not None
):
error_rate = self.params["training_params"]["teacher_forcing_error_rate"]
simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate)
elif "teacher_forcing_scheduler" in self.params["training_params"]:
if "label_noise_scheduler" in self.params["training_params"]:
error_rate = (
self.params["training_params"]["teacher_forcing_scheduler"][
self.params["training_params"]["label_noise_scheduler"][
"min_error_rate"
]
+ min(
self.latest_step,
self.params["training_params"]["teacher_forcing_scheduler"][
self.params["training_params"]["label_noise_scheduler"][
"total_num_steps"
],
)
* (
self.params["training_params"]["teacher_forcing_scheduler"][
self.params["training_params"]["label_noise_scheduler"][
"max_error_rate"
]
- self.params["training_params"]["teacher_forcing_scheduler"][
- self.params["training_params"]["label_noise_scheduler"][
"min_error_rate"
]
)
/ self.params["training_params"]["teacher_forcing_scheduler"][
/ self.params["training_params"]["label_noise_scheduler"][
"total_num_steps"
]
)
simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate)
simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate)
else:
simulated_y_pred = y
......@@ -1193,12 +1155,6 @@ class Manager(OCRManager):
"str_x": str_x,
"loss": sum_loss.item(),
"loss_ce": loss_ce.item(),
"syn_max_lines": self.dataset.train_dataset.get_syn_max_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
"syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
}
return values
......@@ -1252,10 +1208,6 @@ class Manager(OCRManager):
else:
features = self.models["encoder"](x)
features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.models["decoder"].features_updater.get_pos_features(
features
)
......@@ -1284,7 +1236,6 @@ class Manager(OCRManager):
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat(
[
predicted_tokens,
......
......@@ -54,7 +54,6 @@ class OCRManager(GenericTrainingManager):
{
"path": sample["path"],
"label": chunk,
"nb_cols": 1,
}
)
......@@ -67,7 +66,6 @@ class OCRManager(GenericTrainingManager):
Image.fromarray(img).save(img_path)
gt[set_name][img_name] = {
"text": sample["label"],
"nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1,
}
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
......
......@@ -12,7 +12,6 @@ from torch.optim import Adam
from dan import logger
from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager
from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler
......@@ -87,8 +86,6 @@ def get_config():
"aws_secret_access_key": "",
},
"dataset_params": {
"dataset_manager": OCRDatasetManager,
"dataset_class": OCRDataset,
"datasets": {
dataset_name: "{}/{}_{}{}".format(
dataset_path, dataset_name, dataset_level, dataset_variant
......@@ -117,7 +114,6 @@ def get_config():
"height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value
"padding_token": None, # Label padding value
"charset_mode": "seq2seq", # add end-of-transcription and start-of-transcription tokens to charset
"constraints": [
"add_eot",
"add_sot",
......@@ -180,7 +176,7 @@ def get_config():
},
},
"training_params": {
"output_folder": "dan_esposalles_record", # folder name for checkpoint and results
"output_folder": "outputs/dan_esposalles_record", # folder name for checkpoint and results
"max_nb_epochs": 710, # maximum number of epochs before to stop
"max_training_time": 3600
* 24
......@@ -215,8 +211,6 @@ def get_config():
"cer",
"wer",
"wer_no_punct",
"syn_max_lines",
"syn_prob_lines",
], # Metrics name for training
"eval_metrics": [
"cer",
......@@ -226,7 +220,7 @@ def get_config():
"force_cpu": False, # True for debug purposes
"max_char_prediction": 1000, # max number of token prediction
# Keep teacher forcing rate to 20% during whole training
"teacher_forcing_scheduler": {
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4,
......@@ -254,12 +248,6 @@ def serialize_config(config):
serialized_config["mlflow"]["aws_secret_access_key"] = ""
# Get the name of the class
serialized_config["dataset_params"]["dataset_manager"] = serialized_config[
"dataset_params"
]["dataset_manager"].__name__
serialized_config["dataset_params"]["dataset_class"] = serialized_config[
"dataset_params"
]["dataset_class"].__name__
serialized_config["model_params"]["models"]["encoder"] = serialized_config[
"model_params"
]["models"]["encoder"].__name__
......
......@@ -6,7 +6,6 @@ import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.models import FCN_Encoder
from dan.ocr.line.model_utils import Decoder
from dan.ocr.line.utils import TrainerLineCTC
......@@ -35,8 +34,6 @@ def run():
dataset_level = "page"
params = {
"dataset_params": {
"dataset_manager": OCRDatasetManager,
"dataset_class": OCRDataset,
"datasets": {
dataset_name: "../../../Datasets/formatted/{}_{}".format(
dataset_name, dataset_level
......
......@@ -6,7 +6,6 @@ import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.models import FCN_Encoder
from dan.ocr.line.model_utils import Decoder
from dan.ocr.line.utils import TrainerLineCTC
......@@ -61,8 +60,6 @@ def run():
dataset_level = "syn_line"
params = {
"dataset_params": {
"dataset_manager": OCRDatasetManager,
"dataset_class": OCRDataset,
"datasets": {
dataset_name: "../../../Datasets/formatted/{}_{}".format(
dataset_name, dataset_level
......@@ -151,7 +148,7 @@ def run():
"dropout": 0.5,
},
"training_params": {
"output_folder": "FCN_read_2016_line_syn", # folder names for logs and weights
"output_folder": "outputs/FCN_read_2016_line_syn", # folder names for logs and weights
"max_nb_epochs": 10000, # max number of epochs for the training
"max_training_time": 3600
* 24
......
......@@ -148,10 +148,6 @@ class DAN:
features = self.encoder(input_tensor.float())
features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1
......@@ -179,7 +175,6 @@ class DAN:
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat(
[
predicted_tokens,
......
......@@ -41,7 +41,3 @@ def exponential_dropout_scheduler(dropout_rate, step, max_step):
def exponential_scheduler(init_value, end_value, step, max_step):
step = min(step, max_step - 1)
return init_value - (init_value - end_value) * (1 - np.exp(-10 * step / max_step))
def linear_scheduler(init_value, end_value, step, max_step):
return init_value + step * (end_value - init_value) / max_step
......@@ -328,9 +328,8 @@ def apply_data_augmentation(img, da_config):
"""
Apply data augmentation strategy on input image
"""
applied_da = list()
if da_config["proba"] != 1 and rand() > da_config["proba"]:
return img, applied_da
return img
# Convert to PIL Image
img = img[:, :, 0] if img.shape[2] == 1 else img
......@@ -345,12 +344,11 @@ def apply_data_augmentation(img, da_config):
for augmenter in augmenters:
img = augmenter(img)
applied_da.append(type(augmenter).__name__)
# convert to numpy array
img = np.array(img)
img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img
return img, applied_da
return img
def apply_transform(img, transform):
......
......@@ -18,7 +18,6 @@ All hyperparameters are specified and editable in the training scripts (meaning
| `dataset_params.config.width_divisor` | Factor to reduce the height of the feature vector before feeding the decoder. | `int` | `32` |
| `dataset_params.config.padding_value` | Image padding value. | `int` | `0` |
| `dataset_params.config.padding_token` | Transcription padding value. | `int` | `None` |
| `dataset_params.config.charset_mode` | Whether to add end-of-transcription and start-of-transcription tokens to charset. | `str` | `seq2seq` |
| `dataset_params.config.constraints` | Whether to add end-of-transcription and start-of-transcription tokens in labels. | `list` | `["add_eot", "add_sot"]` |
| `dataset_params.config.normalize` | Normalize with mean and variance of training dataset. | `bool` | `True` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
......@@ -270,9 +269,9 @@ The following configuration can be used by default. It must be defined in `datas
| `training_params.train_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` |
| `training_params.force_cpu` | Whether to train on CPU (for debugging). | `bool` | `False` |
| `training_params.max_char_prediction` | Maximum number of characters to predict. | `int` | `1000` |
| `training_params.teacher_forcing_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.teacher_forcing_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.teacher_forcing_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
| `training_params.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
## MLFlow logging
......
arkindex-export==0.1.2
arkindex-export==0.1.3
boto3==1.26.124
editdistance==0.6.2
fontTools==4.39.3
......
......@@ -4,6 +4,12 @@ from pathlib import Path
import pytest
from arkindex_export import open_database
from torch.optim import Adam
from dan.decoder import GlobalHTADecoder
from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config
FIXTURES = Path(__file__).resolve().parent / "data"
......@@ -35,3 +41,131 @@ def demo_db(database_path):
Open connection towards a known demo database
"""
open_database(database_path)
@pytest.fixture
def training_config():
return {
"dataset_params": {
"datasets": {
"training": "./tests/data/training/training_dataset",
},
"train": {
"name": "training-train",
"datasets": [
("training", "train"),
],
},
"val": {
"training-val": [
("training", "val"),
],
},
"test": {
"training-test": [
("training", "test"),
],
},
"config": {
"load_in_memory": True, # Load all images in CPU memory
"width_divisor": 8, # Image width will be divided by 8
"height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value
"padding_token": None, # Label padding value
"constraints": [
"add_eot",
"add_sot",
], # add end-of-transcription and start-of-transcription tokens in labels
"normalize": True, # Normalize with mean and variance of training dataset
"preprocessings": [
{
"type": "to_RGB",
# if grayscaled image, produce RGB one (3 channels with same value) otherwise do nothing
},
],
"augmentation": aug_config(0.9, 0.1),
"synthetic_data": None,
},
},
"model_params": {
"models": {
"encoder": FCN_Encoder,
"decoder": GlobalHTADecoder,
},
"transfer_learning": None,
"transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model
"additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset
"input_channels": 3, # number of channels of input image
"dropout": 0.5, # dropout rate for encoder
"enc_dim": 256, # dimension of extracted features
"nb_layers": 5, # encoder
"h_max": 500, # maximum height for encoder output (for 2D positional embedding)
"w_max": 1000, # maximum width for encoder output (for 2D positional embedding)
"l_max": 15000, # max predicted sequence (for 1D positional embedding)
"dec_num_layers": 8, # number of transformer decoder layers
"dec_num_heads": 4, # number of heads in transformer decoder layers
"dec_res_dropout": 0.1, # dropout in transformer decoder layers
"dec_pred_dropout": 0.1, # dropout rate before decision layer
"dec_att_dropout": 0.1, # dropout rate in multi head attention
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window
# Curriculum dropout
"dropout_scheduler": {
"function": exponential_dropout_scheduler,
"T": 5e4,
},
},
"training_params": {
"output_folder": "dan_trained_model", # folder name for checkpoint and results
"max_nb_epochs": 4, # maximum number of epochs before to stop
"max_training_time": 1200, # maximum time before to stop (in seconds)
"load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate
"interval_save_weights": None, # None: keep best and last only
"batch_size": 2, # mini-batch size for training
"valid_batch_size": 2, # mini-batch size for valdiation
"use_ddp": False, # Use DistributedDataParallel
"nb_gpu": 0,
"optimizers": {
"all": {
"class": Adam,
"args": {
"lr": 0.0001,
"amsgrad": False,
},
},
},
"lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "training-val", # Which dataset to focus on to select best weights
"train_metrics": [
"loss_ce",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval_metrics": [
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": True, # True for debug purposes
"max_char_prediction": 30, # max number of token prediction
# Keep teacher forcing rate to 20% during whole training
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4,
},
},
}
@pytest.fixture
def prediction_data_path():
return FIXTURES / "prediction"