From 11fb8d1950bb83881fc89f12620869fa602b8fe3 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Thu, 11 May 2023 08:27:17 +0000
Subject: [PATCH] Use tmp folders in the training test

---
 dan/manager/training.py   |  4 +---
 dan/ocr/document/train.py |  2 +-
 dan/ocr/line/train.py     |  2 +-
 tests/test_training.py    | 31 ++++++++++++++++---------------
 4 files changed, 19 insertions(+), 20 deletions(-)

diff --git a/dan/manager/training.py b/dan/manager/training.py
index 03aefe42..47ccf6a1 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -78,9 +78,7 @@ class GenericTrainingManager:
         """
         Create output folders for results and checkpoints
         """
-        output_path = os.path.join(
-            "outputs", self.params["training_params"]["output_folder"]
-        )
+        output_path = self.params["training_params"]["output_folder"]
         os.makedirs(output_path, exist_ok=True)
         checkpoints_path = os.path.join(output_path, "checkpoints")
         os.makedirs(checkpoints_path, exist_ok=True)
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index ed33af04..24216c64 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -177,7 +177,7 @@ def get_config():
             },
         },
         "training_params": {
-            "output_folder": "dan_esposalles_record",  # folder name for checkpoint and results
+            "output_folder": "outputs/dan_esposalles_record",  # folder name for checkpoint and results
             "max_nb_epochs": 710,  # maximum number of epochs before to stop
             "max_training_time": 3600
             * 24
diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py
index 98d92e44..60b1e4ba 100644
--- a/dan/ocr/line/train.py
+++ b/dan/ocr/line/train.py
@@ -148,7 +148,7 @@ def run():
             "dropout": 0.5,
         },
         "training_params": {
-            "output_folder": "FCN_read_2016_line_syn",  # folder names for logs and weights
+            "output_folder": "outputs/FCN_read_2016_line_syn",  # folder names for logs and weights
             "max_nb_epochs": 10000,  # max number of epochs for the training
             "max_training_time": 3600
             * 24
diff --git a/tests/test_training.py b/tests/test_training.py
index 85a01927..a674090a 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -1,7 +1,5 @@
 # -*- coding: utf-8 -*-
 
-from pathlib import Path
-
 import pytest
 import torch
 
@@ -52,19 +50,23 @@ def test_train_and_test(
     val_res,
     test_res,
     training_config,
+    tmp_path,
 ):
+    # Use the tmp_path as base folder
+    training_config["training_params"]["output_folder"] = str(
+        tmp_path / training_config["training_params"]["output_folder"]
+    )
+
     train_and_test(0, training_config)
 
     # Check that the trained model is correct
     for model_name in [expected_best_model_name, expected_last_model_name]:
         expected_model = torch.load(FIXTURES / "training" / "models" / model_name)
         trained_model = torch.load(
-            Path(
-                "outputs",
-                training_config["training_params"]["output_folder"],
-                "checkpoints",
-                model_name,
-            )
+            tmp_path
+            / training_config["training_params"]["output_folder"]
+            / "checkpoints"
+            / model_name,
         )
 
         # Check the optimizers parameters
@@ -130,13 +132,12 @@ def test_train_and_test(
     for split_name, expected_res in zip(
         ["train", "val", "test"], [training_res, val_res, test_res]
     ):
-        with open(
-            Path(
-                "outputs",
-                training_config["training_params"]["output_folder"],
-                "results",
-                f"predict_training-{split_name}_0.txt",
-            ),
+        with (
+            tmp_path
+            / training_config["training_params"]["output_folder"]
+            / "results"
+            / f"predict_training-{split_name}_0.txt"
+        ).open(
             "r",
         ) as f:
             res = f.read().splitlines()
-- 
GitLab