diff --git a/dan/transforms.py b/dan/transforms.py index b0501e82990fa84cb704b30495e70d3ea0c2582a..ff39dcc692f865b1416bdbde583c5fa50c390fb8 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -119,30 +119,64 @@ class Erosion: return erode(np.array(x), self.kernel, iterations=self.iterations) -class ErosionDilation: +class ErosionDilation(ImageOnlyTransform): """ Random erosion or dilation """ - def __init__(self, min_kernel, max_kernel, iterations, p=1.0): + def __init__( + self, + min_kernel: int, + max_kernel: int, + iterations: int, + always_apply: bool = False, + p: float = 1.0, + ): + super(ErosionDilation, self).__init__(always_apply, p) self.min_kernel = min_kernel self.max_kernel = max_kernel self.iterations = iterations self.p = p self.always_apply = False - def __call__(self, image, force_apply=False): - if not (random.random() <= self.p or self.always_apply or force_apply): - return {"image": image} + def apply(self, img: np.ndarray, **params): kernel_h = randint(self.min_kernel, self.max_kernel) kernel_w = randint(self.min_kernel, self.max_kernel) kernel = np.ones((kernel_h, kernel_w), np.uint8) augmented_image = ( - Erosion(kernel, iterations=self.iterations)(image) + Erosion(kernel, iterations=self.iterations)(img) if random.random() < 0.5 - else Dilation(kernel=kernel, iterations=self.iterations)(image) + else Dilation(kernel=kernel, iterations=self.iterations)(img) ) - return {"image": augmented_image} + return augmented_image + + +class DPIAdjusting(ImageOnlyTransform): + """ + Resolution modification + """ + + def __init__( + self, + min_factor: float = 0.75, + max_factor: float = 1, + always_apply: bool = False, + p: float = 1.0, + ): + super(DPIAdjusting, self).__init__(always_apply, p) + self.min_factor = min_factor + self.max_factor = max_factor + self.p = p + self.always_apply = False + + def apply(self, img: np.ndarray, **params): + factor = float(Uniform(self.min_factor, self.max_factor).sample()) + img = Image.fromarray(img) + augmented_image = img.resize( + (int(np.ceil(img.width * factor)), int(np.ceil(img.height * factor))), + Image.BILINEAR, + ) + return np.array(augmented_image) def get_preprocessing_transforms( @@ -194,6 +228,9 @@ def get_augmentation_transforms() -> SomeOf: ], n=2, p=0.9, + ) + ], + p=0.9 )