From 11fb8d1950bb83881fc89f12620869fa602b8fe3 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Thu, 11 May 2023 08:27:17 +0000 Subject: [PATCH] Use tmp folders in the training test --- dan/manager/training.py | 4 +--- dan/ocr/document/train.py | 2 +- dan/ocr/line/train.py | 2 +- tests/test_training.py | 31 ++++++++++++++++--------------- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/dan/manager/training.py b/dan/manager/training.py index 03aefe42..47ccf6a1 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -78,9 +78,7 @@ class GenericTrainingManager: """ Create output folders for results and checkpoints """ - output_path = os.path.join( - "outputs", self.params["training_params"]["output_folder"] - ) + output_path = self.params["training_params"]["output_folder"] os.makedirs(output_path, exist_ok=True) checkpoints_path = os.path.join(output_path, "checkpoints") os.makedirs(checkpoints_path, exist_ok=True) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index ed33af04..24216c64 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -177,7 +177,7 @@ def get_config(): }, }, "training_params": { - "output_folder": "dan_esposalles_record", # folder name for checkpoint and results + "output_folder": "outputs/dan_esposalles_record", # folder name for checkpoint and results "max_nb_epochs": 710, # maximum number of epochs before to stop "max_training_time": 3600 * 24 diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py index 98d92e44..60b1e4ba 100644 --- a/dan/ocr/line/train.py +++ b/dan/ocr/line/train.py @@ -148,7 +148,7 @@ def run(): "dropout": 0.5, }, "training_params": { - "output_folder": "FCN_read_2016_line_syn", # folder names for logs and weights + "output_folder": "outputs/FCN_read_2016_line_syn", # folder names for logs and weights "max_nb_epochs": 10000, # max number of epochs for the training "max_training_time": 3600 * 24 diff --git a/tests/test_training.py b/tests/test_training.py index 85a01927..a674090a 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -from pathlib import Path - import pytest import torch @@ -52,19 +50,23 @@ def test_train_and_test( val_res, test_res, training_config, + tmp_path, ): + # Use the tmp_path as base folder + training_config["training_params"]["output_folder"] = str( + tmp_path / training_config["training_params"]["output_folder"] + ) + train_and_test(0, training_config) # Check that the trained model is correct for model_name in [expected_best_model_name, expected_last_model_name]: expected_model = torch.load(FIXTURES / "training" / "models" / model_name) trained_model = torch.load( - Path( - "outputs", - training_config["training_params"]["output_folder"], - "checkpoints", - model_name, - ) + tmp_path + / training_config["training_params"]["output_folder"] + / "checkpoints" + / model_name, ) # Check the optimizers parameters @@ -130,13 +132,12 @@ def test_train_and_test( for split_name, expected_res in zip( ["train", "val", "test"], [training_res, val_res, test_res] ): - with open( - Path( - "outputs", - training_config["training_params"]["output_folder"], - "results", - f"predict_training-{split_name}_0.txt", - ), + with ( + tmp_path + / training_config["training_params"]["output_folder"] + / "results" + / f"predict_training-{split_name}_0.txt" + ).open( "r", ) as f: res = f.read().splitlines() -- GitLab