Skip to content
Snippets Groups Projects
Commit 1354a600 authored by Manon Blanco's avatar Manon Blanco
Browse files

Expect specific file name

parent 7f956b79
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !305. Comments created here will be created in the context of that merge request.
...@@ -63,9 +63,9 @@ image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB) ...@@ -63,9 +63,9 @@ image = cv2.cvtColor(cv2.imread(IMAGE_PATH), cv2.COLOR_BGR2RGB)
Then one can initialize and load the trained model with the parameters used during training. The directory passed as parameter should have: Then one can initialize and load the trained model with the parameters used during training. The directory passed as parameter should have:
- a single model file `*.pt`, - a `model.pt` file,
- a single charset file `*.pkl`, - a `charset.pkl` file,
- a single parameters file `*parameters.yml`. - a `parameters.yml` file corresponding to the `inference_parameters.yml` file generated during training.
```python ```python
model_path = "models" model_path = "models"
......
...@@ -59,21 +59,14 @@ class DAN: ...@@ -59,21 +59,14 @@ class DAN:
:param mode: The mode to load the model (train or eval). :param mode: The mode to load the model (train or eval).
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses. :param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
""" """
model_path = list(path.glob("*.pt")) model_path = path / "model.pt"
assert len(model_path) == 1, f"Found {len(model_path)} model(s) `.pt` in {path}" assert model_path.is_file(), f"File {model_path} not found"
model_path = model_path.pop()
params_path = path / "parameters.yml"
params_path = list(path.glob("*parameters.yml")) assert params_path.is_file(), f"File {params_path} not found"
assert (
len(params_path) == 1 charset_path = path / "charset.pkl"
), f"Found {len(params_path)} parameter(s) `parameters.yml` in {path}" assert charset_path.is_file(), f"File {charset_path} not found"
params_path = params_path.pop()
charset_path = list(path.glob("*.pkl"))
assert (
len(charset_path) == 1
), f"Found {len(charset_path)} charset(s) `.pkl` in {path}"
charset_path = charset_path.pop()
parameters = yaml.safe_load(params_path.read_text())["parameters"] parameters = yaml.safe_load(params_path.read_text())["parameters"]
parameters["decoder"]["device"] = self.device parameters["decoder"]["device"] = self.device
......
...@@ -33,9 +33,9 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image. ...@@ -33,9 +33,9 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
In the following examples the `models` directory should have: In the following examples the `models` directory should have:
- a single model file `*.pt`, - a `model.pt` file,
- a single charset file `*.pkl`, - a `charset.pkl` file,
- a single parameters file `*parameters.yml`. - a `parameters.yml` file corresponding to the `inference_parameters.yml` file generated during training.
### Predict with confidence scores ### Predict with confidence scores
......
File moved
...@@ -639,8 +639,8 @@ def test_run_prediction_language_model( ...@@ -639,8 +639,8 @@ def test_run_prediction_language_model(
model_path.mkdir(exist_ok=True) model_path.mkdir(exist_ok=True)
shutil.copyfile( shutil.copyfile(
PREDICTION_DATA_PATH / "popp_line_model.pt", PREDICTION_DATA_PATH / "model.pt",
model_path / "popp_line_model.pt", model_path / "model.pt",
) )
shutil.copyfile( shutil.copyfile(
PREDICTION_DATA_PATH / "charset.pkl", PREDICTION_DATA_PATH / "charset.pkl",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment