diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 45eb1b2246a4f9cefb85a6466325781bbce0700e..8f9573f92a4e82949e8a8f0a677ffafa91c11016 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -49,7 +49,7 @@ class OCRDatasetManager: self.params["config"]["padding_token"] = self.tokens["pad"] self.my_collate_function = OCRCollateFunction(self.params["config"]) - self.normalization = get_normalization_transforms() + self.normalization = get_normalization_transforms(from_pil_image=True) self.augmentation = ( get_augmentation_transforms() if self.params["config"]["augmentation"] diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 0c8f9dd034501239175648230bf25df7ac1daab0..ff87ca3eb720bf015b82c349f3d753caf1421863 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -83,10 +83,11 @@ class DAN: def preprocess(self, path): """ Preprocess an image. - :param path: Image path + :param path: Path of the image to load and preprocess. """ image = read_image(path) - return self.preprocessing_transforms(image) + preprocessed_image = self.preprocessing_transforms(image) + return self.normalization(preprocessed_image) def predict( self, diff --git a/dan/transforms.py b/dan/transforms.py index 7fa6db63947f264841d351bdd03601f422f2080a..b0501e82990fa84cb704b30495e70d3ea0c2582a 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -145,7 +145,9 @@ class ErosionDilation: return {"image": augmented_image} -def get_preprocessing_transforms(preprocessings: list, to_pil_image=False) -> Compose: +def get_preprocessing_transforms( + preprocessings: list, to_pil_image: bool = False +) -> Compose: """ Returns a list of transformations to be applied to the image. """ @@ -195,8 +197,14 @@ def get_augmentation_transforms() -> SomeOf: ) -def get_normalization_transforms() -> Compose: +def get_normalization_transforms(from_pil_image: bool = False) -> Compose: """ Returns a list of normalization transformations. """ - return Compose([ToTensor(), Normalize(IMAGENET_MEAN, IMAGENET_STD)]) + transforms = [] + + if from_pil_image: + transforms.append(ToTensor()) + + transforms.append(Normalize(IMAGENET_MEAN, IMAGENET_STD)) + return Compose(transforms) diff --git a/dan/utils.py b/dan/utils.py index 44b3cd2365a3853a4d54bc3b6903280730f8e2c9..1ec9a8cfcf4207f16894196fe0e9b6c2b22ea82f 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -50,7 +50,7 @@ def pad_images(data): def read_image(path): """ Read image with torch - :param path: Image path + :param path: Path of the image to load. """ img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB) return img.to(dtype=torch.get_default_dtype()).div(255) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index f7ba63175c0f146697061435b720a181bcae8794..db87cdb451d6ef7ea80ceb1afb8b5a698836d552 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -103,7 +103,7 @@ def test_predict( {"text": "â“P", "confidence": 0.94}, {"text": "â’¸M", "confidence": 0.93}, {"text": "â“€Ch", "confidence": 0.96}, - {"text": "â“„Plombier", "confidence": 0.93}, + {"text": "â“„Plombier", "confidence": 0.94}, {"text": "â“…Patron?12241", "confidence": 0.93}, ], },