From 782250788dbd354aaeaf09e8b62bb052097e3702 Mon Sep 17 00:00:00 2001
From: Bastien Abadie <bastien@nextcairn.com>
Date: Tue, 12 Jul 2022 11:12:55 +0200
Subject: [PATCH 1/3] Auto generated fixes

---
 .isort.cfg                               |   2 +-
 training/README.md                       |   2 +-
 training/evaluate.py                     | 112 ++++++---
 training/model_params.py                 |  10 +-
 training/normalization_params.py         |  37 +--
 training/notify-slack.py                 |  48 ++--
 training/predict.py                      | 116 ++++++---
 training/retrieve_experiments_configs.py |  90 ++++---
 training/run_dla_experiment.sh           |   4 +-
 training/run_experiment.py               | 295 ++++++++++++++---------
 training/train.py                        | 183 +++++++++-----
 training/utils/evaluation_utils.py       | 185 ++++++++------
 training/utils/model.py                  |  91 ++++---
 training/utils/object_metrics.py         |  91 ++++---
 training/utils/params_config.py          |  22 +-
 training/utils/pixel_metrics.py          |  21 +-
 training/utils/prediction_utils.py       |  58 +++--
 training/utils/preprocessing.py          | 110 +++++----
 training/utils/training_pixel_metrics.py |  21 +-
 training/utils/training_utils.py         |  81 ++++---
 training/utils/utils.py                  |  94 +++++---
 21 files changed, 1018 insertions(+), 655 deletions(-)

diff --git a/.isort.cfg b/.isort.cfg
index b58e6a1..60de553 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -1,2 +1,2 @@
 [settings]
