Skip to content
Snippets Groups Projects
Commit a62a2366 authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Make normalization optional

parent 19a0518e
No related branches found
No related tags found
1 merge request!311Make normalization optional
......@@ -109,8 +109,8 @@ class DAN:
)
self.mean, self.std = (
torch.tensor(parameters["mean"]) / 255,
torch.tensor(parameters["std"]) / 255,
torch.tensor(parameters["mean"]) / 255 if "mean" in parameters else None,
torch.tensor(parameters["std"]) / 255 if "std" in parameters else None,
)
self.preprocessing_transforms = get_preprocessing_transforms(
parameters.get("preprocessings", [])
......@@ -124,11 +124,21 @@ class DAN:
"""
image = read_image(path)
preprocessed_image = self.preprocessing_transforms(image)
normalized_image = torch.zeros(preprocessed_image.shape)
for ch in range(preprocessed_image.shape[0]):
if self.mean is None and self.std is None:
return preprocessed_image, preprocessed_image
size = preprocessed_image.shape
normalized_image = torch.zeros(size)
mean = self.mean if self.mean is not None else torch.zeros(size[0])
std = self.std if self.std is not None else torch.ones(size[0])
for ch in range(size[0]):
normalized_image[ch, :, :] = (
preprocessed_image[ch, :, :] - self.mean[ch]
) / self.std[ch]
preprocessed_image[ch, :, :] - mean[ch]
) / std[ch]
return preprocessed_image, normalized_image
def predict(
......
......@@ -37,11 +37,31 @@ PREDICTION_DATA_PATH = FIXTURES / "prediction"
),
),
)
def test_predict(image_name, expected_prediction):
@pytest.mark.parametrize("normalize", (True, False))
def test_predict(image_name, expected_prediction, normalize, tmp_path):
# Update mean/std in parameters.yml
model_path = tmp_path / "models"
model_path.mkdir(exist_ok=True)
shutil.copyfile(
PREDICTION_DATA_PATH / "model.pt",
model_path / "model.pt",
)
shutil.copyfile(
PREDICTION_DATA_PATH / "charset.pkl",
model_path / "charset.pkl",
)
params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
if not normalize:
del params["parameters"]["mean"]
del params["parameters"]["std"]
yaml.dump(params, (model_path / "parameters.yml").open("w"))
device = "cpu"
dan_model = DAN(device)
dan_model.load(path=PREDICTION_DATA_PATH, mode="eval")
dan_model.load(path=model_path, mode="eval")
image_path = PREDICTION_DATA_PATH / "images" / image_name
_, image = dan_model.preprocess(str(image_path))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment