Skip to content
Snippets Groups Projects
Verified Commit 3dfe694e authored by Mélodie's avatar Mélodie Committed by Mélodie Boillet
Browse files

Fix data augmentations

parent f60631d2
No related branches found
No related tags found
1 merge request!224Fix version 0.2.0-dev3 and later
...@@ -119,30 +119,64 @@ class Erosion: ...@@ -119,30 +119,64 @@ class Erosion:
return erode(np.array(x), self.kernel, iterations=self.iterations) return erode(np.array(x), self.kernel, iterations=self.iterations)
class ErosionDilation: class ErosionDilation(ImageOnlyTransform):
""" """
Random erosion or dilation Random erosion or dilation
""" """
def __init__(self, min_kernel, max_kernel, iterations, p=1.0): def __init__(
self,
min_kernel: int,
max_kernel: int,
iterations: int,
always_apply: bool = False,
p: float = 1.0,
):
super(ErosionDilation, self).__init__(always_apply, p)
self.min_kernel = min_kernel self.min_kernel = min_kernel
self.max_kernel = max_kernel self.max_kernel = max_kernel
self.iterations = iterations self.iterations = iterations
self.p = p self.p = p
self.always_apply = False self.always_apply = False
def __call__(self, image, force_apply=False): def apply(self, img: np.ndarray, **params):
if not (random.random() <= self.p or self.always_apply or force_apply):
return {"image": image}
kernel_h = randint(self.min_kernel, self.max_kernel) kernel_h = randint(self.min_kernel, self.max_kernel)
kernel_w = randint(self.min_kernel, self.max_kernel) kernel_w = randint(self.min_kernel, self.max_kernel)
kernel = np.ones((kernel_h, kernel_w), np.uint8) kernel = np.ones((kernel_h, kernel_w), np.uint8)
augmented_image = ( augmented_image = (
Erosion(kernel, iterations=self.iterations)(image) Erosion(kernel, iterations=self.iterations)(img)
if random.random() < 0.5 if random.random() < 0.5
else Dilation(kernel=kernel, iterations=self.iterations)(image) else Dilation(kernel=kernel, iterations=self.iterations)(img)
) )
return {"image": augmented_image} return augmented_image
class DPIAdjusting(ImageOnlyTransform):
"""
Resolution modification
"""
def __init__(
self,
min_factor: float = 0.75,
max_factor: float = 1,
always_apply: bool = False,
p: float = 1.0,
):
super(DPIAdjusting, self).__init__(always_apply, p)
self.min_factor = min_factor
self.max_factor = max_factor
self.p = p
self.always_apply = False
def apply(self, img: np.ndarray, **params):
factor = float(Uniform(self.min_factor, self.max_factor).sample())
img = Image.fromarray(img)
augmented_image = img.resize(
(int(np.ceil(img.width * factor)), int(np.ceil(img.height * factor))),
Image.BILINEAR,
)
return np.array(augmented_image)
def get_preprocessing_transforms( def get_preprocessing_transforms(
...@@ -194,6 +228,9 @@ def get_augmentation_transforms() -> SomeOf: ...@@ -194,6 +228,9 @@ def get_augmentation_transforms() -> SomeOf:
], ],
n=2, n=2,
p=0.9, p=0.9,
)
],
p=0.9
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment