Skip to content
Snippets Groups Projects
Verified Commit 3dfe694e authored by Mélodie's avatar Mélodie Committed by Mélodie Boillet
Browse files

Fix data augmentations

parent f60631d2
No related branches found
No related tags found
1 merge request!224Fix version 0.2.0-dev3 and later
......@@ -119,30 +119,64 @@ class Erosion:
return erode(np.array(x), self.kernel, iterations=self.iterations)
class ErosionDilation:
class ErosionDilation(ImageOnlyTransform):
"""
Random erosion or dilation
"""
def __init__(self, min_kernel, max_kernel, iterations, p=1.0):
def __init__(
self,
min_kernel: int,
max_kernel: int,
iterations: int,
always_apply: bool = False,
p: float = 1.0,
):
super(ErosionDilation, self).__init__(always_apply, p)
self.min_kernel = min_kernel
self.max_kernel = max_kernel
self.iterations = iterations
self.p = p
self.always_apply = False
def __call__(self, image, force_apply=False):
if not (random.random() <= self.p or self.always_apply or force_apply):
return {"image": image}
def apply(self, img: np.ndarray, **params):
kernel_h = randint(self.min_kernel, self.max_kernel)
kernel_w = randint(self.min_kernel, self.max_kernel)
kernel = np.ones((kernel_h, kernel_w), np.uint8)
augmented_image = (
Erosion(kernel, iterations=self.iterations)(image)
Erosion(kernel, iterations=self.iterations)(img)
if random.random() < 0.5
else Dilation(kernel=kernel, iterations=self.iterations)(image)
else Dilation(kernel=kernel, iterations=self.iterations)(img)
)
return {"image": augmented_image}
return augmented_image
class DPIAdjusting(ImageOnlyTransform):
"""
Resolution modification
"""
def __init__(
self,
min_factor: float = 0.75,
max_factor: float = 1,
always_apply: bool = False,
p: float = 1.0,
):
super(DPIAdjusting, self).__init__(always_apply, p)
self.min_factor = min_factor
self.max_factor = max_factor
self.p = p
self.always_apply = False
def apply(self, img: np.ndarray, **params):
factor = float(Uniform(self.min_factor, self.max_factor).sample())
img = Image.fromarray(img)
augmented_image = img.resize(
(int(np.ceil(img.width * factor)), int(np.ceil(img.height * factor))),
Image.BILINEAR,
)
return np.array(augmented_image)
def get_preprocessing_transforms(
......@@ -194,6 +228,9 @@ def get_augmentation_transforms() -> SomeOf:
],
n=2,
p=0.9,
)
],
p=0.9
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment