Skip to content
Snippets Groups Projects
Commit e5418239 authored by Manon Blanco's avatar Manon Blanco
Browse files

Apply suggestions

parent 3b55303a
No related branches found
No related tags found
No related merge requests found
......@@ -49,7 +49,7 @@ class OCRDatasetManager:
self.params["config"]["padding_token"] = self.tokens["pad"]
self.my_collate_function = OCRCollateFunction(self.params["config"])
self.normalization = get_normalization_transforms()
self.normalization = get_normalization_transforms(from_pil_image=True)
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
......
......@@ -83,10 +83,11 @@ class DAN:
def preprocess(self, path):
"""
Preprocess an image.
:param path: Image path
:param path: Path of the image to load and preprocess.
"""
image = read_image(path)
return self.preprocessing_transforms(image)
preprocessed_image = self.preprocessing_transforms(image)
return self.normalization(preprocessed_image)
def predict(
self,
......
......@@ -145,7 +145,9 @@ class ErosionDilation:
return {"image": augmented_image}
def get_preprocessing_transforms(preprocessings: list, to_pil_image=False) -> Compose:
def get_preprocessing_transforms(
preprocessings: list, to_pil_image: bool = False
) -> Compose:
"""
Returns a list of transformations to be applied to the image.
"""
......@@ -195,8 +197,14 @@ def get_augmentation_transforms() -> SomeOf:
)
def get_normalization_transforms() -> Compose:
def get_normalization_transforms(from_pil_image: bool = False) -> Compose:
"""
Returns a list of normalization transformations.
"""
return Compose([ToTensor(), Normalize(IMAGENET_MEAN, IMAGENET_STD)])
transforms = []
if from_pil_image:
transforms.append(ToTensor())
transforms.append(Normalize(IMAGENET_MEAN, IMAGENET_STD))
return Compose(transforms)
......@@ -50,7 +50,7 @@ def pad_images(data):
def read_image(path):
"""
Read image with torch
:param path: Image path
:param path: Path of the image to load.
"""
img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
......
......@@ -103,7 +103,7 @@ def test_predict(
{"text": "ⓁP", "confidence": 0.94},
{"text": "ⒸM", "confidence": 0.93},
{"text": "ⓀCh", "confidence": 0.96},
{"text": "ⓄPlombier", "confidence": 0.93},
{"text": "ⓄPlombier", "confidence": 0.94},
{"text": "ⓅPatron?12241", "confidence": 0.93},
],
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment