Skip to content
Snippets Groups Projects
Commit a26125a5 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Save metrics results to YAML instead of plain text

parent 24b5eb14
No related branches found
No related tags found
1 merge request!122Save metrics results to YAML instead of plain text
...@@ -87,54 +87,66 @@ class MetricManager: ...@@ -87,54 +87,66 @@ class MetricManager:
value = None value = None
if output: if output:
if metric_name in ["nb_samples", "weights"]: if metric_name in ["nb_samples", "weights"]:
value = np.sum(self.epoch_metrics[metric_name]) value = int(np.sum(self.epoch_metrics[metric_name]))
elif metric_name in [ elif metric_name in [
"time", "time",
]: ]:
total_time = np.sum(self.epoch_metrics[metric_name]) value = int(np.sum(self.epoch_metrics[metric_name]))
sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"]) sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = round(sample_time, 4) display_values["sample_time"] = float(round(sample_time, 4))
value = total_time
elif metric_name == "loer": elif metric_name == "loer":
display_values["pper"] = round( display_values["pper"] = float(
np.sum(self.epoch_metrics["nb_pp_op_layout"]) round(
/ np.sum(self.epoch_metrics["nb_gt_layout_token"]), np.sum(self.epoch_metrics["nb_pp_op_layout"])
4, / np.sum(self.epoch_metrics["nb_gt_layout_token"]),
4,
)
) )
elif metric_name == "map_cer_per_class": elif metric_name == "map_cer_per_class":
value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"]) value = float(
compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
)
for key in value.keys(): for key in value.keys():
display_values["map_cer_" + key] = round(value[key], 4) display_values["map_cer_" + key] = float(round(value[key], 4))
continue continue
elif metric_name == "layout_precision_per_class_per_threshold": elif metric_name == "layout_precision_per_class_per_threshold":
value = compute_global_precision_per_class_per_threshold( value = float(
self.epoch_metrics["map_cer"] compute_global_precision_per_class_per_threshold(
self.epoch_metrics["map_cer"]
)
) )
for key_class in value.keys(): for key_class in value.keys():
for threshold in value[key_class].keys(): for threshold in value[key_class].keys():
display_values[ display_values[
"map_cer_{}_{}".format(key_class, threshold) "map_cer_{}_{}".format(key_class, threshold)
] = round(value[key_class][threshold], 4) ] = float(round(value[key_class][threshold], 4))
continue continue
if metric_name == "cer": if metric_name == "cer":
value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum( value = float(
self.epoch_metrics["nb_chars"] np.sum(self.epoch_metrics["edit_chars"])
/ np.sum(self.epoch_metrics["nb_chars"])
) )
if output: if output:
display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"]) display_values["nb_chars"] = int(
np.sum(self.epoch_metrics["nb_chars"])
)
elif metric_name == "wer": elif metric_name == "wer":
value = np.sum(self.epoch_metrics["edit_words"]) / np.sum( value = float(
self.epoch_metrics["nb_words"] np.sum(self.epoch_metrics["edit_words"])
/ np.sum(self.epoch_metrics["nb_words"])
) )
if output: if output:
display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) display_values["nb_words"] = int(
np.sum(self.epoch_metrics["nb_words"])
)
elif metric_name == "wer_no_punct": elif metric_name == "wer_no_punct":
value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum( value = float(
self.epoch_metrics["nb_words_no_punct"] np.sum(self.epoch_metrics["edit_words_no_punct"])
/ np.sum(self.epoch_metrics["nb_words_no_punct"])
) )
if output: if output:
display_values["nb_words_no_punct"] = np.sum( display_values["nb_words_no_punct"] = int(
self.epoch_metrics["nb_words_no_punct"] np.sum(self.epoch_metrics["nb_words_no_punct"])
) )
elif metric_name in [ elif metric_name in [
"loss", "loss",
...@@ -143,15 +155,18 @@ class MetricManager: ...@@ -143,15 +155,18 @@ class MetricManager:
"syn_max_lines", "syn_max_lines",
"syn_prob_lines", "syn_prob_lines",
]: ]:
value = np.average( value = float(
self.epoch_metrics[metric_name], np.average(
weights=np.array(self.epoch_metrics["nb_samples"]), self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
)
) )
elif metric_name == "map_cer": elif metric_name == "map_cer":
value = compute_global_mAP(self.epoch_metrics[metric_name]) value = float(compute_global_mAP(self.epoch_metrics[metric_name]))
elif metric_name == "loer": elif metric_name == "loer":
value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum( value = float(
self.epoch_metrics["nb_nodes_and_edges"] np.sum(self.epoch_metrics["edit_graph"])
/ np.sum(self.epoch_metrics["nb_nodes_and_edges"])
) )
elif value is None: elif value is None:
continue continue
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import yaml
from PIL import Image from PIL import Image
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
...@@ -865,22 +866,21 @@ class GenericTrainingManager: ...@@ -865,22 +866,21 @@ class GenericTrainingManager:
metrics = self.metric_manager[custom_name].get_display_values(output=True) metrics = self.metric_manager[custom_name].get_display_values(output=True)
path = os.path.join( path = os.path.join(
self.paths["results"], self.paths["results"],
"predict_{}_{}.txt".format(custom_name, self.latest_epoch), "predict_{}_{}.yaml".format(custom_name, self.latest_epoch),
) )
with open(path, "w") as f: with open(path, "w") as f:
for metric_name in metrics.keys(): yaml.dump(metrics, stream=f)
f.write("{}: {}\n".format(metric_name, metrics[metric_name]))
# Log mlflow artifacts # Log mlflow artifacts
mlflow.log_artifact(path, "predictions") mlflow.log_artifact(path, "predictions")
def output_pred(self, name): def output_pred(self, name):
path = os.path.join( path = os.path.join(
self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch) self.paths["results"], "pred_{}_{}.yaml".format(name, self.latest_epoch)
) )
pred = "\n".join(self.metric_manager[name].get("pred")) pred = "\n".join(self.metric_manager[name].get("pred"))
with open(path, "w") as f: with open(path, "w") as f:
f.write(pred) yaml.dump(pred, stream=f)
def launch_ddp(self): def launch_ddp(self):
""" """
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import pytest import pytest
import torch import torch
import yaml
from dan.ocr.document.train import train_and_test from dan.ocr.document.train import train_and_test
from tests.conftest import FIXTURES from tests.conftest import FIXTURES
...@@ -13,33 +14,33 @@ from tests.conftest import FIXTURES ...@@ -13,33 +14,33 @@ from tests.conftest import FIXTURES
( (
"best_0.pt", "best_0.pt",
"last_3.pt", "last_3.pt",
[ {
"nb_chars: 43", "nb_chars": 43,
"cer: 1.2791", "cer": 1.2791,
"nb_words: 9", "nb_words": 9,
"wer: 1.0", "wer": 1.0,
"nb_words_no_punct: 9", "nb_words_no_punct": 9,
"wer_no_punct: 1.0", "wer_no_punct": 1.0,
"nb_samples: 2", "nb_samples": 2,
], },
[ {
"nb_chars: 41", "nb_chars": 41,
"cer: 1.2683", "cer": 1.2683,
"nb_words: 9", "nb_words": 9,
"wer: 1.0", "wer": 1.0,
"nb_words_no_punct: 9", "nb_words_no_punct": 9,
"wer_no_punct: 1.0", "wer_no_punct": 1.0,
"nb_samples: 2", "nb_samples": 2,
], },
[ {
"nb_chars: 49", "nb_chars": 49,
"cer: 1.1429", "cer": 1.1429,
"nb_words: 9", "nb_words": 9,
"wer: 1.0", "wer": 1.0,
"nb_words_no_punct: 9", "nb_words_no_punct": 9,
"wer_no_punct: 1.0", "wer_no_punct": 1.0,
"nb_samples: 2", "nb_samples": 2,
], },
), ),
), ),
) )
...@@ -136,11 +137,12 @@ def test_train_and_test( ...@@ -136,11 +137,12 @@ def test_train_and_test(
tmp_path tmp_path
/ training_config["training_params"]["output_folder"] / training_config["training_params"]["output_folder"]
/ "results" / "results"
/ f"predict_training-{split_name}_0.txt" / f"predict_training-{split_name}_0.yaml"
).open( ).open() as f:
"r",
) as f:
res = f.read().splitlines()
# Remove the times from the results as they vary # Remove the times from the results as they vary
res = [result for result in res if "time" not in result] res = {
metric: value
for metric, value in yaml.safe_load(f).items()
if "time" not in metric
}
assert res == expected_res assert res == expected_res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment