From 62f0f47ca7e9784addb64fbb619d057ac6a0d2bc Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Thu, 13 Jul 2023 09:50:19 +0200 Subject: [PATCH] Apply suggestions --- dan/manager/ocr.py | 2 +- dan/predict/prediction.py | 5 +++-- dan/transforms.py | 14 +++++++++++--- dan/utils.py | 2 +- tests/test_prediction.py | 2 +- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 4c0ffd38..fb79581f 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -49,7 +49,7 @@ class OCRDatasetManager: self.params["config"]["padding_token"] = self.tokens["pad"] self.my_collate_function = OCRCollateFunction(self.params["config"]) - self.normalization = get_normalization_transforms() + self.normalization = get_normalization_transforms(from_pil_image=True) self.augmentation = ( get_augmentation_transforms() if self.params["config"]["augmentation"] diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 0c8f9dd0..ff87ca3e 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -83,10 +83,11 @@ class DAN: def preprocess(self, path): """ Preprocess an image. - :param path: Image path + :param path: Path of the image to load and preprocess. """ image = read_image(path) - return self.preprocessing_transforms(image) + preprocessed_image = self.preprocessing_transforms(image) + return self.normalization(preprocessed_image) def predict( self, diff --git a/dan/transforms.py b/dan/transforms.py index 7fa6db63..b0501e82 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -145,7 +145,9 @@ class ErosionDilation: return {"image": augmented_image} -def get_preprocessing_transforms(preprocessings: list, to_pil_image=False) -> Compose: +def get_preprocessing_transforms( + preprocessings: list, to_pil_image: bool = False +) -> Compose: """ Returns a list of transformations to be applied to the image. """ @@ -195,8 +197,14 @@ def get_augmentation_transforms() -> SomeOf: ) -def get_normalization_transforms() -> Compose: +def get_normalization_transforms(from_pil_image: bool = False) -> Compose: """ Returns a list of normalization transformations. """ - return Compose([ToTensor(), Normalize(IMAGENET_MEAN, IMAGENET_STD)]) + transforms = [] + + if from_pil_image: + transforms.append(ToTensor()) + + transforms.append(Normalize(IMAGENET_MEAN, IMAGENET_STD)) + return Compose(transforms) diff --git a/dan/utils.py b/dan/utils.py index 44b3cd23..1ec9a8cf 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -50,7 +50,7 @@ def pad_images(data): def read_image(path): """ Read image with torch - :param path: Image path + :param path: Path of the image to load. """ img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB) return img.to(dtype=torch.get_default_dtype()).div(255) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index f7ba6317..db87cdb4 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -103,7 +103,7 @@ def test_predict( {"text": "â“P", "confidence": 0.94}, {"text": "â’¸M", "confidence": 0.93}, {"text": "â“€Ch", "confidence": 0.96}, - {"text": "â“„Plombier", "confidence": 0.93}, + {"text": "â“„Plombier", "confidence": 0.94}, {"text": "â“…Patron?12241", "confidence": 0.93}, ], }, -- GitLab