From a26125a566274af201dc6d210a3e2283141b74e7 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Fri, 12 May 2023 07:45:36 +0000 Subject: [PATCH] Save metrics results to YAML instead of plain text --- dan/manager/metrics.py | 75 ++++++++++++++++++++++++----------------- dan/manager/training.py | 10 +++--- tests/test_training.py | 68 +++++++++++++++++++------------------ 3 files changed, 85 insertions(+), 68 deletions(-) diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index fdcc0ea5..8c4682fe 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -87,54 +87,66 @@ class MetricManager: value = None if output: if metric_name in ["nb_samples", "weights"]: - value = np.sum(self.epoch_metrics[metric_name]) + value = int(np.sum(self.epoch_metrics[metric_name])) elif metric_name in [ "time", ]: - total_time = np.sum(self.epoch_metrics[metric_name]) - sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"]) - display_values["sample_time"] = round(sample_time, 4) - value = total_time + value = int(np.sum(self.epoch_metrics[metric_name])) + sample_time = value / np.sum(self.epoch_metrics["nb_samples"]) + display_values["sample_time"] = float(round(sample_time, 4)) elif metric_name == "loer": - display_values["pper"] = round( - np.sum(self.epoch_metrics["nb_pp_op_layout"]) - / np.sum(self.epoch_metrics["nb_gt_layout_token"]), - 4, + display_values["pper"] = float( + round( + np.sum(self.epoch_metrics["nb_pp_op_layout"]) + / np.sum(self.epoch_metrics["nb_gt_layout_token"]), + 4, + ) ) elif metric_name == "map_cer_per_class": - value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"]) + value = float( + compute_global_mAP_per_class(self.epoch_metrics["map_cer"]) + ) for key in value.keys(): - display_values["map_cer_" + key] = round(value[key], 4) + display_values["map_cer_" + key] = float(round(value[key], 4)) continue elif metric_name == "layout_precision_per_class_per_threshold": - value = compute_global_precision_per_class_per_threshold( - self.epoch_metrics["map_cer"] + value = float( + compute_global_precision_per_class_per_threshold( + self.epoch_metrics["map_cer"] + ) ) for key_class in value.keys(): for threshold in value[key_class].keys(): display_values[ "map_cer_{}_{}".format(key_class, threshold) - ] = round(value[key_class][threshold], 4) + ] = float(round(value[key_class][threshold], 4)) continue if metric_name == "cer": - value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum( - self.epoch_metrics["nb_chars"] + value = float( + np.sum(self.epoch_metrics["edit_chars"]) + / np.sum(self.epoch_metrics["nb_chars"]) ) if output: - display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"]) + display_values["nb_chars"] = int( + np.sum(self.epoch_metrics["nb_chars"]) + ) elif metric_name == "wer": - value = np.sum(self.epoch_metrics["edit_words"]) / np.sum( - self.epoch_metrics["nb_words"] + value = float( + np.sum(self.epoch_metrics["edit_words"]) + / np.sum(self.epoch_metrics["nb_words"]) ) if output: - display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) + display_values["nb_words"] = int( + np.sum(self.epoch_metrics["nb_words"]) + ) elif metric_name == "wer_no_punct": - value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum( - self.epoch_metrics["nb_words_no_punct"] + value = float( + np.sum(self.epoch_metrics["edit_words_no_punct"]) + / np.sum(self.epoch_metrics["nb_words_no_punct"]) ) if output: - display_values["nb_words_no_punct"] = np.sum( - self.epoch_metrics["nb_words_no_punct"] + display_values["nb_words_no_punct"] = int( + np.sum(self.epoch_metrics["nb_words_no_punct"]) ) elif metric_name in [ "loss", @@ -143,15 +155,18 @@ class MetricManager: "syn_max_lines", "syn_prob_lines", ]: - value = np.average( - self.epoch_metrics[metric_name], - weights=np.array(self.epoch_metrics["nb_samples"]), + value = float( + np.average( + self.epoch_metrics[metric_name], + weights=np.array(self.epoch_metrics["nb_samples"]), + ) ) elif metric_name == "map_cer": - value = compute_global_mAP(self.epoch_metrics[metric_name]) + value = float(compute_global_mAP(self.epoch_metrics[metric_name])) elif metric_name == "loer": - value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum( - self.epoch_metrics["nb_nodes_and_edges"] + value = float( + np.sum(self.epoch_metrics["edit_graph"]) + / np.sum(self.epoch_metrics["nb_nodes_and_edges"]) ) elif value is None: continue diff --git a/dan/manager/training.py b/dan/manager/training.py index dc9267b9..c52cebd6 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -12,6 +12,7 @@ import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp +import yaml from PIL import Image from torch.cuda.amp import GradScaler, autocast from torch.nn import CrossEntropyLoss @@ -865,22 +866,21 @@ class GenericTrainingManager: metrics = self.metric_manager[custom_name].get_display_values(output=True) path = os.path.join( self.paths["results"], - "predict_{}_{}.txt".format(custom_name, self.latest_epoch), + "predict_{}_{}.yaml".format(custom_name, self.latest_epoch), ) with open(path, "w") as f: - for metric_name in metrics.keys(): - f.write("{}: {}\n".format(metric_name, metrics[metric_name])) + yaml.dump(metrics, stream=f) # Log mlflow artifacts mlflow.log_artifact(path, "predictions") def output_pred(self, name): path = os.path.join( - self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch) + self.paths["results"], "pred_{}_{}.yaml".format(name, self.latest_epoch) ) pred = "\n".join(self.metric_manager[name].get("pred")) with open(path, "w") as f: - f.write(pred) + yaml.dump(pred, stream=f) def launch_ddp(self): """ diff --git a/tests/test_training.py b/tests/test_training.py index a674090a..b2c4ebbb 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -2,6 +2,7 @@ import pytest import torch +import yaml from dan.ocr.document.train import train_and_test from tests.conftest import FIXTURES @@ -13,33 +14,33 @@ from tests.conftest import FIXTURES ( "best_0.pt", "last_3.pt", - [ - "nb_chars: 43", - "cer: 1.2791", - "nb_words: 9", - "wer: 1.0", - "nb_words_no_punct: 9", - "wer_no_punct: 1.0", - "nb_samples: 2", - ], - [ - "nb_chars: 41", - "cer: 1.2683", - "nb_words: 9", - "wer: 1.0", - "nb_words_no_punct: 9", - "wer_no_punct: 1.0", - "nb_samples: 2", - ], - [ - "nb_chars: 49", - "cer: 1.1429", - "nb_words: 9", - "wer: 1.0", - "nb_words_no_punct: 9", - "wer_no_punct: 1.0", - "nb_samples: 2", - ], + { + "nb_chars": 43, + "cer": 1.2791, + "nb_words": 9, + "wer": 1.0, + "nb_words_no_punct": 9, + "wer_no_punct": 1.0, + "nb_samples": 2, + }, + { + "nb_chars": 41, + "cer": 1.2683, + "nb_words": 9, + "wer": 1.0, + "nb_words_no_punct": 9, + "wer_no_punct": 1.0, + "nb_samples": 2, + }, + { + "nb_chars": 49, + "cer": 1.1429, + "nb_words": 9, + "wer": 1.0, + "nb_words_no_punct": 9, + "wer_no_punct": 1.0, + "nb_samples": 2, + }, ), ), ) @@ -136,11 +137,12 @@ def test_train_and_test( tmp_path / training_config["training_params"]["output_folder"] / "results" - / f"predict_training-{split_name}_0.txt" - ).open( - "r", - ) as f: - res = f.read().splitlines() + / f"predict_training-{split_name}_0.yaml" + ).open() as f: # Remove the times from the results as they vary - res = [result for result in res if "time" not in result] + res = { + metric: value + for metric, value in yaml.safe_load(f).items() + if "time" not in metric + } assert res == expected_res -- GitLab