From 52bc1f2a055bb0e4b09e338f609509770fb8c53b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie?= <melo.boillet@gmail.com> Date: Tue, 25 Jul 2023 13:36:44 +0200 Subject: [PATCH] Float all Uniform operations --- dan/transforms.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/dan/transforms.py b/dan/transforms.py index 2c596e3c..eb13eaaf 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -178,9 +178,9 @@ def get_list_augmenters(img, aug_configs, fill_value): if aug_config["type"] == "dpi": valid_factor = False while not valid_factor: - factor = Uniform( + factor = float(Uniform( aug_config["min_factor"], aug_config["max_factor"] - ).sample() + ).sample()) valid_factor = not ( ( "max_width" in aug_config @@ -202,12 +202,12 @@ def get_list_augmenters(img, aug_configs, fill_value): augmenters.append(DPIAdjusting(factor)) elif aug_config["type"] == "zoom_ratio": - ratio_h = Uniform( + ratio_h = float(Uniform( aug_config["min_ratio_h"], aug_config["max_ratio_h"] - ).sample() - ratio_w = Uniform( + ).sample()) + ratio_w = float(Uniform( aug_config["min_ratio_w"], aug_config["max_ratio_w"] - ).sample() + ).sample()) augmenters.append( ZoomRatio( ratio_h=ratio_h, ratio_w=ratio_w, keep_dim=aug_config["keep_dim"] @@ -215,7 +215,7 @@ def get_list_augmenters(img, aug_configs, fill_value): ) elif aug_config["type"] == "perspective": - scale = Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample() + scale = float(Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample()) augmenters.append( RandomPerspective( distortion_scale=scale, @@ -231,15 +231,13 @@ def get_list_augmenters(img, aug_configs, fill_value): aug_config["min_kernel_size"], aug_config["max_kernel_size"], (1,) ).item() ) // 2 * 2 + 1 - sigma = ( + sigma = float( Uniform(aug_config["min_sigma"], aug_config["max_sigma"]) .sample() - .item() ) - alpha = ( + alpha = float( Uniform(aug_config["min_alpha"], aug_config["max_alpha"]) .sample() - .item() ) augmenters.append( ElasticDistortion( @@ -282,10 +280,9 @@ def get_list_augmenters(img, aug_configs, fill_value): kernel_w = ( randint(aug_config["min_kernel"], max_kernel_w + 1, (1,)).item() ) // 2 * 2 + 1 - sigma = ( + sigma = float( Uniform(aug_config["min_sigma"], aug_config["max_sigma"]) .sample() - .item() ) augmenters.append( GaussianBlur(kernel_size=(kernel_w, kernel_h), sigma=sigma) @@ -295,10 +292,10 @@ def get_list_augmenters(img, aug_configs, fill_value): augmenters.append(GaussianNoise(std=aug_config["std"])) elif aug_config["type"] == "sharpen": - alpha = Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample() - strength = Uniform( + alpha = float(Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample()) + strength = float(Uniform( aug_config["min_strength"], aug_config["max_strength"] - ).sample() + ).sample()) augmenters.append(Sharpen(alpha=alpha, strength=strength)) else: -- GitLab