Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (11)
......@@ -4,7 +4,6 @@
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
## Installation
To use DAN in your own scripts, install it using pip:
......@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
### Synthetic data generation
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation.
### Model prediction
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation.
......@@ -268,20 +268,10 @@ class GenericDataset(Dataset):
"label": label,
"unchanged_label": label,
"path": os.path.abspath(filename),
"nb_cols": 1
if "nb_cols" not in gt[filename]
else gt[filename]["nb_cols"],
}
)
if load_in_memory:
samples[-1]["img"] = GenericDataset.load_image(filename)
if type(gt[filename]) is dict:
if "lines" in gt[filename].keys():
samples[-1]["raw_line_seg_label"] = gt[filename]["lines"]
if "paragraphs" in gt[filename].keys():
samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"]
if "pages" in gt[filename].keys():
samples[-1]["pages_label"] = gt[filename]["pages"]
return samples
def apply_preprocessing(self, preprocessings):
......@@ -424,15 +414,6 @@ def apply_preprocessing(sample, preprocessings):
temp_img = np.expand_dims(temp_img, axis=2)
img = temp_img
resize_ratio = [ratio, ratio]
if resize_ratio != [1, 1] and "raw_line_seg_label" in sample:
for li in range(len(sample["raw_line_seg_label"])):
for side, ratio in zip(
(["bottom", "top"], ["right", "left"]), resize_ratio
):
for s in side:
sample["raw_line_seg_label"][li][s] = (
sample["raw_line_seg_label"][li][s] * ratio
)
sample["img"] = img
sample["resize_ratio"] = resize_ratio
......
......@@ -87,71 +87,84 @@ class MetricManager:
value = None
if output:
if metric_name in ["nb_samples", "weights"]:
value = np.sum(self.epoch_metrics[metric_name])
value = int(np.sum(self.epoch_metrics[metric_name]))
elif metric_name in [
"time",
]:
total_time = np.sum(self.epoch_metrics[metric_name])
sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = round(sample_time, 4)
value = total_time
value = int(np.sum(self.epoch_metrics[metric_name]))
sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = float(round(sample_time, 4))
elif metric_name == "loer":
display_values["pper"] = round(
np.sum(self.epoch_metrics["nb_pp_op_layout"])
/ np.sum(self.epoch_metrics["nb_gt_layout_token"]),
4,
display_values["pper"] = float(
round(
np.sum(self.epoch_metrics["nb_pp_op_layout"])
/ np.sum(self.epoch_metrics["nb_gt_layout_token"]),
4,
)
)
elif metric_name == "map_cer_per_class":
value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
value = float(
compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
)
for key in value.keys():
display_values["map_cer_" + key] = round(value[key], 4)
display_values["map_cer_" + key] = float(round(value[key], 4))
continue
elif metric_name == "layout_precision_per_class_per_threshold":
value = compute_global_precision_per_class_per_threshold(
self.epoch_metrics["map_cer"]
value = float(
compute_global_precision_per_class_per_threshold(
self.epoch_metrics["map_cer"]
)
)
for key_class in value.keys():
for threshold in value[key_class].keys():
display_values[
"map_cer_{}_{}".format(key_class, threshold)
] = round(value[key_class][threshold], 4)
] = float(round(value[key_class][threshold], 4))
continue
if metric_name == "cer":
value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum(
self.epoch_metrics["nb_chars"]
value = float(
np.sum(self.epoch_metrics["edit_chars"])
/ np.sum(self.epoch_metrics["nb_chars"])
)
if output:
display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"])
display_values["nb_chars"] = int(
np.sum(self.epoch_metrics["nb_chars"])
)
elif metric_name == "wer":
value = np.sum(self.epoch_metrics["edit_words"]) / np.sum(
self.epoch_metrics["nb_words"]
value = float(
np.sum(self.epoch_metrics["edit_words"])
/ np.sum(self.epoch_metrics["nb_words"])
)
if output:
display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
display_values["nb_words"] = int(
np.sum(self.epoch_metrics["nb_words"])
)
elif metric_name == "wer_no_punct":
value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum(
self.epoch_metrics["nb_words_no_punct"]
value = float(
np.sum(self.epoch_metrics["edit_words_no_punct"])
/ np.sum(self.epoch_metrics["nb_words_no_punct"])
)
if output:
display_values["nb_words_no_punct"] = np.sum(
self.epoch_metrics["nb_words_no_punct"]
display_values["nb_words_no_punct"] = int(
np.sum(self.epoch_metrics["nb_words_no_punct"])
)
elif metric_name in [
"loss",
"loss_ctc",
"loss_ce",
"syn_max_lines",
"syn_prob_lines",
]:
value = np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
value = float(
np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
)
)
elif metric_name == "map_cer":
value = compute_global_mAP(self.epoch_metrics[metric_name])
value = float(compute_global_mAP(self.epoch_metrics[metric_name]))
elif metric_name == "loer":
value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum(
self.epoch_metrics["nb_nodes_and_edges"]
value = float(
np.sum(self.epoch_metrics["edit_graph"])
/ np.sum(self.epoch_metrics["nb_nodes_and_edges"])
)
elif value is None:
continue
......@@ -210,8 +223,6 @@ class MetricManager:
"loss_ctc",
"loss_ce",
"loss",
"syn_max_lines",
"syn_prob_lines",
]:
metrics[metric_name] = [
values[metric_name],
......
......@@ -9,7 +9,6 @@ import torch
from fontTools.ttLib import TTFont
from PIL import Image, ImageDraw, ImageFont
from dan import logger
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.ocr.utils import LM_str_to_ind
from dan.utils import (
......@@ -42,13 +41,6 @@ class OCRDatasetManager(DatasetManager):
and self.params["config"]["synthetic_data"]
):
self.synthetic_data = self.params["config"]["synthetic_data"]
if "config" in self.synthetic_data:
self.synthetic_data["config"]["valid_fonts"] = self.get_valid_fonts()
if "new_tokens" in params:
self.charset = sorted(
list(set(self.charset).union(set(params["new_tokens"])))
)
self.tokens = {
"pad": params["config"]["padding_token"],
......@@ -104,34 +96,6 @@ class OCRDatasetManager(DatasetManager):
[s["img"].shape[1] for s in self.train_dataset.samples]
)
def get_valid_fonts(self):
"""
Select fonts that are compatible with the alphabet
"""
font_path = self.synthetic_data["font_path"]
alphabet = self.charset.copy()
special_chars = ["\n"]
alphabet = [char for char in alphabet if char not in special_chars]
valid_fonts = list()
for fold_detail in os.walk(font_path):
if fold_detail[2]:
for font_name in fold_detail[2]:
if ".ttf" not in font_name:
continue
font_path = os.path.join(fold_detail[0], font_name)
to_add = True
if alphabet is not None:
for char in alphabet:
if not char_in_font(char, font_path):
to_add = False
break
if to_add:
valid_fonts.append(font_path)
else:
valid_fonts.append(font_path)
logger.info(f"Found {len(valid_fonts)} fonts.")
return valid_fonts
class OCRDataset(GenericDataset):
"""
......@@ -445,7 +409,6 @@ class OCRDataset(GenericDataset):
sample["label_begin"] = pages[0][1]["begin"]
sample["label_sem"] = pages[0][1]["sem"]
sample["label"] = pages[0][1]
sample["nb_cols"] = pages[0][2]
else:
if pages[0][0].shape[0] != pages[1][0].shape[0]:
max_height = max(pages[0][0].shape[0], pages[1][0].shape[0])
......@@ -459,7 +422,6 @@ class OCRDataset(GenericDataset):
sample["label_begin"] = pages[0][1]["begin"] + pages[1][1]["begin"]
sample["label_sem"] = pages[0][1]["sem"] + pages[1][1]["sem"]
sample["img"] = np.concatenate([pages[0][0], pages[1][0]], axis=1)
sample["nb_cols"] = pages[0][2] + pages[1][2]
sample["label"] = sample["label_raw"]
if "" in self.charset:
sample["label"] = sample["label_begin"]
......@@ -587,7 +549,6 @@ class OCRCollateFunction:
batch_data[i]["unchanged_label"] for i in range(len(batch_data))
]
nb_cols = [batch_data[i]["nb_cols"] for i in range(len(batch_data))]
nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))]
line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))]
line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))]
......@@ -652,7 +613,6 @@ class OCRCollateFunction:
"names": names,
"ids": ids,
"nb_lines": nb_lines,
"nb_cols": nb_cols,
"labels": labels,
"reverse_labels": reverse_labels,
"raw_labels": raw_labels,
......
......@@ -12,6 +12,7 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yaml
from PIL import Image
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
......@@ -805,9 +806,6 @@ class GenericTrainingManager:
mlflow_logging,
self.is_master,
)
if "cer_by_nb_cols" in metric_names:
self.log_cer_by_nb_cols(set_name)
return display_values
def predict(
......@@ -868,22 +866,21 @@ class GenericTrainingManager:
metrics = self.metric_manager[custom_name].get_display_values(output=True)
path = os.path.join(
self.paths["results"],
"predict_{}_{}.txt".format(custom_name, self.latest_epoch),
"predict_{}_{}.yaml".format(custom_name, self.latest_epoch),
)
with open(path, "w") as f:
for metric_name in metrics.keys():
f.write("{}: {}\n".format(metric_name, metrics[metric_name]))
yaml.dump(metrics, stream=f)
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
def output_pred(self, name):
path = os.path.join(
self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)
self.paths["results"], "pred_{}_{}.yaml".format(name, self.latest_epoch)
)
pred = "\n".join(self.metric_manager[name].get("pred"))
with open(path, "w") as f:
f.write(pred)
yaml.dump(pred, stream=f)
def launch_ddp(self):
"""
......@@ -1045,7 +1042,6 @@ class OCRManager(GenericTrainingManager):
{
"path": sample["path"],
"label": chunk,
"nb_cols": 1,
}
)
......@@ -1058,7 +1054,6 @@ class OCRManager(GenericTrainingManager):
Image.fromarray(img).save(img_path)
gt[set_name][img_name] = {
"text": sample["label"],
"nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1,
}
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
......@@ -1193,12 +1188,6 @@ class Manager(OCRManager):
"str_x": str_x,
"loss": sum_loss.item(),
"loss_ce": loss_ce.item(),
"syn_max_lines": self.dataset.train_dataset.get_syn_max_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
"syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
}
return values
......@@ -1252,10 +1241,6 @@ class Manager(OCRManager):
else:
features = self.models["encoder"](x)
features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.models["decoder"].features_updater.get_pos_features(
features
)
......@@ -1284,7 +1269,6 @@ class Manager(OCRManager):
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat(
[
predicted_tokens,
......
......@@ -54,7 +54,6 @@ class OCRManager(GenericTrainingManager):
{
"path": sample["path"],
"label": chunk,
"nb_cols": 1,
}
)
......@@ -67,7 +66,6 @@ class OCRManager(GenericTrainingManager):
Image.fromarray(img).save(img_path)
gt[set_name][img_name] = {
"text": sample["label"],
"nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1,
}
if "line_label" in sample:
gt[set_name][img_name]["lines"] = sample["line_label"]
......
......@@ -212,8 +212,6 @@ def get_config():
"cer",
"wer",
"wer_no_punct",
"syn_max_lines",
"syn_prob_lines",
], # Metrics name for training
"eval_metrics": [
"cer",
......
......@@ -148,10 +148,6 @@ class DAN:
features = self.encoder(input_tensor.float())
features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1
......@@ -179,7 +175,6 @@ class DAN:
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat(
[
predicted_tokens,
......
......@@ -41,7 +41,3 @@ def exponential_dropout_scheduler(dropout_rate, step, max_step):
def exponential_scheduler(init_value, end_value, step, max_step):
step = min(step, max_step - 1)
return init_value - (init_value - end_value) * (1 - np.exp(-10 * step / max_step))
def linear_scheduler(init_value, end_value, step, max_step):
return init_value + step * (end_value - init_value) / max_step
......@@ -2,6 +2,7 @@
import pytest
import torch
import yaml
from dan.ocr.document.train import train_and_test
from tests.conftest import FIXTURES
......@@ -13,33 +14,33 @@ from tests.conftest import FIXTURES
(
"best_0.pt",
"last_3.pt",
[
"nb_chars: 43",
"cer: 1.2791",
"nb_words: 9",
"wer: 1.0",
"nb_words_no_punct: 9",
"wer_no_punct: 1.0",
"nb_samples: 2",
],
[
"nb_chars: 41",
"cer: 1.2683",
"nb_words: 9",
"wer: 1.0",
"nb_words_no_punct: 9",
"wer_no_punct: 1.0",
"nb_samples: 2",
],
[
"nb_chars: 49",
"cer: 1.1429",
"nb_words: 9",
"wer: 1.0",
"nb_words_no_punct: 9",
"wer_no_punct: 1.0",
"nb_samples: 2",
],
{
"nb_chars": 43,
"cer": 1.2791,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 41,
"cer": 1.2683,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 49,
"cer": 1.1429,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
),
),
)
......@@ -136,11 +137,12 @@ def test_train_and_test(
tmp_path
/ training_config["training_params"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_0.txt"
).open(
"r",
) as f:
res = f.read().splitlines()
/ f"predict_training-{split_name}_0.yaml"
).open() as f:
# Remove the times from the results as they vary
res = [result for result in res if "time" not in result]
res = {
metric: value
for metric, value in yaml.safe_load(f).items()
if "time" not in metric
}
assert res == expected_res