diff --git a/dan/transforms.py b/dan/transforms.py index eb13eaaf6ce9a67b0af353113dbb76ccccdbceda..5e9a151482536eace8ae9f8eee92455ed3c9cd67 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -178,9 +178,9 @@ def get_list_augmenters(img, aug_configs, fill_value): if aug_config["type"] == "dpi": valid_factor = False while not valid_factor: - factor = float(Uniform( - aug_config["min_factor"], aug_config["max_factor"] - ).sample()) + factor = float( + Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample() + ) valid_factor = not ( ( "max_width" in aug_config @@ -202,12 +202,12 @@ def get_list_augmenters(img, aug_configs, fill_value): augmenters.append(DPIAdjusting(factor)) elif aug_config["type"] == "zoom_ratio": - ratio_h = float(Uniform( - aug_config["min_ratio_h"], aug_config["max_ratio_h"] - ).sample()) - ratio_w = float(Uniform( - aug_config["min_ratio_w"], aug_config["max_ratio_w"] - ).sample()) + ratio_h = float( + Uniform(aug_config["min_ratio_h"], aug_config["max_ratio_h"]).sample() + ) + ratio_w = float( + Uniform(aug_config["min_ratio_w"], aug_config["max_ratio_w"]).sample() + ) augmenters.append( ZoomRatio( ratio_h=ratio_h, ratio_w=ratio_w, keep_dim=aug_config["keep_dim"] @@ -215,7 +215,9 @@ def get_list_augmenters(img, aug_configs, fill_value): ) elif aug_config["type"] == "perspective": - scale = float(Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample()) + scale = float( + Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample() + ) augmenters.append( RandomPerspective( distortion_scale=scale, @@ -232,12 +234,10 @@ def get_list_augmenters(img, aug_configs, fill_value): ).item() ) // 2 * 2 + 1 sigma = float( - Uniform(aug_config["min_sigma"], aug_config["max_sigma"]) - .sample() + Uniform(aug_config["min_sigma"], aug_config["max_sigma"]).sample() ) alpha = float( - Uniform(aug_config["min_alpha"], aug_config["max_alpha"]) - .sample() + Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample() ) augmenters.append( ElasticDistortion( @@ -281,8 +281,7 @@ def get_list_augmenters(img, aug_configs, fill_value): randint(aug_config["min_kernel"], max_kernel_w + 1, (1,)).item() ) // 2 * 2 + 1 sigma = float( - Uniform(aug_config["min_sigma"], aug_config["max_sigma"]) - .sample() + Uniform(aug_config["min_sigma"], aug_config["max_sigma"]).sample() ) augmenters.append( GaussianBlur(kernel_size=(kernel_w, kernel_h), sigma=sigma) @@ -292,10 +291,12 @@ def get_list_augmenters(img, aug_configs, fill_value): augmenters.append(GaussianNoise(std=aug_config["std"])) elif aug_config["type"] == "sharpen": - alpha = float(Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample()) - strength = float(Uniform( - aug_config["min_strength"], aug_config["max_strength"] - ).sample()) + alpha = float( + Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample() + ) + strength = float( + Uniform(aug_config["min_strength"], aug_config["max_strength"]).sample() + ) augmenters.append(Sharpen(alpha=alpha, strength=strength)) else: