From 526c783dfd0ef81a15e50704a34c5268c65610aa Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 28 Feb 2024 17:31:10 +0100 Subject: [PATCH] Add new transformation to resize to FixedSize --- dan/ocr/transforms.py | 25 +++++++++++++++++++++++++ docs/usage/train/config.md | 12 ++++++++++++ 2 files changed, 37 insertions(+) diff --git a/dan/ocr/transforms.py b/dan/ocr/transforms.py index 3a8c74f6..09e05b7f 100644 --- a/dan/ocr/transforms.py +++ b/dan/ocr/transforms.py @@ -44,6 +44,11 @@ class Preprocessing(str, Enum): Resize the width to a fixed value while keeping the original ratio """ + FixedResize = "fixed_resize" + """ + Resize both the width and the height to a fixed value + """ + class FixedHeightResize: """ @@ -79,6 +84,19 @@ class FixedWidthResize: return round(self.width * aspect_ratio) +class FixedResize: + """ + Resize an image tensor to a fixed width and height + """ + + def __init__(self, height: int, width: int) -> None: + self.height = height + self.width = width + + def __call__(self, img: Tensor) -> Tensor: + return resize(img, (self.height, self.width), antialias=False) + + class MaxResize: """ Resize an image tensor if it is bigger than the maximum size @@ -179,6 +197,13 @@ def get_preprocessing_transforms( ) case Preprocessing.FixedWidthResize: transforms.append(FixedWidthResize(width=preprocessing["fixed_width"])) + case Preprocessing.FixedResize: + transforms.append( + FixedResize( + height=preprocessing["fixed_height"], + width=preprocessing["fixed_width"], + ) + ) if to_pil_image: transforms.append(ToPILImage()) return Compose(transforms) diff --git a/docs/usage/train/config.md b/docs/usage/train/config.md index 20d18fa8..43da046d 100644 --- a/docs/usage/train/config.md +++ b/docs/usage/train/config.md @@ -169,6 +169,18 @@ Usage: ] ``` +- Resize to a fixed width and a fixed height + +```py +[ + { + "type": "fixed_resize", + "fixed_height": 1900, + "fixed_width": 1250, + } +] +``` + - Resize to a maximum size (only if the image is bigger than the given size) ```py -- GitLab