From 52bc1f2a055bb0e4b09e338f609509770fb8c53b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie?= <melo.boillet@gmail.com>
Date: Tue, 25 Jul 2023 13:36:44 +0200
Subject: [PATCH] Float all Uniform operations

---
 dan/transforms.py | 29 +++++++++++++----------------
 1 file changed, 13 insertions(+), 16 deletions(-)

diff --git a/dan/transforms.py b/dan/transforms.py
index 2c596e3c..eb13eaaf 100644
--- a/dan/transforms.py
+++ b/dan/transforms.py
@@ -178,9 +178,9 @@ def get_list_augmenters(img, aug_configs, fill_value):
         if aug_config["type"] == "dpi":
             valid_factor = False
             while not valid_factor:
-                factor = Uniform(
+                factor = float(Uniform(
                     aug_config["min_factor"], aug_config["max_factor"]
-                ).sample()
+                ).sample())
                 valid_factor = not (
                     (
                         "max_width" in aug_config
@@ -202,12 +202,12 @@ def get_list_augmenters(img, aug_configs, fill_value):
             augmenters.append(DPIAdjusting(factor))
 
         elif aug_config["type"] == "zoom_ratio":
-            ratio_h = Uniform(
+            ratio_h = float(Uniform(
                 aug_config["min_ratio_h"], aug_config["max_ratio_h"]
-            ).sample()
-            ratio_w = Uniform(
+            ).sample())
+            ratio_w = float(Uniform(
                 aug_config["min_ratio_w"], aug_config["max_ratio_w"]
-            ).sample()
+            ).sample())
             augmenters.append(
                 ZoomRatio(
                     ratio_h=ratio_h, ratio_w=ratio_w, keep_dim=aug_config["keep_dim"]
@@ -215,7 +215,7 @@ def get_list_augmenters(img, aug_configs, fill_value):
             )
 
         elif aug_config["type"] == "perspective":
-            scale = Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample()
+            scale = float(Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample())
             augmenters.append(
                 RandomPerspective(
                     distortion_scale=scale,
@@ -231,15 +231,13 @@ def get_list_augmenters(img, aug_configs, fill_value):
                     aug_config["min_kernel_size"], aug_config["max_kernel_size"], (1,)
                 ).item()
             ) // 2 * 2 + 1
-            sigma = (
+            sigma = float(
                 Uniform(aug_config["min_sigma"], aug_config["max_sigma"])
                 .sample()
-                .item()
             )
-            alpha = (
+            alpha = float(
                 Uniform(aug_config["min_alpha"], aug_config["max_alpha"])
                 .sample()
-                .item()
             )
             augmenters.append(
                 ElasticDistortion(
@@ -282,10 +280,9 @@ def get_list_augmenters(img, aug_configs, fill_value):
             kernel_w = (
                 randint(aug_config["min_kernel"], max_kernel_w + 1, (1,)).item()
             ) // 2 * 2 + 1
-            sigma = (
+            sigma = float(
                 Uniform(aug_config["min_sigma"], aug_config["max_sigma"])
                 .sample()
-                .item()
             )
             augmenters.append(
                 GaussianBlur(kernel_size=(kernel_w, kernel_h), sigma=sigma)
@@ -295,10 +292,10 @@ def get_list_augmenters(img, aug_configs, fill_value):
             augmenters.append(GaussianNoise(std=aug_config["std"]))
 
         elif aug_config["type"] == "sharpen":
-            alpha = Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample()
-            strength = Uniform(
+            alpha = float(Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample())
+            strength = float(Uniform(
                 aug_config["min_strength"], aug_config["max_strength"]
-            ).sample()
+            ).sample())
             augmenters.append(Sharpen(alpha=alpha, strength=strength))
 
         else:
-- 
GitLab