Skip to content
Snippets Groups Projects
Commit 01206c87 authored by Manon Blanco's avatar Manon Blanco Committed by Mélodie Boillet
Browse files

Apply suggestions

parent af2336aa
No related branches found
No related tags found
1 merge request!201Load image using torch + use training pre-processing function during prediction
This commit is part of merge request !201. Comments created here will be created in the context of that merge request.
......@@ -49,7 +49,7 @@ class OCRDatasetManager:
self.params["config"]["padding_token"] = self.tokens["pad"]
self.my_collate_function = OCRCollateFunction(self.params["config"])
self.normalization = get_normalization_transforms()
self.normalization = get_normalization_transforms(from_pil_image=True)
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
......
......@@ -83,10 +83,11 @@ class DAN:
def preprocess(self, path):
"""
Preprocess an image.
:param path: Image path
:param path: Path of the image to load and preprocess.
"""
image = read_image(path)
return self.preprocessing_transforms(image)
preprocessed_image = self.preprocessing_transforms(image)
return self.normalization(preprocessed_image)
def predict(
self,
......
......@@ -145,7 +145,9 @@ class ErosionDilation:
return {"image": augmented_image}
def get_preprocessing_transforms(preprocessings: list, to_pil_image=False) -> Compose:
def get_preprocessing_transforms(
preprocessings: list, to_pil_image: bool = False
) -> Compose:
"""
Returns a list of transformations to be applied to the image.
"""
......@@ -195,8 +197,14 @@ def get_augmentation_transforms() -> SomeOf:
)
def get_normalization_transforms() -> Compose:
def get_normalization_transforms(from_pil_image: bool = False) -> Compose:
"""
Returns a list of normalization transformations.
"""
return Compose([ToTensor(), Normalize(IMAGENET_MEAN, IMAGENET_STD)])
transforms = []
if from_pil_image:
transforms.append(ToTensor())
transforms.append(Normalize(IMAGENET_MEAN, IMAGENET_STD))
return Compose(transforms)
......@@ -50,7 +50,7 @@ def pad_images(data):
def read_image(path):
"""
Read image with torch
:param path: Image path
:param path: Path of the image to load.
"""
img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
......
......@@ -103,7 +103,7 @@ def test_predict(
{"text": "ⓁP", "confidence": 0.94},
{"text": "ⒸM", "confidence": 0.93},
{"text": "ⓀCh", "confidence": 0.96},
{"text": "ⓄPlombier", "confidence": 0.93},
{"text": "ⓄPlombier", "confidence": 0.94},
{"text": "ⓅPatron?12241", "confidence": 0.93},
],
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment