Skip to content
Snippets Groups Projects

Fix version 0.2.0-dev3 and later

Merged Mélodie Boillet requested to merge fix-dev3 into main
All threads resolved!
1 file
+ 45
8
Compare changes
  • Side-by-side
  • Inline
+ 45
8
@@ -119,30 +119,64 @@ class Erosion:
return erode(np.array(x), self.kernel, iterations=self.iterations)
class ErosionDilation:
class ErosionDilation(ImageOnlyTransform):
"""
Random erosion or dilation
"""
def __init__(self, min_kernel, max_kernel, iterations, p=1.0):
def __init__(
self,
min_kernel: int,
max_kernel: int,
iterations: int,
always_apply: bool = False,
p: float = 1.0,
):
super(ErosionDilation, self).__init__(always_apply, p)
self.min_kernel = min_kernel
self.max_kernel = max_kernel
self.iterations = iterations
self.p = p
self.always_apply = False
def __call__(self, image, force_apply=False):
if not (random.random() <= self.p or self.always_apply or force_apply):
return {"image": image}
def apply(self, img: np.ndarray, **params):
kernel_h = randint(self.min_kernel, self.max_kernel)
kernel_w = randint(self.min_kernel, self.max_kernel)
kernel = np.ones((kernel_h, kernel_w), np.uint8)
augmented_image = (
Erosion(kernel, iterations=self.iterations)(image)
Erosion(kernel, iterations=self.iterations)(img)
if random.random() < 0.5
else Dilation(kernel=kernel, iterations=self.iterations)(image)
else Dilation(kernel=kernel, iterations=self.iterations)(img)
)
return {"image": augmented_image}
return augmented_image
class DPIAdjusting(ImageOnlyTransform):
"""
Resolution modification
"""
def __init__(
self,
min_factor: float = 0.75,
max_factor: float = 1,
always_apply: bool = False,
p: float = 1.0,
):
super(DPIAdjusting, self).__init__(always_apply, p)
self.min_factor = min_factor
self.max_factor = max_factor
self.p = p
self.always_apply = False
def apply(self, img: np.ndarray, **params):
factor = float(Uniform(self.min_factor, self.max_factor).sample())
img = Image.fromarray(img)
augmented_image = img.resize(
(int(np.ceil(img.width * factor)), int(np.ceil(img.height * factor))),
Image.BILINEAR,
)
return np.array(augmented_image)
def get_preprocessing_transforms(
@@ -194,6 +228,9 @@ def get_augmentation_transforms() -> SomeOf:
],
n=2,
p=0.9,
)
],
p=0.9
)
Loading