-known_third_party = cv2,numpy,pytest,requests,setuptools,torch,yaml
+known_third_party = cv2,evaluate,imageio,matplotlib,normalization_params,numpy,predict,pytest,requests,sacred,setuptools,shapely,torch,torchsummary,torchvision,tqdm,train,utils,yaml
diff --git a/training/README.md b/training/README.md
index bc67b9f..40b6d6b 100644
--- a/training/README.md
+++ b/training/README.md
@@ -85,7 +85,7 @@ In the root directory, one has to create an `experiments.csv` file (see `example
 | `restore_model`   | Name of a saved model to resume or fine-tune a training                                          |                                                    |
 | `loss`            | Whether to use an initial loss (`initial`) or the best (`best`) saved loss of the restored model | `initial`                                          |
 
-Note: All the steps are dependant, e.g to run the `"prediction"` step, one **needs** the results of the `"normalization_params"` and `"train"` steps.
+Note: All the steps are dependent, e.g to run the `"prediction"` step, one **needs** the results of the `"normalization_params"` and `"train"` steps.
 
 #### Example
 
diff --git a/training/evaluate.py b/training/evaluate.py
index cdcdb5a..20e5df3 100755
--- a/training/evaluate.py
+++ b/training/evaluate.py
@@ -8,21 +8,28 @@
     Use it to evaluation a trained network.
 """
 
-import os
 import logging
+import os
 import time
+
 import cv2
 import numpy as np
-from tqdm import tqdm
-from shapely.geometry import Polygon
 import torch
 import utils.evaluation_utils as ev_utils
-import utils.pixel_metrics as p_metrics
 import utils.object_metrics as o_metrics
+import utils.pixel_metrics as p_metrics
+from shapely.geometry import Polygon
+from tqdm import tqdm
 
 
-def run(log_path: str, classes_names: list, set: str, data_paths: dict,
-        dataset: str, params: dict):
+def run(
+    log_path: str,
+    classes_names: list,
+    set: str,
+    data_paths: dict,
+    dataset: str,
+    params: dict,
+):
     """
     Run the evaluation.
     :param log_path: Path to save the evaluation results and load the model.
@@ -33,58 +40,83 @@ def run(log_path: str, classes_names: list, set: str, data_paths: dict,
     :param params: The evaluation parameters.
     """
     # Run evaluation.
-    logging.info('Starting evaluation: '+dataset)
+    logging.info("Starting evaluation: " + dataset)
     starting_time = time.time()
 
     label_dir = [dir for dir in data_paths if dataset in str(dir)][0]
-    
-    pixel_metrics = {channel: {metric: [] for metric in ['iou', 'precision', 'recall', 'fscore']} for channel in classes_names[1:]}
-    object_metrics = {channel: {metric: {} for metric in ['precision', 'recall', 'fscore', 'AP']} for channel in classes_names[1:]}
+
+    pixel_metrics = {
+        channel: {metric: [] for metric in ["iou", "precision", "recall", "fscore"]}
+        for channel in classes_names[1:]
+    }
+    object_metrics = {
+        channel: {metric: {} for metric in ["precision", "recall", "fscore", "AP"]}
+        for channel in classes_names[1:]
+    }
     rank_scores = {
-       channel: {
-            iou: {
-                rank: {'True': 0, 'Total': 0} for rank in range(95, -5, -5)
-            } for iou in range(50, 100, 5)
-        } for channel in classes_names[1:]}
+        channel: {
+            iou: {rank: {"True": 0, "Total": 0} for rank in range(95, -5, -5)}
+            for iou in range(50, 100, 5)
+        }
+        for channel in classes_names[1:]
+    }
     number_of_gt = {channel: 0 for channel in classes_names[1:]}
-    for img_name in tqdm(os.listdir(label_dir), desc="Evaluation (prog) "+set):
+    for img_name in tqdm(os.listdir(label_dir), desc="Evaluation (prog) " + set):
         gt_regions = ev_utils.read_json(os.path.join(label_dir, img_name))
-        pred_regions = ev_utils.read_json(os.path.join(log_path, params.prediction_path, set, dataset, img_name))
-        assert(gt_regions['img_size'] == pred_regions['img_size'])
+        pred_regions = ev_utils.read_json(
+            os.path.join(log_path, params.prediction_path, set, dataset, img_name)
+        )
+        assert gt_regions["img_size"] == pred_regions["img_size"]
         gt_polys = ev_utils.get_polygons(gt_regions, classes_names)
         pred_polys = ev_utils.get_polygons(pred_regions, classes_names)
 
-        pixel_metrics = p_metrics.compute_metrics(gt_polys, pred_polys, classes_names[1:], pixel_metrics)
+        pixel_metrics = p_metrics.compute_metrics(
+            gt_polys, pred_polys, classes_names[1:], pixel_metrics
+        )
+
+        image_rank_scores = o_metrics.compute_rank_scores(
+            gt_polys, pred_polys, classes_names[1:]
+        )
+        rank_scores = o_metrics.update_rank_scores(
+            rank_scores, image_rank_scores, classes_names[1:]
+        )
+        number_of_gt = {
+            channel: number_of_gt[channel] + len(gt_polys[channel])
+            for channel in classes_names[1:]
+        }
 
-        image_rank_scores = o_metrics.compute_rank_scores(gt_polys, pred_polys, classes_names[1:])
-        rank_scores = o_metrics.update_rank_scores(rank_scores, image_rank_scores, classes_names[1:])
-        number_of_gt = {channel: number_of_gt[channel]+len(gt_polys[channel]) for channel in classes_names[1:]}
-        
-    object_metrics = o_metrics.get_mean_results(rank_scores, number_of_gt, classes_names[1:], object_metrics)
+    object_metrics = o_metrics.get_mean_results(
+        rank_scores, number_of_gt, classes_names[1:], object_metrics
+    )
 
     # Print the results.
     print(set)
     for channel in classes_names[1:]:
         print(channel)
-        print('IoU       = ', np.round(np.mean(pixel_metrics[channel]['iou']), 4))
-        print('Precision = ', np.round(np.mean(pixel_metrics[channel]['precision']), 4))
-        print('Recall    = ', np.round(np.mean(pixel_metrics[channel]['recall']), 4))
-        print('F-score   = ', np.round(np.mean(pixel_metrics[channel]['fscore']), 4))
+        print("IoU       = ", np.round(np.mean(pixel_metrics[channel]["iou"]), 4))
+        print("Precision = ", np.round(np.mean(pixel_metrics[channel]["precision"]), 4))
+        print("Recall    = ", np.round(np.mean(pixel_metrics[channel]["recall"]), 4))
+        print("F-score   = ", np.round(np.mean(pixel_metrics[channel]["fscore"]), 4))
 
-        aps = object_metrics[channel]['AP']
-        print('AP [IOU=0.50] = ', np.round(aps[50], 4))
-        print('AP [IOU=0.75] = ', np.round(aps[75], 4))
-        print('AP [IOU=0.95] = ', np.round(aps[95], 4))
-        print('AP [0.5,0.95] = ', np.round(np.mean(list(aps.values())), 4))
-        print('\n')
+        aps = object_metrics[channel]["AP"]
+        print("AP [IOU=0.50] = ", np.round(aps[50], 4))
+        print("AP [IOU=0.75] = ", np.round(aps[75], 4))
+        print("AP [IOU=0.95] = ", np.round(aps[95], 4))
+        print("AP [0.5,0.95] = ", np.round(np.mean(list(aps.values())), 4))
+        print("\n")
 
     os.makedirs(os.path.join(log_path, params.evaluation_path, set), exist_ok=True)
-    #ev_utils.save_graphical_results(object_metrics, classes_names[1:],
+    # ev_utils.save_graphical_results(object_metrics, classes_names[1:],
     #                                os.path.join(log_path, params.evaluation_path, set))
-    ev_utils.save_results(pixel_metrics, object_metrics, classes_names[1:],
-                          os.path.join(log_path, params.evaluation_path, set), dataset)
+    ev_utils.save_results(
+        pixel_metrics,
+        object_metrics,
+        classes_names[1:],
+        os.path.join(log_path, params.evaluation_path, set),
+        dataset,
+    )
 
     end = time.gmtime(time.time() - starting_time)
-    logging.info('Finished evaluating in %2d:%2d:%2d',
-                 end.tm_hour, end.tm_min, end.tm_sec)
-
+    logging.info(
+        "Finished evaluating in %2d:%2d:%2d", end.tm_hour, end.tm_min, end.tm_sec
+    )
diff --git a/training/model_params.py b/training/model_params.py
index cfb5487..69e9d74 100755
--- a/training/model_params.py
+++ b/training/model_params.py
@@ -13,14 +13,16 @@
 """
 
 import logging
+
 import torch
 from sacred import Experiment
 from torchsummary import summary
 from utils.model import Net
 
-ex = Experiment('Get model parameters')
-logging.basicConfig(level=logging.INFO,
-                    format='%(asctime)s - %(levelname)s - %(message)s')
+ex = Experiment("Get model parameters")
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
 
 
 @ex.config
@@ -29,7 +31,7 @@ def default_config():
     Define the default configuration for the experiment.
     """
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-    logging.info('Running on %s', device)
+    logging.info("Running on %s", device)
     img_size = 768
     no_of_classes = 2
 
diff --git a/training/normalization_params.py b/training/normalization_params.py
index aa04ae9..9b6b7a2 100755
--- a/training/normalization_params.py
+++ b/training/normalization_params.py
@@ -8,14 +8,16 @@
     Use it to get the mean and standard deviation of the training set.
 """
 
-import os
 import logging
+import os
+
 import numpy as np
-from tqdm import tqdm
-from torchvision import transforms
+import utils.preprocessing as pprocessing
 from torch.utils.data import DataLoader
+from torchvision import transforms
+from tqdm import tqdm
 from utils.params_config import Params
-import utils.preprocessing as pprocessing
+
 
 def run(log_path: str, data_paths: dict, params: Params, img_size: int):
     """
@@ -27,30 +29,31 @@ def run(log_path: str, data_paths: dict, params: Params, img_size: int):
     :param img_size: The network input image size.
     """
     dataset = pprocessing.PredictionDataset(
-        data_paths['train']['image'],
-        transform=transforms.Compose([pprocessing.Rescale(img_size),
-                                      pprocessing.ToTensor()]))
-    loader = DataLoader(dataset, batch_size=1,
-                        shuffle=False, num_workers=2)
+        data_paths["train"]["image"],
+        transform=transforms.Compose(
+            [pprocessing.Rescale(img_size), pprocessing.ToTensor()]
+        ),
+    )
+    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
 
     # Compute mean and std.
     mean = []
     std = []
     for data in tqdm(loader, desc="Computing parameters (prog)"):
-        image = data['image'].numpy()
+        image = data["image"].numpy()
         mean.append(np.mean(image, axis=(0, 2, 3)))
         std.append(np.std(image, axis=(0, 2, 3)))
 
     mean = np.array(mean).mean(axis=0)
     std = np.array(std).mean(axis=0)
 
-    logging.info('Mean: {}'.format(np.uint8(mean)))
-    logging.info(' Std: {}'.format(np.uint8(std)))
-    
-    with open(os.path.join(log_path, params.mean), 'w') as file:
+    logging.info("Mean: {}".format(np.uint8(mean)))
+    logging.info(" Std: {}".format(np.uint8(std)))
+
+    with open(os.path.join(log_path, params.mean), "w") as file:
         for value in mean:
-            file.write(str(np.uint8(value))+'\n')
+            file.write(str(np.uint8(value)) + "\n")
 
-    with open(os.path.join(log_path, params.std), 'w') as file:
+    with open(os.path.join(log_path, params.std), "w") as file:
         for value in std:
-            file.write(str(np.uint8(value))+'\n')
+            file.write(str(np.uint8(value)) + "\n")
diff --git a/training/notify-slack.py b/training/notify-slack.py
index d3d2e24..e4805cd 100755
--- a/training/notify-slack.py
+++ b/training/notify-slack.py
@@ -1,11 +1,12 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 
-import requests
-from pathlib import Path
-import os
-import json
 import argparse
+import json
+import os
+from pathlib import Path
+
+import requests
 
 config_path = Path("~/.notify-slack-cfg").expanduser()
 if not config_path.exists():
@@ -16,17 +17,17 @@ if not config_path.exists():
     """
     raise ValueError(error_msg)
 
-SLACK_NOTIFY_ICON=":ubuntu:"
-SLACK_BOT_USERNAME="Bash Notifier"
+SLACK_NOTIFY_ICON = ":ubuntu:"
+SLACK_BOT_USERNAME = "Bash Notifier"
 SLACK_WEBHOOK_SERVICE = config_path.read_text().strip()
-SLACK_URL=f"https://hooks.slack.com/services/{SLACK_WEBHOOK_SERVICE}"
+SLACK_URL = f"https://hooks.slack.com/services/{SLACK_WEBHOOK_SERVICE}"
 LOG_PATH = Path("DLA_train.log")
 DLA_LENGTH = 18
 
 ICONS = {
-    "INFO": ':information_source: ',
-    "WARN": ':warning: ',
-    "ERROR": ':github_changes_requested: ',
+    "INFO": ":information_source: ",
+    "WARN": ":warning: ",
+    "ERROR": ":github_changes_requested: ",
 }
 
 
@@ -67,19 +68,24 @@ def run(message, log_file, number_of_lines):
 def main():
     parser = argparse.ArgumentParser("Send message to slack")
     parser.add_argument("message", help="Message to be sent to slack", type=str)
-    parser.add_argument("--log_file",
-                        help="Log file from where last N files will be included with the message. "
-                             "Use None as file name if don't want to include a log file.",
-                        type=Path,
-                        default=LOG_PATH)
-    parser.add_argument("-N", "--number_of_lines",
-                        help="Number of lines to be included from the end of the log file."
-                             "Use 0 to not include a log file.",
-                        type=int,
-                        default=DLA_LENGTH)
+    parser.add_argument(
+        "--log_file",
+        help="Log file from where last N files will be included with the message. "
+        "Use None as file name if don't want to include a log file.",
+        type=Path,
+        default=LOG_PATH,
+    )
+    parser.add_argument(
+        "-N",
+        "--number_of_lines",
+        help="Number of lines to be included from the end of the log file."
+        "Use 0 to not include a log file.",
+        type=int,
+        default=DLA_LENGTH,
+    )
     args = parser.parse_args()
     run(**vars(args))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/training/predict.py b/training/predict.py
index 336db81..972800a 100755
--- a/training/predict.py
+++ b/training/predict.py
@@ -8,17 +8,21 @@
     Use it to predict some images from a trained network.
 """
 
-import os
 import logging
+import os
 import time
+
 import cv2
 import numpy as np
-from tqdm import tqdm
-from shapely.geometry import Polygon
 import torch
 import utils.prediction_utils as pr_utils
+from shapely.geometry import Polygon
+from tqdm import tqdm
 
-def get_predicted_polygons(probas: np.ndarray, min_cc: int, classes_names: list) -> dict:
+
+def get_predicted_polygons(
+    probas: np.ndarray, min_cc: int, classes_names: list
+) -> dict:
     """
     Clean the predicted and retrieve the detected object coordinates.
     :param probas: The probability maps obtained by the model.
@@ -31,26 +35,39 @@ def get_predicted_polygons(probas: np.ndarray, min_cc: int, classes_names: list)
     max_probas = np.argmax(probas, axis=0)
     for channel in range(1, probas.shape[0]):
         # Keep pixels with highest probability.
-        channel_probas = np.uint8(max_probas == channel) \
-                             * probas[channel, :, :]
+        channel_probas = np.uint8(max_probas == channel) * probas[channel, :, :]
         # Retrieve the polygons contours.
         bin_img = channel_probas.copy()
         bin_img[bin_img > 0] = 1
-        contours, hierarchy = cv2.findContours(np.uint8(bin_img), cv2.RETR_EXTERNAL,
-                                               cv2.CHAIN_APPROX_SIMPLE)
+        contours, hierarchy = cv2.findContours(
+            np.uint8(bin_img), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+        )
         # Remove small connected components.
         if min_cc > 0:
-            contours = [contour for contour in contours if cv2.contourArea(contour) > min_cc]
+            contours = [
+                contour for contour in contours if cv2.contourArea(contour) > min_cc
+            ]
         page_contours[classes_names[channel]] = [
-            {'confidence': pr_utils.compute_confidence(contour, channel_probas),
-             'polygon': contour} for contour in contours
+            {
+                "confidence": pr_utils.compute_confidence(contour, channel_probas),
+                "polygon": contour,
+            }
+            for contour in contours
         ]
     return page_contours
 
 
-def run(prediction_path: str, log_path: str, img_size: int, colors: list,
-        classes_names: list, save_image: list, min_cc: int,
-        loaders: dict, net):
+def run(
+    prediction_path: str,
+    log_path: str,
+    img_size: int,
+    colors: list,
+    classes_names: list,
+    save_image: list,
+    min_cc: int,
+    loaders: dict,
+    net,
+):
     """
     Run the prediction.
     :param prediction_path: The path to save the predictions.
@@ -69,41 +86,64 @@ def run(prediction_path: str, log_path: str, img_size: int, colors: list,
     # Run prediction.
     net.eval()
 
-    logging.info('Starting predicting')
+    logging.info("Starting predicting")
     starting_time = time.time()
 
     with torch.no_grad():
-        for set, loader in zip(['train', 'val', 'test'],
-                               loaders.values()):
+        for set, loader in zip(["train", "val", "test"], loaders.values()):
             seen_datasets = []
             # Create folders to save the predictions.
-            os.makedirs(os.path.join(log_path, prediction_path, set),
-                        exist_ok=True)
+            os.makedirs(os.path.join(log_path, prediction_path, set), exist_ok=True)
 
-            for i, data in enumerate(tqdm(loader, desc="Prediction (prog) "+set), 0):
+            for i, data in enumerate(tqdm(loader, desc="Prediction (prog) " + set), 0):
                 # Create dataset folders to save the predictions.
-                if data['dataset'][0] not in seen_datasets:
-                    os.makedirs(os.path.join(log_path, prediction_path, set, data['dataset'][0]),
-                                exist_ok=True)
-                    seen_datasets.append(data['dataset'][0])
+                if data["dataset"][0] not in seen_datasets:
+                    os.makedirs(
+                        os.path.join(
+                            log_path, prediction_path, set, data["dataset"][0]
+                        ),
+                        exist_ok=True,
+                    )
+                    seen_datasets.append(data["dataset"][0])
 
                 # Generate and save the predictions.
-                output = net(data['image'].to(device).float())
-                input_size = [element.numpy()[0] for element in data['size'][:2]]
+                output = net(data["image"].to(device).float())
+                input_size = [element.numpy()[0] for element in data["size"][:2]]
 
-                assert(output.shape[0] == 1)
-                polygons = get_predicted_polygons(output[0].cpu().numpy(), min_cc, classes_names)
-                polygons = pr_utils.resize_polygons(polygons, input_size, img_size, data['padding'])
+                assert output.shape[0] == 1
+                polygons = get_predicted_polygons(
+                    output[0].cpu().numpy(), min_cc, classes_names
+                )
+                polygons = pr_utils.resize_polygons(
+                    polygons, input_size, img_size, data["padding"]
+                )
 
-                polygons['img_size'] = [int(element) for element in input_size]
-                pr_utils.save_prediction(polygons,
-                    os.path.join(log_path, prediction_path, set, data['dataset'][0], data['name'][0]))
+                polygons["img_size"] = [int(element) for element in input_size]
+                pr_utils.save_prediction(
+                    polygons,
+                    os.path.join(
+                        log_path,
+                        prediction_path,
+                        set,
+                        data["dataset"][0],
+                        data["name"][0],
+                    ),
+                )
                 if set in save_image:
-                    pr_utils.save_prediction_image(polygons, colors, input_size,
-                                                   os.path.join(log_path, prediction_path, set,
-                                                                data['dataset'][0], data['name'][0]))
+                    pr_utils.save_prediction_image(
+                        polygons,
+                        colors,
+                        input_size,
+                        os.path.join(
+                            log_path,
+                            prediction_path,
+                            set,
+                            data["dataset"][0],
+                            data["name"][0],
+                        ),
+                    )
 
     end = time.gmtime(time.time() - starting_time)
-    logging.info('Finished predicting in %2d:%2d:%2d',
-                 end.tm_hour, end.tm_min, end.tm_sec)
-
+    logging.info(
+        "Finished predicting in %2d:%2d:%2d", end.tm_hour, end.tm_min, end.tm_sec
+    )
diff --git a/training/retrieve_experiments_configs.py b/training/retrieve_experiments_configs.py
index 1d36e95..c2709f4 100644
--- a/training/retrieve_experiments_configs.py
+++ b/training/retrieve_experiments_configs.py
@@ -8,16 +8,17 @@
     Use it to get the configurations for running the experiments.
 """
 
-import os
+import argparse
 import csv
 import json
 import logging
-import argparse
+import os
 
-logging.basicConfig(level=logging.INFO,
-                    format='%(asctime)s - %(levelname)s - %(message)s')
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
 
-TMP_DIR = './tmp'
+TMP_DIR = "./tmp"
 STEPS = ["normalization_params", "train", "prediction", "evaluation"]
 
 
@@ -29,49 +30,64 @@ def run(config):
     os.makedirs(TMP_DIR, exist_ok=True)
 
     with open(config) as config_file:
-        reader = csv.DictReader(config_file, delimiter=',')
+        reader = csv.DictReader(config_file, delimiter=",")
         for index, row in enumerate(reader, 1):
 
             json_dict = {}
             # Get experiment name.
-            assert row['experiment_name'] != ''
-            json_dict['experiment_name'] = row['experiment_name']
+            assert row["experiment_name"] != ""
+            json_dict["experiment_name"] = row["experiment_name"]
 
             # Get steps as a list of names.
-            json_dict['steps'] = row['steps'].split(';')
+            json_dict["steps"] = row["steps"].split(";")
 
             # Get train/val/test folders.
-            json_dict['data_paths'] = {}
-            for set in ['train', 'val', 'test']:
-                json_dict['data_paths'][set] = {}
-                if set in ['train', 'val']:
-                    for folder, key in zip(['images', 'labels', 'labels_json'], ['image', 'mask', 'json']):
-                        if row[set] != '':
-                            json_dict['data_paths'][set][key] = [
-                                os.path.join(element, set, folder) for element in row[set].split(';')]
+            json_dict["data_paths"] = {}
+            for set in ["train", "val", "test"]:
+                json_dict["data_paths"][set] = {}
+                if set in ["train", "val"]:
+                    for folder, key in zip(
+                        ["images", "labels", "labels_json"], ["image", "mask", "json"]
+                    ):
+                        if row[set] != "":
+                            json_dict["data_paths"][set][key] = [
+                                os.path.join(element, set, folder)
+                                for element in row[set].split(";")
+                            ]
                         else:
-                            json_dict['data_paths'][set][key] = []
+                            json_dict["data_paths"][set][key] = []
                 else:
-                    for folder, key in zip(['images', 'labels_json'], ['image', 'json']):
-                        if row[set] != '':
-                            json_dict['data_paths'][set][key] = [
-                                os.path.join(element, set, folder) for element in row[set].split(';')]
+                    for folder, key in zip(
+                        ["images", "labels_json"], ["image", "json"]
+                    ):
+                        if row[set] != "":
+                            json_dict["data_paths"][set][key] = [
+                                os.path.join(element, set, folder)
+                                for element in row[set].split(";")
+                            ]
                         else:
-                            json_dict['data_paths'][set][key] = []
+                            json_dict["data_paths"][set][key] = []
 
             # Get restore model.
-            if row['restore_model'] != '':
-                json_dict['training'] = {'restore_model': row['restore_model']}
-                if row['loss'] != '':
-                    json_dict['training']['loss'] = row['loss']
+            if row["restore_model"] != "":
+                json_dict["training"] = {"restore_model": row["restore_model"]}
+                if row["loss"] != "":
+                    json_dict["training"]["loss"] = row["loss"]
 
             # Save configuration file.
-            json_file = str(index)+'_'+row['experiment_name']+'.json'
-            with open(os.path.join(TMP_DIR, json_file), 'w') as file:
-                json.dump({key: value for key, value in json_dict.items() if value}, file, indent=4)
+            json_file = str(index) + "_" + row["experiment_name"] + ".json"
+            with open(os.path.join(TMP_DIR, json_file), "w") as file:
+                json.dump(
+                    {key: value for key, value in json_dict.items() if value},
+                    file,
+                    indent=4,
+                )
 
-    logging.info(f"Retrieved {index} experiment configurations from {config}"
-                 if index > 1 else f"Retrieved {index} experiment configuration from {config}")
+    logging.info(
+        f"Retrieved {index} experiment configurations from {config}"
+        if index > 1
+        else f"Retrieved {index} experiment configuration from {config}"
+    )
 
 
 def main():
@@ -80,13 +96,15 @@ def main():
     """
     parser = argparse.ArgumentParser(
         description="Script to retrieve the experiments configs.",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--config', type=str, required=True,
-                        help='Path to the configurations file')
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--config", type=str, required=True, help="Path to the configurations file"
+    )
 
     args = parser.parse_args()
     run(**(vars(args)))
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/training/run_dla_experiment.sh b/training/run_dla_experiment.sh
index 28b8f5a..4350b59 100644
--- a/training/run_dla_experiment.sh
+++ b/training/run_dla_experiment.sh
@@ -45,12 +45,12 @@ for filename in ${TMP_DIR}/*; do
         && (python3 notify-slack.py "INFO: Experiment completed" --log_file DLA_train_"${index}".log) \
         || (python3 notify-slack.py "ERROR: Experiment failed" --log_file DLA_train_"${index}".log ; exit)
     fi
-    
+
     index=$((index+1))
 
 done
 
-echo 
+echo
 echo "Experiments done!"
 
 rm -r $TMP_DIR
diff --git a/training/run_experiment.py b/training/run_experiment.py
index bd1f6ed..10562f1 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -8,32 +8,37 @@
     Use it to train, predict and evaluate a model.
 """
 
-import os
-import sys
 import json
 import logging
-from tqdm import tqdm
+import os
+import sys
 from pathlib import Path
-from sacred import Experiment
-from sacred.observers import MongoObserver
+
+import evaluate
+import normalization_params
+import predict
 import torch
 import torch.optim as optim
+import train
+import utils.preprocessing as pprocessing
+import utils.training_utils as tr_utils
+from sacred import Experiment
+from sacred.observers import MongoObserver
 from torch.cuda.amp import GradScaler
 from torch.utils.data import DataLoader
 from torchvision import transforms
-import normalization_params
-import train, predict, evaluate
+from tqdm import tqdm
 from utils import model, utils
 from utils.params_config import Params
-import utils.preprocessing as pprocessing
-import utils.training_utils as tr_utils
 
 STEPS = ["normalization_params", "train", "prediction", "evaluation"]
 
-ex = Experiment('Doc-UFCN')
-logging.basicConfig(level=logging.INFO,
-                    format='%(asctime)s - %(levelname)s - %(message)s')
-mongo_url='mongodb://user:password@omniboard.vpn/sacred'
+ex = Experiment("Doc-UFCN")
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
+mongo_url = "mongodb://user:password@omniboard.vpn/sacred"
+
 
 @ex.config
 def default_config():
@@ -76,15 +81,18 @@ def default_config():
         "omniboard": False,
         "min_cc": 0,
         "save_image": [],
-        "use_amp": False
+        "use_amp": False,
     }
-    assert global_params['batch_size'] is not None or global_params['no_of_params'] is not None, "Please provide a batch size or a maximum number of parameters"
+    assert (
+        global_params["batch_size"] is not None
+        or global_params["no_of_params"] is not None
+    ), "Please provide a batch size or a maximum number of parameters"
     params = Params().to_dict()
 
     # Load the current experiment parameters.
-    experiment_name = 'doc-ufcn'
-    log_path = 'runs/'+experiment_name.lower().replace(' ', '_').replace('-', '_')
-    tb_path = 'events/'
+    experiment_name = "doc-ufcn"
+    log_path = "runs/" + experiment_name.lower().replace(" ", "_").replace("-", "_")
+    tb_path = "events/"
     steps = ["normalization_params", "train", "prediction", "evaluation"]
     for step in steps:
         assert step in STEPS
@@ -93,37 +101,43 @@ def default_config():
         "train": {
             "image": ["./data/train/images/"],
             "mask": ["./data/train/labels/"],
-            "json": ["./data/train/labels_json/"]
+            "json": ["./data/train/labels_json/"],
         },
         "val": {
             "image": ["./data/val/images/"],
             "mask": ["./data/val/labels/"],
-            "json": ["./data/val/labels_json/"]
+            "json": ["./data/val/labels_json/"],
         },
         "test": {
             "image": ["./data/test/images/"],
-            "json": ["./data/test/labels_json/"]
-        }
+            "json": ["./data/test/labels_json/"],
+        },
     }
-    exp_data_paths = {set:
-        {key: [Path(element).expanduser() for element in value]
-        for key, value in paths.items()}
+    exp_data_paths = {
+        set: {
+            key: [Path(element).expanduser() for element in value]
+            for key, value in paths.items()
+        }
         for set, paths in data_paths.items()
     }
 
-    training = {
-        "restore_model": None,
-        "loss": 'initial'
-    }
-    training['loss'] = training['loss'].lower()
-    assert training['loss'] in ['initial', 'best']
-    if "train" in steps and global_params['omniboard'] is True:
+    training = {"restore_model": None, "loss": "initial"}
+    training["loss"] = training["loss"].lower()
+    assert training["loss"] in ["initial", "best"]
+    if "train" in steps and global_params["omniboard"] is True:
         ex.observers.append(MongoObserver(mongo_url))
 
 
 @ex.capture
-def save_config(log_path: str, experiment_name: str, global_params: dict,
-                params: Params, steps: list, data_paths: dict, training: dict):
+def save_config(
+    log_path: str,
+    experiment_name: str,
+    global_params: dict,
+    params: Params,
+    steps: list,
+    data_paths: dict,
+    training: dict,
+):
     """
     Save the current configuration.
     :param log_path: Path to save the experiment information and model.
@@ -137,13 +151,13 @@ def save_config(log_path: str, experiment_name: str, global_params: dict,
     """
     os.makedirs(log_path, exist_ok=True)
     json_dict = {
-        'global_params': global_params,
-        'params': params,
-        'steps': steps,
-        'data_paths': data_paths,
-        'training': training
+        "global_params": global_params,
+        "params": params,
+        "steps": steps,
+        "data_paths": data_paths,
+        "training": training,
     }
-    with open(os.path.join(log_path, experiment_name+'.json'), 'w') as config_file:
+    with open(os.path.join(log_path, experiment_name + ".json"), "w") as config_file:
         json.dump(json_dict, config_file, indent=4)
 
 
@@ -158,24 +172,26 @@ def get_mean_std(log_path: str, params: Params) -> dict:
     """
     params = Params.from_dict(params)
     if not os.path.isfile(os.path.join(log_path, params.mean)):
-        logging.error('No file found at %s', os.path.join(log_path, params.mean))
+        logging.error("No file found at %s", os.path.join(log_path, params.mean))
         sys.exit()
     else:
-        with open(os.path.join(log_path, params.mean), 'r') as file:
+        with open(os.path.join(log_path, params.mean), "r") as file:
             mean = file.read().splitlines()
             mean = [int(value) for value in mean]
     if not os.path.isfile(os.path.join(log_path, params.std)):
-        logging.error('No file found at %s', os.path.join(log_path, params.std))
+        logging.error("No file found at %s", os.path.join(log_path, params.std))
         sys.exit()
     else:
-        with open(os.path.join(log_path, params.std), 'r') as file:
+        with open(os.path.join(log_path, params.std), "r") as file:
             std = file.read().splitlines()
             std = [int(value) for value in std]
-    return {'mean': mean, 'std': std}
+    return {"mean": mean, "std": std}
 
 
 @ex.capture
-def training_loaders(norm_params: dict, exp_data_paths: dict, global_params: dict) -> dict:
+def training_loaders(
+    norm_params: dict, exp_data_paths: dict, global_params: dict
+) -> dict:
     """
     Generate the loaders to use during the training step.
     :param norm_params: The mean and std values used during image normalization.
@@ -184,29 +200,44 @@ def training_loaders(norm_params: dict, exp_data_paths: dict, global_params: dic
     :return loaders: A dictionary with the loaders.
     """
     loaders = {}
-    t = tqdm(['train', 'val'])
+    t = tqdm(["train", "val"])
     t.set_description("Loading data")
-    for set, images, masks in zip(t,
-                                  [exp_data_paths['train']['image'], exp_data_paths['val']['image']],
-                                  [exp_data_paths['train']['mask'], exp_data_paths['val']['mask']]):
+    for set, images, masks in zip(
+        t,
+        [exp_data_paths["train"]["image"], exp_data_paths["val"]["image"]],
+        [exp_data_paths["train"]["mask"], exp_data_paths["val"]["mask"]],
+    ):
         dataset = pprocessing.TrainingDataset(
-            images, masks,
-            global_params['classes_colors'], transform=transforms.Compose([
-                pprocessing.Rescale(global_params['img_size']),
-                pprocessing.Normalize(norm_params['mean'], norm_params['std'])])
+            images,
+            masks,
+            global_params["classes_colors"],
+            transform=transforms.Compose(
+                [
+                    pprocessing.Rescale(global_params["img_size"]),
+                    pprocessing.Normalize(norm_params["mean"], norm_params["std"]),
+                ]
+            ),
+        )
+        loaders[set] = DataLoader(
+            dataset,
+            num_workers=2,
+            pin_memory=True,
+            batch_sampler=utils.Sampler(
+                dataset,
+                bin_size=global_params["bin_size"],
+                batch_size=global_params["batch_size"],
+                nb_params=global_params["no_of_params"],
+            ),
+            collate_fn=utils.DLACollateFunction(),
         )
-        loaders[set] = DataLoader(dataset, num_workers=2, pin_memory=True,
-                                  batch_sampler=utils.Sampler(dataset, bin_size=global_params["bin_size"],
-                                                              batch_size=global_params["batch_size"],
-                                                              nb_params=global_params["no_of_params"]), 
-                                  collate_fn=utils.DLACollateFunction())
         logging.info(f"{set}: Found {len(dataset)} images")
     return loaders
 
 
 @ex.capture
-def prediction_loaders(norm_params: dict, exp_data_paths: dict,
-                       global_params: dict) -> dict:
+def prediction_loaders(
+    norm_params: dict, exp_data_paths: dict, global_params: dict
+) -> dict:
     """
     Generate the loaders to use during the prediction step.
     :param norm_params: The mean and std values used during image normalization.
@@ -216,18 +247,27 @@ def prediction_loaders(norm_params: dict, exp_data_paths: dict,
     """
     loaders = {}
     for set, images in zip(
-            ['train', 'val', 'test'],
-            [exp_data_paths['train']['image'], exp_data_paths['val']['image'], exp_data_paths['test']['image']]):
+        ["train", "val", "test"],
+        [
+            exp_data_paths["train"]["image"],
+            exp_data_paths["val"]["image"],
+            exp_data_paths["test"]["image"],
+        ],
+    ):
         dataset = pprocessing.PredictionDataset(
             images,
-            transform=transforms.Compose([
-                pprocessing.Rescale(global_params['img_size']),
-                pprocessing.Normalize(norm_params['mean'], norm_params['std']),
-                pprocessing.Pad(),
-                pprocessing.ToTensor()])
+            transform=transforms.Compose(
+                [
+                    pprocessing.Rescale(global_params["img_size"]),
+                    pprocessing.Normalize(norm_params["mean"], norm_params["std"]),
+                    pprocessing.Pad(),
+                    pprocessing.ToTensor(),
+                ]
+            ),
+        )
+        loaders[set + "_loader"] = DataLoader(
+            dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True
         )
-        loaders[set+'_loader'] = DataLoader(dataset, batch_size=1, shuffle=False,
-                                            num_workers=2, pin_memory=True)
     return loaders
 
 
@@ -236,45 +276,52 @@ def training_initialization(global_params: dict, training: dict, log_path: str)
     """
     Initialize the training step.
     :param global_params: Global parameters of the experiment entered by the used.
-    :param training: Training parameters.    
+    :param training: Training parameters.
     :param log_path: Path to save the experiment information and model.
     :return tr_params: A dictionary with the training parameters.
     """
-    no_of_classes = len(global_params['classes_names'])
-    ex.log_scalar('no_of_classes', no_of_classes)
-    net = model.load_network(no_of_classes, global_params['use_amp'], ex)
-    
-    if training['restore_model'] is None:
+    no_of_classes = len(global_params["classes_names"])
+    ex.log_scalar("no_of_classes", no_of_classes)
+    net = model.load_network(no_of_classes, global_params["use_amp"], ex)
+
+    if training["restore_model"] is None:
         net.apply(model.weights_init)
         tr_params = {
-            'net': net,
-            'criterion': tr_utils.Diceloss(no_of_classes),
-            'optimizer': optim.Adam(net.parameters(), lr=global_params['learning_rate']),
-            'saved_epoch': 0,
-            'best_loss': 10e5,
-            'scaler': GradScaler(enabled=global_params['use_amp']),
-            'use_amp': global_params['use_amp']
+            "net": net,
+            "criterion": tr_utils.Diceloss(no_of_classes),
+            "optimizer": optim.Adam(
+                net.parameters(), lr=global_params["learning_rate"]
+            ),
+            "saved_epoch": 0,
+            "best_loss": 10e5,
+            "scaler": GradScaler(enabled=global_params["use_amp"]),
+            "use_amp": global_params["use_amp"],
         }
     else:
-    # Restore model to resume training.
+        # Restore model to resume training.
         checkpoint, net, optimizer, scaler = model.restore_model(
-            net, optim.Adam(net.parameters(), lr=global_params['learning_rate']),
-            GradScaler(enabled=global_params['use_amp']), log_path, training['restore_model'])
+            net,
+            optim.Adam(net.parameters(), lr=global_params["learning_rate"]),
+            GradScaler(enabled=global_params["use_amp"]),
+            log_path,
+            training["restore_model"],
+        )
         tr_params = {
-            'net': net,
-            'criterion': tr_utils.Diceloss(no_of_classes),
-            'optimizer': optimizer,
-            'saved_epoch': checkpoint['epoch'],
-            'best_loss': checkpoint['best_loss'] if training['loss'] == 'best' else 10e5,
-            'scaler': scaler,
-            'use_amp': global_params['use_amp']
+            "net": net,
+            "criterion": tr_utils.Diceloss(no_of_classes),
+            "optimizer": optimizer,
+            "saved_epoch": checkpoint["epoch"],
+            "best_loss": checkpoint["best_loss"]
+            if training["loss"] == "best"
+            else 10e5,
+            "scaler": scaler,
+            "use_amp": global_params["use_amp"],
         }
     return tr_params
 
 
 @ex.capture
-def prediction_initialization(params: dict, global_params: dict,
-                              log_path: str) -> dict:
+def prediction_initialization(params: dict, global_params: dict, log_path: str) -> dict:
     """
     Initialize the prediction step.
     :param params: The global parameters of the experiment.
@@ -283,7 +330,7 @@ def prediction_initialization(params: dict, global_params: dict,
     :return: A dictionary with the prediction parameters.
     """
     params = Params.from_dict(params)
-    no_of_classes = len(global_params['classes_names'])
+    no_of_classes = len(global_params["classes_names"])
     net = model.load_network(no_of_classes, False, ex)
 
     _, net, _, _ = model.restore_model(net, None, None, log_path, params.model_path)
@@ -291,8 +338,15 @@ def prediction_initialization(params: dict, global_params: dict,
 
 
 @ex.automain
-def run(global_params: dict, params: Params, log_path: str,
-        tb_path: str, steps: list, exp_data_paths: dict, training: dict):
+def run(
+    global_params: dict,
+    params: Params,
+    log_path: str,
+    tb_path: str,
+    steps: list,
+    exp_data_paths: dict,
+    training: dict,
+):
     """
     Main program.
     :param global_params: Global parameters of the experiment entered by the used.
@@ -301,7 +355,7 @@ def run(global_params: dict, params: Params, log_path: str,
     :param tb_path: Path to save the Tensorboard events.
     :param steps: List of the steps to run.
     :param exp_data_paths: Path to the data folders.
-    :param training: Training parameters.    
+    :param training: Training parameters.
     """
     if len(steps) == 0:
         logging.info("No step to run, exiting execution.")
@@ -311,8 +365,9 @@ def run(global_params: dict, params: Params, log_path: str,
         save_config()
 
         if "normalization_params" in steps:
-            normalization_params.run(log_path, exp_data_paths,
-                                     params, global_params['img_size'])
+            normalization_params.run(
+                log_path, exp_data_paths, params, global_params["img_size"]
+            )
 
         if "train" in steps or "prediction" in steps:
             # Get the mean and std values.
@@ -322,23 +377,45 @@ def run(global_params: dict, params: Params, log_path: str,
             # Generate the loaders and start training.
             loaders = training_loaders(norm_params)
             tr_params = training_initialization()
-            train.run(params.model_path, log_path, tb_path, global_params['no_of_epochs'],
-                      norm_params, global_params['classes_names'], loaders, tr_params, ex)
+            train.run(
+                params.model_path,
+                log_path,
+                tb_path,
+                global_params["no_of_epochs"],
+                norm_params,
+                global_params["classes_names"],
+                loaders,
+                tr_params,
+                ex,
+            )
 
         if "prediction" in steps:
             # Generate the loaders and start predicting.
             loaders = prediction_loaders(norm_params)
             net = prediction_initialization()
-            predict.run(params.prediction_path, log_path, global_params['img_size'],
-                        global_params['classes_colors'], global_params['classes_names'],
-                        global_params['save_image'], global_params['min_cc'], loaders, net)
+            predict.run(
+                params.prediction_path,
+                log_path,
+                global_params["img_size"],
+                global_params["classes_colors"],
+                global_params["classes_names"],
+                global_params["save_image"],
+                global_params["min_cc"],
+                loaders,
+                net,
+            )
 
         if "evaluation" in steps:
             for set in exp_data_paths.keys():
                 for dataset in exp_data_paths[set]["json"]:
                     if os.path.isdir(dataset):
-                        evaluate.run(log_path, global_params['classes_names'], set,
-                                     exp_data_paths[set]['json'], str(dataset.parent.parent.name),
-                                     params)
+                        evaluate.run(
+                            log_path,
+                            global_params["classes_names"],
+                            set,
+                            exp_data_paths[set]["json"],
+                            str(dataset.parent.parent.name),
+                            params,
+                        )
                     else:
                         logging.info(f"{dataset} folder not found.")
diff --git a/training/train.py b/training/train.py
index 3089936..0315bcb 100755
--- a/training/train.py
+++ b/training/train.py
@@ -8,19 +8,21 @@
     Use it to train a model.
 """
 
-import sys
-import os
 import logging
+import os
+import sys
 import time
+
 import numpy as np
-from tqdm import tqdm
 import torch
-from torch.utils.tensorboard import SummaryWriter
+import utils.training_pixel_metrics as p_metrics
+import utils.training_utils as tr_utils
 from torch.cuda.amp import autocast
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
 from utils import model
 from utils.params_config import Params
-import utils.training_pixel_metrics as p_metrics
-import utils.training_utils as tr_utils
+
 
 def init_metrics(no_of_classes: int) -> dict:
     """
@@ -41,17 +43,27 @@ def log_metrics(ex, epoch: int, metrics: dict, writer, step: str):
     :param step: String indicating whether to log training or validation metrics.
     """
     for key in metrics.keys():
-        writer.add_scalar(step+'_'+key, metrics[key], epoch)
-        ex.log_scalar(step.lower()+'.'+key, metrics[key], epoch)
-        if step == 'Training':
-            logging.info('  TRAIN {}: {}={}'.format(epoch, key, round(metrics[key], 4)))
+        writer.add_scalar(step + "_" + key, metrics[key], epoch)
+        ex.log_scalar(step.lower() + "." + key, metrics[key], epoch)
+        if step == "Training":
+            logging.info("  TRAIN {}: {}={}".format(epoch, key, round(metrics[key], 4)))
         else:
-            logging.info('    VALID {}: {}={}'.format(epoch, key, round(metrics[key], 4)))
-
-
-def run_one_epoch(loader, params: dict, writer, epochs: list,
-                  no_of_epochs: int, device: str, norm_params: dict,
-                  classes_names: list, step: str):
+            logging.info(
+                "    VALID {}: {}={}".format(epoch, key, round(metrics[key], 4))
+            )
+
+
+def run_one_epoch(
+    loader,
+    params: dict,
+    writer,
+    epochs: list,
+    no_of_epochs: int,
+    device: str,
+    norm_params: dict,
+    classes_names: list,
+    step: str,
+):
     """
     Run one epoch of training (or validation).
     :param loader: The loader containing the images and masks.
@@ -64,46 +76,50 @@ def run_one_epoch(loader, params: dict, writer, epochs: list,
     :param classes_names: The names of the classes involved during the experiment.
     :param step: String indicating whether to run a training or validation step.
     :return params: The updated training parameters.
-    :return epoch_values: The metrics computed during the epoch. 
+    :return epoch_values: The metrics computed during the epoch.
     """
     metrics = init_metrics(len(classes_names))
     epoch = epochs[0]
 
     t = tqdm(loader)
-    if step == 'Training':
-        t.set_description("TRAIN (prog) {}/{}".format(epoch, no_of_epochs+epochs[1]))
+    if step == "Training":
+        t.set_description("TRAIN (prog) {}/{}".format(epoch, no_of_epochs + epochs[1]))
     else:
-        t.set_description("VALID (prog) {}/{}".format(epoch, no_of_epochs+epochs[1]))
+        t.set_description("VALID (prog) {}/{}".format(epoch, no_of_epochs + epochs[1]))
 
     for index, data in enumerate(t, 1):
-        params['optimizer'].zero_grad()
-        with autocast(enabled=params['use_amp']):
-            if params['use_amp']:
-                output = params['net'](data['image'].to(device).half())
+        params["optimizer"].zero_grad()
+        with autocast(enabled=params["use_amp"]):
+            if params["use_amp"]:
+                output = params["net"](data["image"].to(device).half())
             else:
-                output = params['net'](data['image'].to(device).float())
-            loss = params['criterion'](output, data['mask'].to(device).long())
+                output = params["net"](data["image"].to(device).float())
+            loss = params["criterion"](output, data["mask"].to(device).long())
 
         for pred in range(output.shape[0]):
-            current_pred = np.argmax(output[pred, :, :, :].cpu().detach().numpy(), axis=0)
-            current_label = data['mask'][pred, :, :].cpu().detach().numpy()
-            batch_metrics = p_metrics.compute_metrics(current_pred, current_label,
-                                                      loss.item(), classes_names)
+            current_pred = np.argmax(
+                output[pred, :, :, :].cpu().detach().numpy(), axis=0
+            )
+            current_label = data["mask"][pred, :, :].cpu().detach().numpy()
+            batch_metrics = p_metrics.compute_metrics(
+                current_pred, current_label, loss.item(), classes_names
+            )
             metrics = p_metrics.update_metrics(metrics, batch_metrics)
-           
-        epoch_values = tr_utils.get_epoch_values(metrics, classes_names, index+1)
+
+        epoch_values = tr_utils.get_epoch_values(metrics, classes_names, index + 1)
         display_values = epoch_values
-        display_values['loss'] = round(display_values['loss'], 4)
+        display_values["loss"] = round(display_values["loss"], 4)
         t.set_postfix(values=str(display_values))
 
         if step == "Training":
-            params['scaler'].scale(loss).backward()
-            params['scaler'].step(params['optimizer'])
-            params['scaler'].update()
+            params["scaler"].scale(loss).backward()
+            params["scaler"].step(params["optimizer"])
+            params["scaler"].update()
             # Display prediction images in Tensorboard all 100 mini-batches.
             if index == 1 or index % 100 == 99:
-                tr_utils.display_training(output, data['image'], data['mask'], writer,
-                                          epoch, norm_params)
+                tr_utils.display_training(
+                    output, data["image"], data["mask"], writer, epoch, norm_params
+                )
 
     if step == "Training":
         return params, epoch_values
@@ -111,8 +127,17 @@ def run_one_epoch(loader, params: dict, writer, epochs: list,
         return epoch_values
 
 
-def run(model_path: str, log_path: str, tb_path: str, no_of_epochs: int,
-        norm_params: dict, classes_names: list, loaders: dict, tr_params: dict, ex):
+def run(
+    model_path: str,
+    log_path: str,
+    tb_path: str,
+    no_of_epochs: int,
+    norm_params: dict,
+    classes_names: list,
+    loaders: dict,
+    tr_params: dict,
+    ex,
+):
     """
     Run the training.
     :param model_path: The path to save the trained model.
@@ -129,49 +154,73 @@ def run(model_path: str, log_path: str, tb_path: str, no_of_epochs: int,
 
     # Run training.
     writer = SummaryWriter(os.path.join(log_path, tb_path))
-    logging.info('Starting training')
+    logging.info("Starting training")
     starting_time = time.time()
 
-    for epoch in range(1, no_of_epochs+1):
-        current_epoch = epoch + tr_params['saved_epoch']
+    for epoch in range(1, no_of_epochs + 1):
+        current_epoch = epoch + tr_params["saved_epoch"]
         # Run training.
-        tr_params['net'].train()
+        tr_params["net"].train()
         tr_params, epoch_values = run_one_epoch(
-            loaders['train'], tr_params, writer,
-            [current_epoch, tr_params['saved_epoch']],
-            no_of_epochs, device, norm_params, classes_names, step="Training")
+            loaders["train"],
+            tr_params,
+            writer,
+            [current_epoch, tr_params["saved_epoch"]],
+            no_of_epochs,
+            device,
+            norm_params,
+            classes_names,
+            step="Training",
+        )
 
         log_metrics(ex, current_epoch, epoch_values, writer, step="Training")
 
         with torch.no_grad():
             # Run evaluation.
-            tr_params['net'].eval()
-            epoch_values = run_one_epoch(loaders['val'], tr_params, writer,
-                                         [current_epoch, tr_params['saved_epoch']],
-                                         no_of_epochs, device, norm_params,
-                                         classes_names, step="Validation")
+            tr_params["net"].eval()
+            epoch_values = run_one_epoch(
+                loaders["val"],
+                tr_params,
+                writer,
+                [current_epoch, tr_params["saved_epoch"]],
+                no_of_epochs,
+                device,
+                norm_params,
+                classes_names,
+                step="Validation",
+            )
             log_metrics(ex, current_epoch, epoch_values, writer, step="Validation")
             # Keep best model.
-            if epoch_values['loss'] < tr_params['best_loss']:
-                tr_params['best_loss'] = epoch_values['loss']
-                model.save_model(current_epoch+1, tr_params['net'].state_dict(), epoch_values['loss'],
-                                 tr_params['optimizer'].state_dict(), tr_params['scaler'].state_dict(),
-                                 os.path.join(log_path, model_path))
-                logging.info('Best model (epoch %d) saved', current_epoch)
+            if epoch_values["loss"] < tr_params["best_loss"]:
+                tr_params["best_loss"] = epoch_values["loss"]
+                model.save_model(
+                    current_epoch + 1,
+                    tr_params["net"].state_dict(),
+                    epoch_values["loss"],
+                    tr_params["optimizer"].state_dict(),
+                    tr_params["scaler"].state_dict(),
+                    os.path.join(log_path, model_path),
+                )
+                logging.info("Best model (epoch %d) saved", current_epoch)
 
     # Save last model.
-    path = os.path.join(log_path, 'last_'+model_path).replace('model', 'model_0')
+    path = os.path.join(log_path, "last_" + model_path).replace("model", "model_0")
     index = 1
     while os.path.exists(path):
-        path = path.replace(str(index-1), str(index))
+        path = path.replace(str(index - 1), str(index))
         index += 1
 
-    model.save_model(current_epoch, tr_params['net'].state_dict(), epoch_values['loss'],
-                     tr_params['optimizer'].state_dict(), tr_params['scaler'].state_dict(), path)
-    logging.info('Last model (epoch %d) saved', current_epoch)
+    model.save_model(
+        current_epoch,
+        tr_params["net"].state_dict(),
+        epoch_values["loss"],
+        tr_params["optimizer"].state_dict(),
+        tr_params["scaler"].state_dict(),
+        path,
+    )
+    logging.info("Last model (epoch %d) saved", current_epoch)
 
     end = time.gmtime(time.time() - starting_time)
-    logging.info('Finished training in %2d:%2d:%2d',
-                 end.tm_hour, end.tm_min, end.tm_sec)
-
-
+    logging.info(
+        "Finished training in %2d:%2d:%2d", end.tm_hour, end.tm_min, end.tm_sec
+    )
diff --git a/training/utils/evaluation_utils.py b/training/utils/evaluation_utils.py
index 98d24ab..8d7cd85 100755
--- a/training/utils/evaluation_utils.py
+++ b/training/utils/evaluation_utils.py
@@ -4,18 +4,19 @@
 """
     The evaluation utils module
     ======================
-    
+
     Use it to during the evaluation stage.
 """
 
+import json
 import os
+
 import cv2
-import json
+import matplotlib.font_manager as fm
+import matplotlib.pyplot as plt
 import numpy as np
-from shapely.geometry import Polygon, MultiPolygon
 from matplotlib import ticker
-import matplotlib.pyplot as plt
-import matplotlib.font_manager as fm
+from shapely.geometry import MultiPolygon, Polygon
 
 
 def read_json(filename: str) -> dict:
@@ -24,7 +25,7 @@ def read_json(filename: str) -> dict:
     :param filename: Path to the file to read.
     :return: A dictionary with the file content.
     """
-    with open(filename, 'r') as file:
+    with open(filename, "r") as file:
         return json.load(file)
 
 
@@ -39,15 +40,19 @@ def get_polygons(regions: dict, classes: list) -> dict:
     polys = {}
     for index, channel in enumerate(classes[1:], 1):
         if channel in regions.keys():
-            polys[channel] = [(polygon['confidence'], Polygon(polygon['polygon']).buffer(0))
-                              for polygon in regions[channel]]
+            polys[channel] = [
+                (polygon["confidence"], Polygon(polygon["polygon"]).buffer(0))
+                for polygon in regions[channel]
+            ]
     return polys
 
+
 # Save the metrics.
 
 
-def save_results(pixel_results: dict, object_results: dict, classes: list,
-                 path: str, dataset: str):
+def save_results(
+    pixel_results: dict, object_results: dict, classes: list, path: str, dataset: str
+):
     """
     Save the pixel and object results into a json file.
     :param pixel_results: The results obtained at pixel level.
@@ -57,21 +62,28 @@ def save_results(pixel_results: dict, object_results: dict, classes: list,
     :param dataset: The name of the current dataset.
     """
     json_dict = {channel: {} for channel in classes}
-    
+
     for channel in classes:
-        json_dict[channel]['iou'] = np.round(np.mean(pixel_results[channel]['iou']), 4)
-        json_dict[channel]['precision'] = np.round(np.mean(pixel_results[channel]['precision']), 4)
-        json_dict[channel]['recall'] = np.round(np.mean(pixel_results[channel]['recall']), 4)
-        json_dict[channel]['fscore'] = np.round(np.mean(pixel_results[channel]['fscore']), 4)
-        aps = object_results[channel]['AP']
-        json_dict[channel]['AP@[.5]'] = np.round(aps[50], 4)
-        json_dict[channel]['AP@[.75]'] = np.round(aps[75], 4)
-        json_dict[channel]['AP@[.95]'] = np.round(aps[95], 4)
-        json_dict[channel]['AP@[.5,.95]'] = np.round(np.mean(list(aps.values())), 4)
-
-    with open(os.path.join(path, dataset+'_results.json'), 'w') as json_file:
+        json_dict[channel]["iou"] = np.round(np.mean(pixel_results[channel]["iou"]), 4)
+        json_dict[channel]["precision"] = np.round(
+            np.mean(pixel_results[channel]["precision"]), 4
+        )
+        json_dict[channel]["recall"] = np.round(
+            np.mean(pixel_results[channel]["recall"]), 4
+        )
+        json_dict[channel]["fscore"] = np.round(
+            np.mean(pixel_results[channel]["fscore"]), 4
+        )
+        aps = object_results[channel]["AP"]
+        json_dict[channel]["AP@[.5]"] = np.round(aps[50], 4)
+        json_dict[channel]["AP@[.75]"] = np.round(aps[75], 4)
+        json_dict[channel]["AP@[.95]"] = np.round(aps[95], 4)
+        json_dict[channel]["AP@[.5,.95]"] = np.round(np.mean(list(aps.values())), 4)
+
+    with open(os.path.join(path, dataset + "_results.json"), "w") as json_file:
         json.dump(json_dict, json_file, indent=4)
 
+
 def save_graphical_results(results: dict, classes: list, path: str):
     """
     Plot various curves involving the computed metrics.
@@ -80,9 +92,10 @@ def save_graphical_results(results: dict, classes: list, path: str):
     :param path: The path to the results directory.
     """
     plot_precision_recall_curve(results, classes, path)
-    plot_rank_score(results, classes, 'Precision', path)
-    plot_rank_score(results, classes, 'Recall', path)
-    plot_rank_score(results, classes, 'F-score', path)
+    plot_rank_score(results, classes, "Precision", path)
+    plot_rank_score(results, classes, "Recall", path)
+    plot_rank_score(results, classes, "F-score", path)
+
 
 def generate_figure(params: dict, rotation: bool = None):
     """
@@ -94,18 +107,17 @@ def generate_figure(params: dict, rotation: bool = None):
     :return axis: The created axis to plot.
     :return fp_light: The loaded font property used in the figures.
     """
-    fp_light = fm.FontProperties(fname='./utils/font/Quicksand-Light.ttf',
-                                 size=11)
-    fp_medium = fm.FontProperties(fname='./utils/font/Quicksand-Medium.ttf',
-                                  size=11)
-    fig = plt.figure(figsize=params['size'])
+    fp_light = fm.FontProperties(fname="./utils/font/Quicksand-Light.ttf", size=11)
+    fp_medium = fm.FontProperties(fname="./utils/font/Quicksand-Medium.ttf", size=11)
+    fig = plt.figure(figsize=params["size"])
     axis = fig.add_subplot(111)
-    axis.set_xlabel(params['xlabel'], fontproperties=fp_light)
-    axis.set_ylabel(params['ylabel'], fontproperties=fp_light)
-    axis.set_xticklabels(params['xticks'], fontproperties=fp_light)
-    axis.set_yticklabels(params['yticks'], fontproperties=fp_light,
-                         rotation=rotation, va="center")
-    plt.title(params['title'], fontproperties=fp_medium, fontsize=16, pad=20)
+    axis.set_xlabel(params["xlabel"], fontproperties=fp_light)
+    axis.set_ylabel(params["ylabel"], fontproperties=fp_light)
+    axis.set_xticklabels(params["xticks"], fontproperties=fp_light)
+    axis.set_yticklabels(
+        params["yticks"], fontproperties=fp_light, rotation=rotation, va="center"
+    )
+    plt.title(params["title"], fontproperties=fp_medium, fontsize=16, pad=20)
     return fig, axis, fp_light
 
 
@@ -119,13 +131,12 @@ def plot_rank_score(scores: dict, classes: list, metric: str, path: str):
     :param path: The path to the results directory.
     """
     params = {
-        'size': (12, 8),
-        'title': metric+" vs. confidence score for various IoU thresholds",
-        'xlabel': "Confidence score",
-        'ylabel': metric,
-        'xticks': [0, 0.50, 0.55, 0.60, 0.65, 0.70,
-                   0.75, 0.80, 0.85, 0.90, 0.95],
-        'yticks': [0, 0.2, 0.4, 0.6, 0.8, 1]
+        "size": (12, 8),
+        "title": metric + " vs. confidence score for various IoU thresholds",
+        "xlabel": "Confidence score",
+        "ylabel": metric,
+        "xticks": [0, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95],
+        "yticks": [0, 0.2, 0.4, 0.6, 0.8, 1],
     }
     colors = plt.cm.RdPu(np.linspace(0.2, 1, 10))
     for channel in classes:
@@ -133,28 +144,40 @@ def plot_rank_score(scores: dict, classes: list, metric: str, path: str):
         axis.grid(color="grey", alpha=0.2)
         axis.xaxis.set_major_locator(ticker.MultipleLocator(5))
         for index, iou in enumerate(range(50, 100, 5)):
-            if metric == 'Precision':
-                score = list(scores[channel]['precision'][iou].values())
-                rank = list(scores[channel]['precision'][iou].keys())
-            if metric == 'Recall':
-                score = list(scores[channel]['recall'][iou].values())
-                rank = list(scores[channel]['recall'][iou].keys())
-            if metric == 'F-score':
-                score = list(scores[channel]['fscore'][iou].values())
-                rank = list(scores[channel]['fscore'][iou].keys())
-            axis.plot(rank, score, label="{:.2f}".format(iou/100),
-                      alpha=1, color=colors[index], linewidth=2)
-            axis.scatter(rank, score, color=colors[index],
-                         facecolors='none', linewidth=1, marker='o')
+            if metric == "Precision":
+                score = list(scores[channel]["precision"][iou].values())
+                rank = list(scores[channel]["precision"][iou].keys())
+            if metric == "Recall":
+                score = list(scores[channel]["recall"][iou].values())
+                rank = list(scores[channel]["recall"][iou].keys())
+            if metric == "F-score":
+                score = list(scores[channel]["fscore"][iou].values())
+                rank = list(scores[channel]["fscore"][iou].keys())
+            axis.plot(
+                rank,
+                score,
+                label="{:.2f}".format(iou / 100),
+                alpha=1,
+                color=colors[index],
+                linewidth=2,
+            )
+            axis.scatter(
+                rank,
+                score,
+                color=colors[index],
+                facecolors="none",
+                linewidth=1,
+                marker="o",
+            )
         axis.set_xlim([49, 96])
         axis.set_ylim([0, 1])
         plt.legend(prop=fp_light, loc="lower left")
-        plt.savefig(os.path.join(path, metric+'_'+channel+'.png'),
-                    bbox_inches='tight')
+        plt.savefig(
+            os.path.join(path, metric + "_" + channel + ".png"), bbox_inches="tight"
+        )
 
 
-def plot_precision_recall_curve(object_metrics: dict, classes: list,
-                                path: str):
+def plot_precision_recall_curve(object_metrics: dict, classes: list, path: str):
     """
     Plot the precision-recall curve for different IoU thresholds.
     :param object_metrics: The computed precisions and recalls to plot.
@@ -162,26 +185,40 @@ def plot_precision_recall_curve(object_metrics: dict, classes: list,
     :param path: The path to the results directory.
     """
     params = {
-        'size': (12, 8),
-        'title': "Precision-recall curve for various IoU thresholds",
-        'xlabel': "Recall",
-        'ylabel': "Precision",
-        'xticks': [0, 0.2, 0.4, 0.6, 0.8, 1],
-        'yticks': [0, 0.2, 0.4, 0.6, 0.8, 1]
+        "size": (12, 8),
+        "title": "Precision-recall curve for various IoU thresholds",
+        "xlabel": "Recall",
+        "ylabel": "Precision",
+        "xticks": [0, 0.2, 0.4, 0.6, 0.8, 1],
+        "yticks": [0, 0.2, 0.4, 0.6, 0.8, 1],
     }
     colors = plt.cm.RdPu(np.linspace(0.2, 1, 10))
     for channel in classes:
         _, axis, fp_light = generate_figure(params)
         axis.grid(color="grey", alpha=0.2)
         for index, iou in enumerate(range(50, 100, 5)):
-            current_pr = list(object_metrics[channel]['precision'][iou].values())
-            current_rec = list(object_metrics[channel]['recall'][iou].values())
-            axis.plot(current_rec, current_pr, label="{:.2f}".format(iou/100),
-                      alpha=1, color=colors[index], linewidth=2)
-            axis.scatter(current_rec, current_pr, color=colors[index],
-                         facecolors='none', linewidth=1, marker='o')
+            current_pr = list(object_metrics[channel]["precision"][iou].values())
+            current_rec = list(object_metrics[channel]["recall"][iou].values())
+            axis.plot(
+                current_rec,
+                current_pr,
+                label="{:.2f}".format(iou / 100),
+                alpha=1,
+                color=colors[index],
+                linewidth=2,
+            )
+            axis.scatter(
+                current_rec,
+                current_pr,
+                color=colors[index],
+                facecolors="none",
+                linewidth=1,
+                marker="o",
+            )
         axis.set_xlim([0, 1])
         axis.set_ylim([0, 1])
         plt.legend(prop=fp_light, loc="lower right")
-        plt.savefig(os.path.join(path, 'Precision-recall_'+channel+'.png'),
-                    bbox_inches='tight')
+        plt.savefig(
+            os.path.join(path, "Precision-recall_" + channel + ".png"),
+            bbox_inches="tight",
+        )
diff --git a/training/utils/model.py b/training/utils/model.py
index c705220..ebb1305 100755
--- a/training/utils/model.py
+++ b/training/utils/model.py
@@ -8,11 +8,12 @@
     Use it to define, load and restore a model.
 """
 
-import sys
-import os
+import copy
 import logging
+import os
+import sys
 import time
-import copy
+
 import torch
 import torch.nn as nn
 from torch.cuda.amp import autocast
@@ -23,6 +24,7 @@ class Net(nn.Module):
     The Net class is used to generate a network.
     The class contains different useful layers.
     """
+
     def __init__(self, no_of_classes: int, use_amp: bool):
         """
         Constructor of the Net class.
@@ -40,8 +42,7 @@ class Net(nn.Module):
         self.conv_block1 = self.conv_block(256, 128)
         self.conv_block2 = self.conv_block(256, 64)
         self.conv_block3 = self.conv_block(128, 32)
-        self.last_conv = nn.Conv2d(64, no_of_classes, 3,
-                                   stride=1, padding=1)
+        self.last_conv = nn.Conv2d(64, no_of_classes, 3, stride=1, padding=1)
         self.softmax = nn.Softmax(dim=1)
 
     @staticmethod
@@ -55,18 +56,27 @@ class Net(nn.Module):
         :return: The sequence of the convolutions.
         """
         modules = []
-        modules.append(nn.Conv2d(input_size, output_size, 3, stride=1,
-                                 dilation=1, padding=1, bias=False))
-        modules.append(nn.BatchNorm2d(output_size,
-                                      track_running_stats=False))
+        modules.append(
+            nn.Conv2d(
+                input_size, output_size, 3, stride=1, dilation=1, padding=1, bias=False
+            )
+        )
+        modules.append(nn.BatchNorm2d(output_size, track_running_stats=False))
         modules.append(nn.ReLU(inplace=True))
         modules.append(nn.Dropout(p=0.4))
         for i in [2, 4, 8, 16]:
-            modules.append(nn.Conv2d(output_size, output_size, 3,
-                                     stride=1, dilation=i, padding=i,
-                                     bias=False))
-            modules.append(nn.BatchNorm2d(output_size,
-                                          track_running_stats=False))
+            modules.append(
+                nn.Conv2d(
+                    output_size,
+                    output_size,
+                    3,
+                    stride=1,
+                    dilation=i,
+                    padding=i,
+                    bias=False,
+                )
+            )
+            modules.append(nn.BatchNorm2d(output_size, track_running_stats=False))
             modules.append(nn.ReLU(inplace=True))
             modules.append(nn.Dropout(p=0.4))
         return nn.Sequential(*modules)
@@ -81,17 +91,16 @@ class Net(nn.Module):
         :return: The sequence of the convolutions.
         """
         return nn.Sequential(
-            nn.Conv2d(input_size, output_size, 3,
-                      stride=1, padding=1, bias=False),
+            nn.Conv2d(input_size, output_size, 3, stride=1, padding=1, bias=False),
             nn.BatchNorm2d(output_size, track_running_stats=False),
             nn.ReLU(inplace=True),
             nn.Dropout(p=0.4),
             # Does the upsampling.
-            nn.ConvTranspose2d(output_size, output_size,
-                               2, stride=2, bias=False),
+            nn.ConvTranspose2d(output_size, output_size, 2, stride=2, bias=False),
             nn.BatchNorm2d(output_size, track_running_stats=False),
             nn.ReLU(inplace=True),
-            nn.Dropout(p=0.4))
+            nn.Dropout(p=0.4),
+        )
 
     def forward(self, x):
         """
@@ -142,11 +151,11 @@ def load_network(no_of_classes: int, use_amp: bool, ex):
     net = Net(no_of_classes, use_amp)
     # Allow parallel running if more than 1 gpu available.
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-    logging.info('Running on %s', device)
+    logging.info("Running on %s", device)
     if torch.cuda.device_count() > 1:
         logging.info("Let's use %d GPUs", torch.cuda.device_count())
         net = nn.DataParallel(net)
-        ex.log_scalar('gpus.number', torch.cuda.device_count())
+        ex.log_scalar("gpus.number", torch.cuda.device_count())
     return net.to(device)
 
 
@@ -165,33 +174,39 @@ def restore_model(net, optimizer, scaler, log_path: str, model_path: str):
     """
     starting_time = time.time()
     if not os.path.isfile(os.path.join(log_path, model_path)):
-        logging.error('No model found at %s',
-                      os.path.join(log_path, model_path))
+        logging.error("No model found at %s", os.path.join(log_path, model_path))
         sys.exit()
     else:
         if torch.cuda.is_available():
             checkpoint = torch.load(os.path.join(log_path, model_path))
         else:
-            checkpoint = torch.load(os.path.join(log_path, model_path),
-                                    map_location=torch.device('cpu'))
+            checkpoint = torch.load(
+                os.path.join(log_path, model_path), map_location=torch.device("cpu")
+            )
         loaded_checkpoint = {}
         if torch.cuda.device_count() > 1:
             for key in checkpoint["state_dict"].keys():
-                if 'module' not in key:
-                    loaded_checkpoint['module.'+key] = checkpoint["state_dict"][key]
+                if "module" not in key:
+                    loaded_checkpoint["module." + key] = checkpoint["state_dict"][key]
                 else:
                     loaded_checkpoint = checkpoint["state_dict"]
         else:
             for key in checkpoint["state_dict"].keys():
-                loaded_checkpoint[key.replace("module.", "")] = checkpoint["state_dict"][key]
+                loaded_checkpoint[key.replace("module.", "")] = checkpoint[
+                    "state_dict"
+                ][key]
         net.load_state_dict(loaded_checkpoint)
 
         if optimizer is not None:
-            optimizer.load_state_dict(checkpoint['optimizer'])
+            optimizer.load_state_dict(checkpoint["optimizer"])
         if scaler is not None:
-            scaler.load_state_dict(checkpoint['scaler'])
-        logging.info('Loaded checkpoint %s (epoch %d) in %1.5fs',
-                     model_path, checkpoint['epoch'], (time.time() - starting_time))
+            scaler.load_state_dict(checkpoint["scaler"])
+        logging.info(
+            "Loaded checkpoint %s (epoch %d) in %1.5fs",
+            model_path,
+            checkpoint["epoch"],
+            (time.time() - starting_time),
+        )
         return checkpoint, net, optimizer, scaler
 
 
@@ -205,9 +220,11 @@ def save_model(epoch: int, model, loss: float, optimizer, scaler, filename: str)
     :param scaler: The scaler used for AMP.
     :param filename: The name of the model file.
     """
-    model_params = {'epoch': epoch,
-                    'state_dict': copy.deepcopy(model),
-                    'best_loss': loss,
-                    'optimizer': copy.deepcopy(optimizer),
-                    'scaler': scaler}
+    model_params = {
+        "epoch": epoch,
+        "state_dict": copy.deepcopy(model),
+        "best_loss": loss,
+        "optimizer": copy.deepcopy(optimizer),
+        "scaler": scaler,
+    }
     torch.save(model_params, filename)
diff --git a/training/utils/object_metrics.py b/training/utils/object_metrics.py
index b07568d..fec1393 100755
--- a/training/utils/object_metrics.py
+++ b/training/utils/object_metrics.py
@@ -67,8 +67,7 @@ def __rank_predicted_objects(labels: list, predictions: list) -> dict:
     ious = __get_ious(labels, predictions)
 
     scores = {index: prediction[0] for index, prediction in enumerate(predictions)}
-    tuples_score_iou = [(v, ious[k])
-                        for k, v in scores.items()]
+    tuples_score_iou = [(v, ious[k]) for k, v in scores.items()]
     scores = sorted(tuples_score_iou, key=lambda item: (-item[0], -item[1]))
     return scores
 
@@ -84,23 +83,17 @@ def compute_rank_scores(labels: list, predictions: list, classes: list) -> dict:
     :return scores: The scores obtained for a each rank, IoU
                     threshold and class.
     """
-    scores = {
-        channel: {
-            iou: None for iou in range(50, 100, 5)
-        } for channel in classes
-    }
+    scores = {channel: {iou: None for iou in range(50, 100, 5)} for channel in classes}
     for channel in classes:
         channel_scores = __rank_predicted_objects(labels[channel], predictions[channel])
         for iou in range(50, 100, 5):
-            rank_scores = {
-                rank: {'True': 0, 'Total': 0} for rank in range(95, -5, -5)
-            }
+            rank_scores = {rank: {"True": 0, "Total": 0} for rank in range(95, -5, -5)}
             for rank in range(95, -5, -5):
-                rank_objects = list(filter(
-                    lambda item: item[0] >= rank/100, channel_scores))
-                rank_scores[rank]['True'] = sum(x[1] > iou / 100
-                                                for x in rank_objects)
-                rank_scores[rank]['Total'] = len(rank_objects)
+                rank_objects = list(
+                    filter(lambda item: item[0] >= rank / 100, channel_scores)
+                )
+                rank_scores[rank]["True"] = sum(x[1] > iou / 100 for x in rank_objects)
+                rank_scores[rank]["Total"] = len(rank_objects)
             scores[channel][iou] = rank_scores
     return scores
 
@@ -116,24 +109,22 @@ def update_rank_scores(global_scores: dict, image_scores: dict, classes: list) -
     for channel in classes:
         for iou in range(50, 100, 5):
             for rank in range(95, -5, -5):
-                global_scores[channel][iou][rank]['True'] += \
-                    image_scores[channel][iou][rank]['True']
-                global_scores[channel][iou][rank]['Total'] += \
-                    image_scores[channel][iou][rank]['Total']
+                global_scores[channel][iou][rank]["True"] += image_scores[channel][iou][
+                    rank
+                ]["True"]
+                global_scores[channel][iou][rank]["Total"] += image_scores[channel][
+                    iou
+                ][rank]["Total"]
     return global_scores
 
 
 def __init_results() -> dict:
     """
-    Initialize the results dictionnary by generating dictionary for
+    Initialize the results dictionary by generating dictionary for
     the different rank and Intersection-over-Union thresholds.
-    :return: The initialized results dictionnary.
+    :return: The initialized results dictionary.
     """
-    return {
-        iou: {
-            rank: 0 for rank in range(95, -5, -5)
-        } for iou in range(50, 100, 5)
-    }
+    return {iou: {rank: 0 for rank in range(95, -5, -5)} for iou in range(50, 100, 5)}
 
 
 def __get_average_precision(precisions: list, recalls: list) -> float:
@@ -154,20 +145,22 @@ def __get_average_precision(precisions: list, recalls: list) -> float:
         max_precision = np.max(precisions)
         argmax_precision = np.argmax(precisions)
         max_recall = recalls[argmax_precision]
-        rp_tuples.append({'p': max_precision, 'r': max_recall})
-        for _ in range(argmax_precision+1):
+        rp_tuples.append({"p": max_precision, "r": max_recall})
+        for _ in range(argmax_precision + 1):
             precisions.pop(0)
             recalls.pop(0)
-    rp_tuples[-1]['r'] = 1
+    rp_tuples[-1]["r"] = 1
 
-    ps = [rp_tuple['p'] for rp_tuple in rp_tuples]
-    rs = [rp_tuple['r'] for rp_tuple in rp_tuples]
+    ps = [rp_tuple["p"] for rp_tuple in rp_tuples]
+    rs = [rp_tuple["r"] for rp_tuple in rp_tuples]
     ps.insert(0, ps[0])
     rs.insert(0, 0)
     return np.trapz(ps, x=rs)
 
-def get_mean_results(global_scores: dict, true_gt: dict, classes: list,
-                     results: dict) -> dict:
+
+def get_mean_results(
+    global_scores: dict, true_gt: dict, classes: list, results: dict
+) -> dict:
     """
     Get the mean metrics values for all the set.
     :param global_scores: The overall computed scores.
@@ -183,21 +176,27 @@ def get_mean_results(global_scores: dict, true_gt: dict, classes: list,
         aps = {iou: 0 for iou in range(50, 100, 5)}
         for iou in range(50, 100, 5):
             for rank in range(95, -5, -5):
-                true_predicted = global_scores[channel][iou][rank]['True']
-                predicted = global_scores[channel][iou][rank]['Total']
+                true_predicted = global_scores[channel][iou][rank]["True"]
+                predicted = global_scores[channel][iou][rank]["Total"]
 
-                precisions[iou][rank] = true_predicted / predicted if predicted != 0 else 1
-                recalls[iou][rank] = true_predicted / true_gt[channel] if true_gt[channel] != 0 else 1
+                precisions[iou][rank] = (
+                    true_predicted / predicted if predicted != 0 else 1
+                )
+                recalls[iou][rank] = (
+                    true_predicted / true_gt[channel] if true_gt[channel] != 0 else 1
+                )
 
                 if precisions[iou][rank] + recalls[iou][rank] != 0:
-                    fscores[iou][rank] = 2 * \
-                        (precisions[iou][rank] * recalls[iou][rank]) / \
-                        (precisions[iou][rank] + recalls[iou][rank])
+                    fscores[iou][rank] = (
+                        2
+                        * (precisions[iou][rank] * recalls[iou][rank])
+                        / (precisions[iou][rank] + recalls[iou][rank])
+                    )
             aps[iou] = __get_average_precision(
-                list(precisions[iou].values()),
-                list(recalls[iou].values()))
-            results[channel]['precision'] = precisions
-            results[channel]['recall'] = recalls
-            results[channel]['fscore'] = fscores
-            results[channel]['AP'] = aps
+                list(precisions[iou].values()), list(recalls[iou].values())
+            )
+            results[channel]["precision"] = precisions
+            results[channel]["recall"] = recalls
+            results[channel]["fscore"] = fscores
+            results[channel]["AP"] = aps
     return results
diff --git a/training/utils/params_config.py b/training/utils/params_config.py
index 646a694..6b60bc6 100755
--- a/training/utils/params_config.py
+++ b/training/utils/params_config.py
@@ -13,6 +13,7 @@ class BaseParams:
     """
     This is a global class for the configuration parameters.
     """
+
     def to_dict(self):
         """
         Maps the class attributes to a dictionary.
@@ -41,19 +42,20 @@ class Params(BaseParams):
     :param str: Path to the file containing the standard deviation values of training set.
     :param model_path: Path to store the obtained model.
     :param train_image_path: Path to the directory containing training images.
-    :param train_mask_path: Path to the directory containing training masks. 
-    :param val_image_path: Path to the directory containing validation images. 
-    :param val_mask_path: Path to the directory containing validation masks. 
-    :param test_image_path: Path to the directory containing testing images. 
+    :param train_mask_path: Path to the directory containing training masks.
+    :param val_image_path: Path to the directory containing validation images.
+    :param val_mask_path: Path to the directory containing validation masks.
+    :param test_image_path: Path to the directory containing testing images.
     :param train_image_path: Path to the directory containing training images.
     :param classes_files: File containing the color codes of the classes involved
                           in the experiment.
-    ;param prediction_path: Path to the directory to save the predictions. 
+    ;param prediction_path: Path to the directory to save the predictions.
     """
+
     def __init__(self, **kwargs):
 
-        self.mean = kwargs.get('mean', 'mean')
-        self.std = kwargs.get('std', 'std')
-        self.model_path = kwargs.get('model_path', 'model.pth')
-        self.prediction_path = kwargs.get('prediction_path', 'prediction')
-        self.evaluation_path = kwargs.get('evaluation_path', 'results')
+        self.mean = kwargs.get("mean", "mean")
+        self.std = kwargs.get("std", "std")
+        self.model_path = kwargs.get("model_path", "model.pth")
+        self.prediction_path = kwargs.get("prediction_path", "prediction")
+        self.evaluation_path = kwargs.get("evaluation_path", "results")
diff --git a/training/utils/pixel_metrics.py b/training/utils/pixel_metrics.py
index a290ca4..e317d44 100755
--- a/training/utils/pixel_metrics.py
+++ b/training/utils/pixel_metrics.py
@@ -16,8 +16,9 @@
 import numpy as np
 
 
-def compute_metrics(labels: list, predictions: list, classes: list,
-                    global_metrics: dict) -> dict:
+def compute_metrics(
+    labels: list, predictions: list, classes: list, global_metrics: dict
+) -> dict:
     """
     Compute the pixel level metrics between prediction and label areas of
     a given page.
@@ -35,17 +36,17 @@ def compute_metrics(labels: list, predictions: list, classes: list,
         gt_area = np.sum([gt.area for _, gt in labels[channel]])
         pred_area = np.sum([pred.area for _, pred in predictions[channel]])
 
-        global_metrics[channel]['iou'].append(get_iou(inter, gt_area, pred_area))
+        global_metrics[channel]["iou"].append(get_iou(inter, gt_area, pred_area))
         precision = get_precision(inter, pred_area)
         recall = get_recall(inter, gt_area)
-        global_metrics[channel]['precision'].append(precision)
-        global_metrics[channel]['recall'].append(recall)
+        global_metrics[channel]["precision"].append(precision)
+        global_metrics[channel]["recall"].append(recall)
         if precision + recall != 0:
-            global_metrics[channel]['fscore'].append(
-                2*precision*recall / (precision+recall)
+            global_metrics[channel]["fscore"].append(
+                2 * precision * recall / (precision + recall)
             )
         else:
-            global_metrics[channel]['fscore'].append(0)
+            global_metrics[channel]["fscore"].append(0)
     return global_metrics
 
 
@@ -66,7 +67,7 @@ def get_iou(intersection: float, label_area: float, predicted_area: float) -> fl
     if intersection == 0 and union != 0:
         return 0
     # Objects to detect and/or predicted that intersect.
-    return intersection / union  
+    return intersection / union
 
 
 def get_precision(intersection: float, predicted_area: float) -> float:
@@ -95,5 +96,3 @@ def get_recall(intersection, label_area) -> float:
     if label_area == 0:
         return 1
     return intersection / label_area
-
-    
diff --git a/training/utils/prediction_utils.py b/training/utils/prediction_utils.py
index 5eb3909..79ccdc3 100755
--- a/training/utils/prediction_utils.py
+++ b/training/utils/prediction_utils.py
@@ -4,18 +4,20 @@
 """
     The prediction utils module
     ======================
-    
+
     Use it to during the prediction stage.
 """
 
-import cv2
 import json
-import numpy as np
+
+import cv2
 import imageio as io
+import numpy as np
 
 
-def resize_polygons(polygons: dict, image_size: tuple, input_size: tuple,
-                    padding: tuple) -> dict:
+def resize_polygons(
+    polygons: dict, image_size: tuple, input_size: tuple, padding: tuple
+) -> dict:
     """
     Resize the detected polygons to the original input image size.
     :param polygons: The polygons to resize.
@@ -32,22 +34,33 @@ def resize_polygons(polygons: dict, image_size: tuple, input_size: tuple,
 
     for channel in polygons.keys():
         for index, polygon in enumerate(polygons[channel]):
-            x_points = [element[0][1] for element in polygon['polygon']]
-            y_points = [element[0][0] for element in polygon['polygon']]
-            x_points = [int((element - padding['top']) * ratio[0]) for element in x_points]
-            y_points = [int((element - padding['left']) * ratio[1]) for element in y_points]
-            
-            x_points = [int(element) if element < image_size[0] else int(image_size[0]) for element in x_points]
-            y_points = [int(element) if element < image_size[1] else int(image_size[1]) for element in y_points]
+            x_points = [element[0][1] for element in polygon["polygon"]]
+            y_points = [element[0][0] for element in polygon["polygon"]]
+            x_points = [
+                int((element - padding["top"]) * ratio[0]) for element in x_points
+            ]
+            y_points = [
+                int((element - padding["left"]) * ratio[1]) for element in y_points
+            ]
+
+            x_points = [
+                int(element) if element < image_size[0] else int(image_size[0])
+                for element in x_points
+            ]
+            y_points = [
+                int(element) if element < image_size[1] else int(image_size[1])
+                for element in y_points
+            ]
             x_points = [int(element) if element > 0 else 0 for element in x_points]
             y_points = [int(element) if element > 0 else 0 for element in y_points]
-            assert(max(x_points) <= image_size[0])
-            assert(min(x_points) >= 0)                              
-            assert(max(y_points) <= image_size[1])
-            assert(min(y_points) >= 0)                              
-            polygons[channel][index]['polygon'] = list(zip(y_points, x_points))
+            assert max(x_points) <= image_size[0]
+            assert min(x_points) >= 0
+            assert max(y_points) <= image_size[1]
+            assert min(y_points) >= 0
+            polygons[channel][index]["polygon"] = list(zip(y_points, x_points))
     return polygons
 
+
 def compute_confidence(region: np.ndarray, probas: np.ndarray) -> float:
     """
     Compute the confidence score of a detected polygon.
@@ -60,6 +73,7 @@ def compute_confidence(region: np.ndarray, probas: np.ndarray) -> float:
     confidence = np.sum(mask * probas) / np.sum(mask)
     return round(confidence, 4)
 
+
 # Save the prediction coordinates and images.
 
 
@@ -71,25 +85,25 @@ def save_prediction(polygons: dict, filename: str):
                      confidence scores.
     :param filename: The filename to save the detected polygons.
     """
-    with open(filename.replace('png', 'json'), 'w') as outfile:
+    with open(filename.replace("png", "json"), "w") as outfile:
         json.dump(polygons, outfile, indent=4)
 
+
 def save_prediction_image(polygons, colors, input_size, filename: str):
     """
     Save the detected polygon to an image.
     :param polygons: The detected polygons coordinates.
     :param colors: The colors corresponding to each involved class.
-    :param input_size: The origianl input image size. 
+    :param input_size: The original input image size.
     :param filename: The filename to save the prediction image.
     """
     image = np.zeros((input_size[0], input_size[1], 3))
     index = 1
     for channel in polygons.keys():
-        if channel == 'img_size':
+        if channel == "img_size":
             continue
         color = [int(element) for element in colors[index]]
         for polygon in polygons[channel]:
-            cv2.drawContours(image, [np.array(polygon['polygon'])], 0, color, -1)
+            cv2.drawContours(image, [np.array(polygon["polygon"])], 0, color, -1)
         index += 1
     io.imsave(filename, np.uint8(image))
-
diff --git a/training/utils/preprocessing.py b/training/utils/preprocessing.py
index a399569..93a20fb 100755
--- a/training/utils/preprocessing.py
+++ b/training/utils/preprocessing.py
@@ -9,19 +9,24 @@
 """
 
 import os
+
 import cv2
-import torch
 import numpy as np
+import torch
 from torch.utils.data import Dataset
+
 from .utils import rgb_to_gray_array, rgb_to_gray_value
 
+
 class TrainingDataset(Dataset):
     """
     The TrainingDataset class is used to prepare the images and labels to
     run training step.
     """
-    def __init__(self, images_dir: str, masks_dir: str, colors: list,
-                 transform: list = None):
+
+    def __init__(
+        self, images_dir: str, masks_dir: str, colors: list, transform: list = None
+    ):
         """
         Constructor of the TrainingDataset class.
         :param images_dir: The directories containing the images.
@@ -31,15 +36,12 @@ class TrainingDataset(Dataset):
         """
         self.images_dir = images_dir
         self.images = [
-            (dir.parent.parent.name, dir/element)
+            (dir.parent.parent.name, dir / element)
             for dir in self.images_dir
             for element in os.listdir(dir)
         ]
         self.masks_dir = masks_dir
-        self.masks = {
-            dir.parent.parent.name: dir
-            for dir in self.masks_dir
-        }
+        self.masks = {dir.parent.parent.name: dir for dir in self.masks_dir}
         self.colors = colors
         self.transform = transform
 
@@ -65,25 +67,25 @@ class TrainingDataset(Dataset):
             image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
-        label = cv2.imread(str(self.masks[img_name[0]]/img_name[1].name))
+        label = cv2.imread(str(self.masks[img_name[0]] / img_name[1].name))
         if len(label.shape) < 3:
             label = cv2.cvtColor(label, cv2.COLOR_GRAY2RGB)
         label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
         label = rgb_to_gray_array(label)
-        
+
         # Transform the label into a categorical label.
         new_label = np.zeros_like(label)
         for index, value in enumerate(self.colors):
             color = rgb_to_gray_value(value)
             new_label[label == color] = index
 
-        sample = {'image': image, 'mask': new_label, 'size': image.shape[0:2]}
+        sample = {"image": image, "mask": new_label, "size": image.shape[0:2]}
 
         # Apply the transformations.
         if self.transform:
             sample = self.transform(sample)
 
-        sample['size'] = sample['image'].shape[0:2]
+        sample["size"] = sample["image"].shape[0:2]
 
         return sample
 
@@ -93,6 +95,7 @@ class PredictionDataset(Dataset):
     The PredictionDataset class is used to prepare the images to
     run prediction step.
     """
+
     def __init__(self, images_dir: str, transform: list = None):
         """
         Constructor of the PredictionDataset class.
@@ -101,7 +104,7 @@ class PredictionDataset(Dataset):
         """
         self.images_dir = images_dir
         self.images = [
-            (dir.parent.parent.name, dir/element)
+            (dir.parent.parent.name, dir / element)
             for dir in self.images_dir
             for element in os.listdir(dir)
         ]
@@ -129,8 +132,12 @@ class PredictionDataset(Dataset):
             image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
-        sample = {'image': image, 'name': img_name.name,
-                  'dataset': self.images[idx][0], 'size': image.shape[0:2]}
+        sample = {
+            "image": image,
+            "name": img_name.name,
+            "dataset": self.images[idx][0],
+            "size": image.shape[0:2],
+        }
 
         # Apply the transformations.
         if self.transform:
@@ -141,11 +148,12 @@ class PredictionDataset(Dataset):
 # Transformations
 
 
-class Rescale():
+class Rescale:
     """
     The Rescale class is used to rescale the image of a sample into a
     given size.
     """
+
     def __init__(self, output_size: int):
         """
         Constructor of the Rescale class.
@@ -160,28 +168,29 @@ class Rescale():
         :param sample: The sample to rescale.
         :return sample: The rescaled sample.
         """
-        old_size = sample['image'].shape[:2]
+        old_size = sample["image"].shape[:2]
         # Compute the new sizes.
         ratio = float(self.output_size) / max(old_size)
         new_size = [int(x * ratio) for x in old_size]
 
         # Resize the image.
         if max(old_size) != self.output_size:
-            image = cv2.resize(sample['image'], (new_size[1], new_size[0]))
-            sample['image'] = image
+            image = cv2.resize(sample["image"], (new_size[1], new_size[0]))
+            sample["image"] = image
 
         # Resize the label. MUST BE AVOIDED.
-        if 'mask' in sample.keys():
-            if max(sample['mask'].shape[:2]) != self.output_size:
-                mask = cv2.resize(sample['mask'], (new_size[1], new_size[0]))
-                sample['mask'] = mask
+        if "mask" in sample.keys():
+            if max(sample["mask"].shape[:2]) != self.output_size:
+                mask = cv2.resize(sample["mask"], (new_size[1], new_size[0]))
+                sample["mask"] = mask
         return sample
 
 
-class Pad():
+class Pad:
     """
     The Pad class is used to pad the image of a sample to make it divisible by 8.
     """
+
     def __init__(self):
         """
         Constructor of the Pad class.
@@ -197,28 +206,42 @@ class Pad():
         # Compute the padding parameters.
         delta_w = 0
         delta_h = 0
-        if sample['image'].shape[0] % 8 != 0:
-            delta_h = int(8 * np.ceil(sample['image'].shape[0] / 8)) - sample['image'].shape[0]
-        if sample['image'].shape[1] % 8 != 0:
-            delta_w = int(8 * np.ceil(sample['image'].shape[1] / 8)) - sample['image'].shape[1]
+        if sample["image"].shape[0] % 8 != 0:
+            delta_h = (
+                int(8 * np.ceil(sample["image"].shape[0] / 8))
+                - sample["image"].shape[0]
+            )
+        if sample["image"].shape[1] % 8 != 0:
+            delta_w = (
+                int(8 * np.ceil(sample["image"].shape[1] / 8))
+                - sample["image"].shape[1]
+            )
 
         top, bottom = delta_h // 2, delta_h - (delta_h // 2)
         left, right = delta_w // 2, delta_w - (delta_w // 2)
 
         # Add padding to have same size images.
-        image = cv2.copyMakeBorder(sample['image'], top, bottom, left, right,
-                                   cv2.BORDER_CONSTANT, value=[0, 0, 0])
-        sample['image'] = image
-        sample['padding'] = {'top': top, 'left': left}
+        image = cv2.copyMakeBorder(
+            sample["image"],
+            top,
+            bottom,
+            left,
+            right,
+            cv2.BORDER_CONSTANT,
+            value=[0, 0, 0],
+        )
+        sample["image"] = image
+        sample["padding"] = {"top": top, "left": left}
         return sample
 
 
-class Normalize():
+class Normalize:
     """
     The Normalize class is used to normalize the image of a sample.
     The mean value and standard deviation must be first computed on the
     training dataset.
     """
+
     def __init__(self, mean: list, std: list):
         """
         Constructor of the Normalize class.
@@ -238,26 +261,27 @@ class Normalize():
         :param sample: The sample with the image to normalize.
         :return sample: The sample with the normalized image.
         """
-        image = np.zeros(sample['image'].shape)
-        for channel in range(sample['image'].shape[2]):
-            image[:, :, channel] = (np.float32(sample['image'][:, :, channel])
-                                        - self.mean[channel]) \
-                                        / self.std[channel]
-        sample['image'] = image
+        image = np.zeros(sample["image"].shape)
+        for channel in range(sample["image"].shape[2]):
+            image[:, :, channel] = (
+                np.float32(sample["image"][:, :, channel]) - self.mean[channel]
+            ) / self.std[channel]
+        sample["image"] = image
         return sample
 
 
-class ToTensor():
+class ToTensor:
     """
     The ToTensor class is used convert ndarrays into Tensors.
     """
+
     def __call__(self, sample: dict) -> dict:
         """
         Transform the sample image and label into Tensors.
         :param sample: The initial sample.
         :return sample: The sample made of Tensors.
         """
-        sample['image'] = torch.from_numpy(sample['image'].transpose((2, 0, 1)))
-        if 'mask' is sample.keys():
-            sample['mask'] = torch.from_numpy(sample['mask'])
+        sample["image"] = torch.from_numpy(sample["image"].transpose((2, 0, 1)))
+        if "mask" is sample.keys():
+            sample["mask"] = torch.from_numpy(sample["mask"])
         return sample
diff --git a/training/utils/training_pixel_metrics.py b/training/utils/training_pixel_metrics.py
index cf1a460..868df34 100755
--- a/training/utils/training_pixel_metrics.py
+++ b/training/utils/training_pixel_metrics.py
@@ -14,8 +14,9 @@
 import numpy as np
 
 
-def compute_metrics(pred: np.ndarray, label: np.ndarray, loss: float,
-                    classes: list) -> dict:
+def compute_metrics(
+    pred: np.ndarray, label: np.ndarray, loss: float, classes: list
+) -> dict:
     """
     Compute the metrics between a prediction and a label mask.
     :param pred: The prediction made by the network.
@@ -37,15 +38,14 @@ def update_metrics(metrics: dict, batch_metrics: dict) -> dict:
     :param batch_metrics: The current batch metrics.
     :return metrics: The updated global metrics.
     """
-    for i in range(metrics['matrix'].shape[0]):
-        for j in range(metrics['matrix'].shape[1]):  
-            metrics['matrix'][i][j] += batch_metrics['matrix'][i][j]
-    metrics['loss'] += batch_metrics['loss']
+    for i in range(metrics["matrix"].shape[0]):
+        for j in range(metrics["matrix"].shape[1]):
+            metrics["matrix"][i][j] += batch_metrics["matrix"][i][j]
+    metrics["loss"] += batch_metrics["loss"]
     return metrics
 
 
-def confusion_matrix(pred: np.ndarray, label: np.ndarray,
-                     classes: list) -> np.array:
+def confusion_matrix(pred: np.ndarray, label: np.ndarray, classes: list) -> np.array:
     """
     Get the confusion matrix between the prediction and the given label.
     :param pred: The prediction made by the network.
@@ -59,7 +59,7 @@ def confusion_matrix(pred: np.ndarray, label: np.ndarray,
         for j in range(size):
             bin_label = label == i
             bin_pred = pred == j
-            confusion_matrix[j, i] = (bin_pred*bin_label).sum()
+            confusion_matrix[j, i] = (bin_pred * bin_label).sum()
     return confusion_matrix
 
 
@@ -78,6 +78,5 @@ def iou(confusion_matrix: np.ndarray, channel: str) -> float:
     if true_positives == 0:
         return 0
     elif tpfn + tpfp == true_positives:
-    	return 0
+        return 0
     return true_positives / (tpfn + tpfp - true_positives)
-
diff --git a/training/utils/training_utils.py b/training/utils/training_utils.py
index 9218461..b169f27 100755
--- a/training/utils/training_utils.py
+++ b/training/utils/training_utils.py
@@ -8,8 +8,8 @@
     Use it to during the training stage.
 """
 
-import numpy as np
 import matplotlib.pyplot as plt
+import numpy as np
 import torch
 import torch.nn as nn
 import utils.training_pixel_metrics as p_metrics
@@ -21,6 +21,7 @@ class Diceloss(nn.Module):
     """
     The Diceloss class is used during training.
     """
+
     def __init__(self, num_classes: int):
         """
         Constructor of the Diceloss class.
@@ -30,21 +31,26 @@ class Diceloss(nn.Module):
         self.num_classes = num_classes
 
     def forward(self, pred: np.ndarray, target: np.ndarray) -> float:
-       """
-       Compute the Dice loss between a label and a prediction mask.
-       :param pred: The prediction made by the network.
-       :param target: The label mask.
-       :return: The Dice loss.
-       """
-       label = nn.functional.one_hot(target, num_classes=self.num_classes).permute(0,3,1,2).contiguous()
-
-       smooth = 1.
-       iflat = pred.contiguous().view(-1)
-       tflat = label.contiguous().view(-1)
-       intersection = (iflat * tflat).sum()
-       A_sum = torch.sum(iflat * iflat)
-       B_sum = torch.sum(tflat * tflat)
-       return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
+        """
+        Compute the Dice loss between a label and a prediction mask.
+        :param pred: The prediction made by the network.
+        :param target: The label mask.
+        :return: The Dice loss.
+        """
+        label = (
+            nn.functional.one_hot(target, num_classes=self.num_classes)
+            .permute(0, 3, 1, 2)
+            .contiguous()
+        )
+
+        smooth = 1.0
+        iflat = pred.contiguous().view(-1)
+        tflat = label.contiguous().view(-1)
+        intersection = (iflat * tflat).sum()
+        A_sum = torch.sum(iflat * iflat)
+        B_sum = torch.sum(tflat * tflat)
+        return 1 - ((2.0 * intersection + smooth) / (A_sum + B_sum + smooth))
+
 
 # Plot the prediction during training.
 
@@ -56,8 +62,7 @@ def plot_prediction(output: np.ndarray) -> np.ndarray:
     :param output: The predictions of the batch images.
     :return prediction: The array of categorical predictions.
     """
-    prediction = np.zeros((output.shape[0], 1, output.shape[2],
-                           output.shape[3]))
+    prediction = np.zeros((output.shape[0], 1, output.shape[2], output.shape[3]))
     for pred in range(output.shape[0]):
         current_pred = output[pred, :, :, :]
         new = np.argmax(current_pred, axis=0)
@@ -66,8 +71,14 @@ def plot_prediction(output: np.ndarray) -> np.ndarray:
     return prediction
 
 
-def display_training(output: np.ndarray, image: np.ndarray, label: np.ndarray,
-                     writer, epoch: int, norm_params: list):
+def display_training(
+    output: np.ndarray,
+    image: np.ndarray,
+    label: np.ndarray,
+    writer,
+    epoch: int,
+    norm_params: list,
+):
     """
     Define the figure to plot a batch images, labels and current predictions.
     Add it to Tensorboard.
@@ -80,26 +91,30 @@ def display_training(output: np.ndarray, image: np.ndarray, label: np.ndarray,
                         to normalize the images.
     """
     predictions = plot_prediction(output.cpu().detach().numpy())
-    fig, axs = plt.subplots(predictions.shape[0], 3,
-                            figsize=(10, 3*predictions.shape[0]),
-                            gridspec_kw={'hspace': 0.2, 'wspace': 0.05})
+    fig, axs = plt.subplots(
+        predictions.shape[0],
+        3,
+        figsize=(10, 3 * predictions.shape[0]),
+        gridspec_kw={"hspace": 0.2, "wspace": 0.05},
+    )
     for pred in range(predictions.shape[0]):
         current_input = image.cpu().numpy()[pred, :, :, :]
         current_input = current_input.transpose((1, 2, 0))
         for channel in range(current_input.shape[2]):
-            current_input[:, :, channel] = (current_input[:, :, channel]
-                                            * norm_params['std'][channel]) \
-                                            + norm_params['mean'][channel]
+            current_input[:, :, channel] = (
+                current_input[:, :, channel] * norm_params["std"][channel]
+            ) + norm_params["mean"][channel]
         if predictions.shape[0] > 1:
             axs[pred, 0].imshow(current_input.astype(np.uint8))
-            axs[pred, 1].imshow(label.cpu()[pred, :, :], cmap='gray')
-            axs[pred, 2].imshow(predictions[pred, 0, :, :], cmap='gray')
+            axs[pred, 1].imshow(label.cpu()[pred, :, :], cmap="gray")
+            axs[pred, 2].imshow(predictions[pred, 0, :, :], cmap="gray")
         else:
             axs[0].imshow(current_input.astype(np.uint8))
-            axs[1].imshow(label.cpu()[pred, :, :], cmap='gray')
-            axs[2].imshow(predictions[pred, 0, :, :], cmap='gray')
+            axs[1].imshow(label.cpu()[pred, :, :], cmap="gray")
+            axs[2].imshow(predictions[pred, 0, :, :], cmap="gray")
     _ = [axi.set_axis_off() for axi in axs.ravel()]
-    writer.add_figure('Image_Label_Prediction', fig, global_step=epoch)
+    writer.add_figure("Image_Label_Prediction", fig, global_step=epoch)
+
 
 # Display the metrics during training.
 
@@ -114,6 +129,8 @@ def get_epoch_values(metrics: dict, classes: list, batch: int) -> dict:
     """
     values = {}
     for channel in classes[1:]:
-        values['iou_'+channel] = round(p_metrics.iou(metrics['matrix'], classes.index(channel)), 6)
+        values["iou_" + channel] = round(
+            p_metrics.iou(metrics["matrix"], classes.index(channel)), 6
+        )
     values["loss"] = metrics["loss"] / batch
     return values
diff --git a/training/utils/utils.py b/training/utils/utils.py
index 5dfc778..47e4051 100755
--- a/training/utils/utils.py
+++ b/training/utils/utils.py
@@ -9,9 +9,10 @@
 """
 
 import copy
-import torch
 import random
+
 import numpy as np
+import torch
 
 # Useful functions.
 
@@ -25,8 +26,7 @@ def rgb_to_gray_value(rgb: tuple) -> int:
     try:
         return int(rgb[0] * 0.299 + rgb[1] * 0.587 + rgb[2] * 0.114)
     except TypeError:
-        return int(int(rgb[0]) * 0.299 + int(rgb[1]) * 0.587 +
-                   int(rgb[2]) * 0.114)
+        return int(int(rgb[0]) * 0.299 + int(rgb[1]) * 0.587 + int(rgb[2]) * 0.114)
 
 
 def rgb_to_gray_array(rgb: np.ndarray) -> np.ndarray:
@@ -35,8 +35,7 @@ def rgb_to_gray_array(rgb: np.ndarray) -> np.ndarray:
     :param rgb: The RGB array to transform.
     :return: The corresponding gray array.
     """
-    gray_array = rgb[:, :, 0] * 0.299 + rgb[:, :, 1] * 0.587 \
-        + rgb[:, :, 2] * 0.114
+    gray_array = rgb[:, :, 0] * 0.299 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114
     return np.uint8(gray_array)
 
 
@@ -52,7 +51,7 @@ def create_buckets(images_sizes, bin_size):
 
     bucket = {}
     current = min_size + bin_size - 1
-    while(current < max_size):
+    while current < max_size:
         bucket[current] = []
         current += bin_size
     bucket[max_size] = []
@@ -61,38 +60,50 @@ def create_buckets(images_sizes, bin_size):
         dict_index = (((value - min_size) // bin_size) + 1) * bin_size + min_size - 1
         bucket[min(dict_index, max_size)].append(index)
 
-    bucket = {dict_index: values for dict_index, values in bucket.items() if len(values) > 0}
+    bucket = {
+        dict_index: values for dict_index, values in bucket.items() if len(values) > 0
+    }
     return bucket
 
 
 class Sampler(torch.utils.data.Sampler):
-
     def __init__(self, data, bin_size=20, batch_size=None, nb_params=None):
 
         self.bin_size = bin_size
         self.batch_size = batch_size
         self.nb_params = nb_params
-        
-        self.data_sizes = [image['size'] for image in data]
 
-        self.vertical = {index: image['size'][1] for index, image in enumerate(data) if image['size'][0] > image['size'][1]}
-        self.horizontal = {index: image['size'][0] for index, image in enumerate(data) if image['size'][0] <= image['size'][1]}
+        self.data_sizes = [image["size"] for image in data]
+
+        self.vertical = {
+            index: image["size"][1]
+            for index, image in enumerate(data)
+            if image["size"][0] > image["size"][1]
+        }
+        self.horizontal = {
+            index: image["size"][0]
+            for index, image in enumerate(data)
+            if image["size"][0] <= image["size"][1]
+        }
 
         self.buckets = [
-            create_buckets(self.vertical, self.bin_size) if len(self.vertical) > 0 else {},
-            create_buckets(self.horizontal, self.bin_size) if len(self.horizontal) > 0 else {},
+            create_buckets(self.vertical, self.bin_size)
+            if len(self.vertical) > 0
+            else {},
+            create_buckets(self.horizontal, self.bin_size)
+            if len(self.horizontal) > 0
+            else {},
         ]
 
-    def __len__ (self):
-        return (len(self.vertical) + len(self.horizontal))
-
+    def __len__(self):
+        return len(self.vertical) + len(self.horizontal)
 
     def __iter__(self):
         buckets = copy.deepcopy(self.buckets)
         for index, bucket in enumerate(buckets):
             for key in bucket.keys():
                 random.shuffle(buckets[index][key])
-        
+
         if self.batch_size is not None and self.nb_params is None:
             final_indices = []
             index_current = -1
@@ -107,7 +118,7 @@ class Sampler(torch.utils.data.Sampler):
                         current_batch_size += 1
                         final_indices[index_current].append(index)
             random.shuffle(final_indices)
-        
+
         elif self.nb_params is not None:
             final_indices = []
             index_current = -1
@@ -115,7 +126,9 @@ class Sampler(torch.utils.data.Sampler):
                 current_params = self.nb_params
                 for key in sorted(bucket.keys(), reverse=True):
                     for index in bucket[key]:
-                        element_params = self.data_sizes[index][0] * self.data_sizes[index][1] * 3
+                        element_params = (
+                            self.data_sizes[index][0] * self.data_sizes[index][1] * 3
+                        )
                         if current_params + element_params > self.nb_params:
                             current_params = 0
                             final_indices.append([])
@@ -127,7 +140,9 @@ class Sampler(torch.utils.data.Sampler):
         return iter(final_indices)
 
 
-def pad_images_masks(images: list, masks: list, image_padding_value: int, mask_padding_value: int):
+def pad_images_masks(
+    images: list, masks: list, image_padding_value: int, mask_padding_value: int
+):
     """
     Pad images and masks to create batchs.
     :param images: The batch images to pad.
@@ -148,30 +163,43 @@ def pad_images_masks(images: list, masks: list, image_padding_value: int, mask_p
     if max_width % 8 != 0:
         max_width = int(8 * np.ceil(max_width / 8))
 
-    padded_images = np.ones((len(images), max_height, max_width, images[0].shape[2])) * image_padding_value
+    padded_images = (
+        np.ones((len(images), max_height, max_width, images[0].shape[2]))
+        * image_padding_value
+    )
     padded_masks = np.ones((len(masks), max_height, max_width)) * mask_padding_value
     for index, (image, mask) in enumerate(zip(images, masks)):
         delta_h = max_height - image.shape[0]
         delta_w = max_width - image.shape[1]
         top, bottom = delta_h // 2, delta_h - (delta_h // 2)
         left, right = delta_w // 2, delta_w - (delta_w // 2)
-        padded_images[index, top:padded_images.shape[1]-bottom, left:padded_images.shape[2]-right, :] = image
-        padded_masks[index, top:padded_masks.shape[1]-bottom, left:padded_masks.shape[2]-right] = mask
+        padded_images[
+            index,
+            top : padded_images.shape[1] - bottom,
+            left : padded_images.shape[2] - right,
+            :,
+        ] = image
+        padded_masks[
+            index,
+            top : padded_masks.shape[1] - bottom,
+            left : padded_masks.shape[2] - right,
+        ] = mask
 
     return padded_images, padded_masks
 
 
 class DLACollateFunction:
-
     def __init__(self):
         self.image_padding_token = 0
         self.mask_padding_token = 0
-        
-    def __call__(self, batch):
-        image = [item['image'] for item in batch]
-        mask = [item['mask'] for item in batch]
-        pad_image, pad_mask = pad_images_masks(image, mask,
-            self.image_padding_token, self.mask_padding_token)
-        return {'image': torch.tensor(pad_image).permute(0, 3, 1, 2),
-                'mask': torch.tensor(pad_mask)}
 
+    def __call__(self, batch):
+        image = [item["image"] for item in batch]
+        mask = [item["mask"] for item in batch]
+        pad_image, pad_mask = pad_images_masks(
+            image, mask, self.image_padding_token, self.mask_padding_token
+        )
+        return {
+            "image": torch.tensor(pad_image).permute(0, 3, 1, 2),
+            "mask": torch.tensor(pad_mask),
+        }
-- 
GitLab


From d6696a351446471abc6cf22d76777fd6febe6f58 Mon Sep 17 00:00:00 2001
From: Bastien Abadie <bastien@nextcairn.com>
Date: Tue, 12 Jul 2022 11:18:57 +0200
Subject: [PATCH 2/3] Manual fixes

---
 training/evaluate.py               |  3 ---
 training/experiments_config.json   |  0
 training/model_params.py           |  4 ++--
 training/predict.py                |  1 -
 training/run_experiment.py         | 11 ++++++-----
 training/train.py                  |  2 --
 training/utils/evaluation_utils.py |  3 +--
 training/utils/preprocessing.py    |  2 +-
 8 files changed, 10 insertions(+), 16 deletions(-)
 mode change 100755 => 100644 training/experiments_config.json

diff --git a/training/evaluate.py b/training/evaluate.py
index 20e5df3..f8aad1b 100755
--- a/training/evaluate.py
+++ b/training/evaluate.py
@@ -12,13 +12,10 @@ import logging
 import os
 import time
 
-import cv2
 import numpy as np
-import torch
 import utils.evaluation_utils as ev_utils
 import utils.object_metrics as o_metrics
 import utils.pixel_metrics as p_metrics
-from shapely.geometry import Polygon
 from tqdm import tqdm
 
 
diff --git a/training/experiments_config.json b/training/experiments_config.json
old mode 100755
new mode 100644
diff --git a/training/model_params.py b/training/model_params.py
index 69e9d74..14e5321 100755
--- a/training/model_params.py
+++ b/training/model_params.py
@@ -32,8 +32,8 @@ def default_config():
     """
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
     logging.info("Running on %s", device)
-    img_size = 768
-    no_of_classes = 2
+    img_size = 768  # noqa: F841
+    no_of_classes = 2  # noqa: F841
 
 
 @ex.automain
diff --git a/training/predict.py b/training/predict.py
index 972800a..29aa308 100755
--- a/training/predict.py
+++ b/training/predict.py
@@ -16,7 +16,6 @@ import cv2
 import numpy as np
 import torch
 import utils.prediction_utils as pr_utils
-from shapely.geometry import Polygon
 from tqdm import tqdm
 
 
diff --git a/training/run_experiment.py b/training/run_experiment.py
index 10562f1..f1c769c 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -17,7 +17,6 @@ from pathlib import Path
 import evaluate
 import normalization_params
 import predict
-import torch
 import torch.optim as optim
 import train
 import utils.preprocessing as pprocessing
@@ -87,12 +86,14 @@ def default_config():
         global_params["batch_size"] is not None
         or global_params["no_of_params"] is not None
     ), "Please provide a batch size or a maximum number of parameters"
-    params = Params().to_dict()
+    params = Params().to_dict()  # noqa: F841
 
     # Load the current experiment parameters.
     experiment_name = "doc-ufcn"
-    log_path = "runs/" + experiment_name.lower().replace(" ", "_").replace("-", "_")
-    tb_path = "events/"
+    log_path = "runs/" + experiment_name.lower().replace(  # noqa: F841
+        " ", "_"
+    ).replace("-", "_")
+    tb_path = "events/"  # noqa: F841
     steps = ["normalization_params", "train", "prediction", "evaluation"]
     for step in steps:
         assert step in STEPS
@@ -113,7 +114,7 @@ def default_config():
             "json": ["./data/test/labels_json/"],
         },
     }
-    exp_data_paths = {
+    exp_data_paths = {  # noqa: F841
         set: {
             key: [Path(element).expanduser() for element in value]
             for key, value in paths.items()
diff --git a/training/train.py b/training/train.py
index 0315bcb..1592745 100755
--- a/training/train.py
+++ b/training/train.py
@@ -10,7 +10,6 @@
 
 import logging
 import os
-import sys
 import time
 
 import numpy as np
@@ -21,7 +20,6 @@ from torch.cuda.amp import autocast
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 from utils import model
-from utils.params_config import Params
 
 
 def init_metrics(no_of_classes: int) -> dict:
diff --git a/training/utils/evaluation_utils.py b/training/utils/evaluation_utils.py
index 8d7cd85..1b92d75 100755
--- a/training/utils/evaluation_utils.py
+++ b/training/utils/evaluation_utils.py
@@ -11,12 +11,11 @@
 import json
 import os
 
-import cv2
 import matplotlib.font_manager as fm
 import matplotlib.pyplot as plt
 import numpy as np
 from matplotlib import ticker
-from shapely.geometry import MultiPolygon, Polygon
+from shapely.geometry import Polygon
 
 
 def read_json(filename: str) -> dict:
diff --git a/training/utils/preprocessing.py b/training/utils/preprocessing.py
index 93a20fb..b55069f 100755
--- a/training/utils/preprocessing.py
+++ b/training/utils/preprocessing.py
@@ -282,6 +282,6 @@ class ToTensor:
         :return sample: The sample made of Tensors.
         """
         sample["image"] = torch.from_numpy(sample["image"].transpose((2, 0, 1)))
-        if "mask" is sample.keys():
+        if "mask" == sample.keys():
             sample["mask"] = torch.from_numpy(sample["mask"])
         return sample
-- 
GitLab


From 23b4a095d1833bc2cf940cf71f879c08c99de17c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Tue, 12 Jul 2022 10:32:40 +0000
Subject: [PATCH 3/3] Apply 1 suggestion(s) to 1 file(s)

---
 training/utils/preprocessing.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/training/utils/preprocessing.py b/training/utils/preprocessing.py
index b55069f..467e0ef 100755
--- a/training/utils/preprocessing.py
+++ b/training/utils/preprocessing.py
@@ -282,6 +282,6 @@ class ToTensor:
         :return sample: The sample made of Tensors.
         """
         sample["image"] = torch.from_numpy(sample["image"].transpose((2, 0, 1)))
-        if "mask" == sample.keys():
+        if "mask" in sample.keys():
             sample["mask"] = torch.from_numpy(sample["mask"])
         return sample
-- 
GitLab