diff --git a/dan/manager/training.py b/dan/manager/training.py index 03aefe429eb60f09a26ac8754676733f511c2436..47ccf6a17e214a051942e1c35dcb6c516eb3109f 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -78,9 +78,7 @@ class GenericTrainingManager: """ Create output folders for results and checkpoints """ - output_path = os.path.join( - "outputs", self.params["training_params"]["output_folder"] - ) + output_path = self.params["training_params"]["output_folder"] os.makedirs(output_path, exist_ok=True) checkpoints_path = os.path.join(output_path, "checkpoints") os.makedirs(checkpoints_path, exist_ok=True) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index ed33af04e1f6205365c8bbbe42a4a905f0e5a3d8..24216c6485e3bd4090d0c7c1f5bc07edaf7752c3 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -177,7 +177,7 @@ def get_config(): }, }, "training_params": { - "output_folder": "dan_esposalles_record", # folder name for checkpoint and results + "output_folder": "outputs/dan_esposalles_record", # folder name for checkpoint and results "max_nb_epochs": 710, # maximum number of epochs before to stop "max_training_time": 3600 * 24 diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py index 98d92e447798be65d563ec9236a0803dd31c3c97..60b1e4bab62fff513f92d52773bba7ea4c56e76c 100644 --- a/dan/ocr/line/train.py +++ b/dan/ocr/line/train.py @@ -148,7 +148,7 @@ def run(): "dropout": 0.5, }, "training_params": { - "output_folder": "FCN_read_2016_line_syn", # folder names for logs and weights + "output_folder": "outputs/FCN_read_2016_line_syn", # folder names for logs and weights "max_nb_epochs": 10000, # max number of epochs for the training "max_training_time": 3600 * 24 diff --git a/tests/test_training.py b/tests/test_training.py index 85a01927766335330f1c0279af857ba40e3fa1bb..a674090ab9701161e5d8938c002fc8759bbf35d2 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -from pathlib import Path - import pytest import torch @@ -52,19 +50,23 @@ def test_train_and_test( val_res, test_res, training_config, + tmp_path, ): + # Use the tmp_path as base folder + training_config["training_params"]["output_folder"] = str( + tmp_path / training_config["training_params"]["output_folder"] + ) + train_and_test(0, training_config) # Check that the trained model is correct for model_name in [expected_best_model_name, expected_last_model_name]: expected_model = torch.load(FIXTURES / "training" / "models" / model_name) trained_model = torch.load( - Path( - "outputs", - training_config["training_params"]["output_folder"], - "checkpoints", - model_name, - ) + tmp_path + / training_config["training_params"]["output_folder"] + / "checkpoints" + / model_name, ) # Check the optimizers parameters @@ -130,13 +132,12 @@ def test_train_and_test( for split_name, expected_res in zip( ["train", "val", "test"], [training_res, val_res, test_res] ): - with open( - Path( - "outputs", - training_config["training_params"]["output_folder"], - "results", - f"predict_training-{split_name}_0.txt", - ), + with ( + tmp_path + / training_config["training_params"]["output_folder"] + / "results" + / f"predict_training-{split_name}_0.txt" + ).open( "r", ) as f: res = f.read().splitlines()