From a11abc5b091f9b79fda5e265e31e84fe30dfedf4 Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Mon, 30 Oct 2023 14:24:15 +0100
Subject: [PATCH] Expect specific file name

---
 README.md                                     |  6 ++---
 dan/ocr/predict/inference.py                  | 23 +++++++------------
 docs/usage/predict/index.md                   |  6 ++---
 .../{popp_line_model.pt => model.pt}          |  0
 tests/test_prediction.py                      |  4 ++--
 5 files changed, 16 insertions(+), 23 deletions(-)
 rename tests/data/prediction/{popp_line_model.pt => model.pt} (100%)

diff --git a/README.md b/README.md
index 42e37351..5d50b36c 100644
--- a/README.md
+++ b/README.md
@@ -63,9 +63,9 @@ image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
 
 Then one can initialize and load the trained model with the parameters used during training. The directory passed as parameter should have:
 
-- a single model file `*.pt`,
-- a single charset file `*.pkl`,
-- a single parameters file `*parameters.yml`.
+- a `model.pt` file,
+- a `charset.pkl` file,
+- a `parameters.yml` file corresponding to the `inference_parameters.yml` file generated during training.
 
 ```python
 model_path = "models"
diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index c73c8890..2bea4bee 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -59,21 +59,14 @@ class DAN:
         :param mode: The mode to load the model (train or eval).
         :param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
         """
-        model_path = list(path.glob("*.pt"))
-        assert len(model_path) == 1, f"Found {len(model_path)} model(s) `.pt` in {path}"
-        model_path = model_path.pop()
-
-        params_path = list(path.glob("*parameters.yml"))
-        assert (
-            len(params_path) == 1
-        ), f"Found {len(params_path)} parameter(s) `parameters.yml` in {path}"
-        params_path = params_path.pop()
-
-        charset_path = list(path.glob("*.pkl"))
-        assert (
-            len(charset_path) == 1
-        ), f"Found {len(charset_path)} charset(s) `.pkl` in {path}"
-        charset_path = charset_path.pop()
+        model_path = path / "model.pt"
+        assert model_path.is_file(), f"File {model_path} not found"
+
+        params_path = path / "parameters.yml"
+        assert params_path.is_file(), f"File {params_path} not found"
+
+        charset_path = path / "charset.pkl"
+        assert charset_path.is_file(), f"File {charset_path} not found"
 
         parameters = yaml.safe_load(params_path.read_text())["parameters"]
         parameters["decoder"]["device"] = self.device
diff --git a/docs/usage/predict/index.md b/docs/usage/predict/index.md
index 497d9f37..9ab4edaa 100644
--- a/docs/usage/predict/index.md
+++ b/docs/usage/predict/index.md
@@ -31,9 +31,9 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
 
 In the following examples the `models` directory should have:
 
-- a single model file `*.pt`,
-- a single charset file `*.pkl`,
-- a single parameters file `*parameters.yml`.
+- a `model.pt` file,
+- a `charset.pkl` file,
+- a `parameters.yml` file corresponding to the `inference_parameters.yml` file generated during training.
 
 ### Predict with confidence scores
 
diff --git a/tests/data/prediction/popp_line_model.pt b/tests/data/prediction/model.pt
similarity index 100%
rename from tests/data/prediction/popp_line_model.pt
rename to tests/data/prediction/model.pt
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index f84b3406..ffa7bf70 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -635,8 +635,8 @@ def test_run_prediction_language_model(
     model_path.mkdir(exist_ok=True)
 
     shutil.copyfile(
-        PREDICTION_DATA_PATH / "popp_line_model.pt",
-        model_path / "popp_line_model.pt",
+        PREDICTION_DATA_PATH / "model.pt",
+        model_path / "model.pt",
     )
     shutil.copyfile(
         PREDICTION_DATA_PATH / "charset.pkl",
-- 
GitLab