Skip to content
Snippets Groups Projects

Load the model via a path to a folder

Merged Manon Blanco requested to merge load-model-via-path into main
All threads resolved!
Files
5
+ 13
14
@@ -49,20 +49,25 @@ class DAN:
def load(
self,
model_path: Path,
params_path: Path,
charset_path: Path,
path: Path,
mode: str = "eval",
use_language_model: bool = False,
) -> None:
"""
Load a trained model.
:param model_path: Path to the model.
:param params_path: Path to the parameters.
:param charset_path: Path to the charset.
:param path: Path to the directory containing the model, the YAML parameters file and the charset file.
:param mode: The mode to load the model (train or eval).
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
"""
model_path = path / "model.pt"
assert model_path.is_file(), f"File {model_path} not found"
params_path = path / "parameters.yml"
assert params_path.is_file(), f"File {params_path} not found"
charset_path = path / "charset.pkl"
assert charset_path.is_file(), f"File {charset_path} not found"
parameters = yaml.safe_load(params_path.read_text())["parameters"]
parameters["decoder"]["device"] = self.device
@@ -410,8 +415,6 @@ def run(
image: Optional[Path],
image_dir: Optional[Path],
model: Path,
parameters: Path,
charset: Path,
output: Path,
confidence_score: bool,
confidence_score_levels: List[Level],
@@ -434,9 +437,7 @@ def run(
Predict a single image save the output
:param image: Path to the image to predict.
:param image_dir: Path to the folder where the images to predict are stored.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param model: Path to the directory containing the model, the YAML parameters file and the charset file to use for prediction.
:param output: Path to the output folder where the results will be saved.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
@@ -460,9 +461,7 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
dan_model.load(model, mode="eval", use_language_model=use_language_model)
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
Loading