diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index d65931f11d1725e9628f137abbd05aac1a28e658..c869d4e1fb7e2d6a3583e360adef5e654079607a 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -109,8 +109,8 @@ class DAN: ) self.mean, self.std = ( - torch.tensor(parameters["mean"]) / 255, - torch.tensor(parameters["std"]) / 255, + torch.tensor(parameters["mean"]) / 255 if "mean" in parameters else None, + torch.tensor(parameters["std"]) / 255 if "std" in parameters else None, ) self.preprocessing_transforms = get_preprocessing_transforms( parameters.get("preprocessings", []) @@ -124,11 +124,21 @@ class DAN: """ image = read_image(path) preprocessed_image = self.preprocessing_transforms(image) - normalized_image = torch.zeros(preprocessed_image.shape) - for ch in range(preprocessed_image.shape[0]): + + if self.mean is None and self.std is None: + return preprocessed_image, preprocessed_image + + size = preprocessed_image.shape + normalized_image = torch.zeros(size) + + mean = self.mean if self.mean is not None else torch.zeros(size[0]) + std = self.std if self.std is not None else torch.ones(size[0]) + + for ch in range(size[0]): normalized_image[ch, :, :] = ( - preprocessed_image[ch, :, :] - self.mean[ch] - ) / self.std[ch] + preprocessed_image[ch, :, :] - mean[ch] + ) / std[ch] + return preprocessed_image, normalized_image def predict( diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 1f9fda21980af3e1f979321cab5ae79c2bd070ec..1fc5423e2bbaf53c231dcb2f1ef57cb2439a5151 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -37,11 +37,31 @@ PREDICTION_DATA_PATH = FIXTURES / "prediction" ), ), ) -def test_predict(image_name, expected_prediction): +@pytest.mark.parametrize("normalize", (True, False)) +def test_predict(image_name, expected_prediction, normalize, tmp_path): + # Update mean/std in parameters.yml + model_path = tmp_path / "models" + model_path.mkdir(exist_ok=True) + + shutil.copyfile( + PREDICTION_DATA_PATH / "model.pt", + model_path / "model.pt", + ) + shutil.copyfile( + PREDICTION_DATA_PATH / "charset.pkl", + model_path / "charset.pkl", + ) + + params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml") + if not normalize: + del params["parameters"]["mean"] + del params["parameters"]["std"] + yaml.dump(params, (model_path / "parameters.yml").open("w")) + device = "cpu" dan_model = DAN(device) - dan_model.load(path=PREDICTION_DATA_PATH, mode="eval") + dan_model.load(path=model_path, mode="eval") image_path = PREDICTION_DATA_PATH / "images" / image_name _, image = dan_model.preprocess(str(image_path))