Skip to content
Snippets Groups Projects
dataset.py 9.44 KiB
# -*- coding: utf-8 -*-
import json
import os
import random

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision.io import ImageReadMode, read_image

from dan.datasets.utils import natural_sort
from dan.transforms import (
    get_augmentation_transforms,
    get_normalization_transforms,
    get_preprocessing_transforms,
)


class DatasetManager:
    def __init__(self, params, device: str):
        self.params = params
        self.dataset_class = None

        self.my_collate_function = None
        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
        self.pin_memory = device != "cpu"

        self.train_dataset = None
        self.valid_datasets = dict()
        self.test_datasets = dict()

        self.train_loader = None
        self.valid_loaders = dict()
        self.test_loaders = dict()

        self.train_sampler = None
        self.valid_samplers = dict()
        self.test_samplers = dict()

        self.generator = torch.Generator()
        self.generator.manual_seed(0)

        self.batch_size = {
            "train": self.params["batch_size"],
            "val": self.params["valid_batch_size"]
            if "valid_batch_size" in self.params
            else self.params["batch_size"],
            "test": self.params["test_batch_size"]
            if "test_batch_size" in self.params
            else 1,
        }

    def apply_specific_treatment_after_dataset_loading(self, dataset):
        raise NotImplementedError

    def load_datasets(self):
        """
        Load training and validation datasets
        """
        self.train_dataset = self.dataset_class(
            self.params,
            "train",
            self.params["train"]["name"],
            self.get_paths_and_sets(self.params["train"]["datasets"]),
            normalization_transforms=get_normalization_transforms(),
            augmentation_transforms=(
                get_augmentation_transforms()
                if self.params["config"]["augmentation"]
                else None
            ),
        )

        self.my_collate_function = self.train_dataset.collate_function(
            self.params["config"]
        )
        self.apply_specific_treatment_after_dataset_loading(self.train_dataset)

        for custom_name in self.params["val"].keys():
            self.valid_datasets[custom_name] = self.dataset_class(
                self.params,
                "val",
                custom_name,
                self.get_paths_and_sets(self.params["val"][custom_name]),
                normalization_transforms=get_normalization_transforms(),
                augmentation_transforms=None,
            )
            self.apply_specific_treatment_after_dataset_loading(
                self.valid_datasets[custom_name]
            )

    def load_ddp_samplers(self):
        """
        Load training and validation data samplers
        """
        if self.params["use_ddp"]:
            self.train_sampler = DistributedSampler(
                self.train_dataset,
                num_replicas=self.params["num_gpu"],
                rank=self.params["ddp_rank"],
                shuffle=True,
            )
            for custom_name in self.valid_datasets.keys():
                self.valid_samplers[custom_name] = DistributedSampler(
                    self.valid_datasets[custom_name],
                    num_replicas=self.params["num_gpu"],
                    rank=self.params["ddp_rank"],
                    shuffle=False,
                )
        else:
            for custom_name in self.valid_datasets.keys():
                self.valid_samplers[custom_name] = None

    def load_dataloaders(self):
        """
        Load training and validation data loaders
        """
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size["train"],
            shuffle=True if self.train_sampler is None else False,
            drop_last=False,
            batch_sampler=self.train_sampler,
            sampler=self.train_sampler,
            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
            pin_memory=self.pin_memory,
            collate_fn=self.my_collate_function,
            worker_init_fn=self.seed_worker,
            generator=self.generator,
        )

        for key in self.valid_datasets.keys():
            self.valid_loaders[key] = DataLoader(
                self.valid_datasets[key],
                batch_size=self.batch_size["val"],
                sampler=self.valid_samplers[key],
                batch_sampler=self.valid_samplers[key],
                shuffle=False,
                num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
                pin_memory=self.pin_memory,
                drop_last=False,
                collate_fn=self.my_collate_function,
                worker_init_fn=self.seed_worker,
                generator=self.generator,
            )

    @staticmethod
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    def generate_test_loader(self, custom_name, sets_list):
        """
        Load test dataset, data sampler and data loader
        """
        if custom_name in self.test_loaders.keys():
            return
        paths_and_sets = list()
        for set_info in sets_list:
            paths_and_sets.append(
                {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
            )
        self.test_datasets[custom_name] = self.dataset_class(
            self.params,
            "test",
            custom_name,
            paths_and_sets,
            normalization_transforms=get_normalization_transforms(),
        )
        self.apply_specific_treatment_after_dataset_loading(
            self.test_datasets[custom_name]
        )
        if self.params["use_ddp"]:
            self.test_samplers[custom_name] = DistributedSampler(
                self.test_datasets[custom_name],
                num_replicas=self.params["num_gpu"],
                rank=self.params["ddp_rank"],
                shuffle=False,
            )
        else:
            self.test_samplers[custom_name] = None
        self.test_loaders[custom_name] = DataLoader(
            self.test_datasets[custom_name],
            batch_size=self.batch_size["test"],
            sampler=self.test_samplers[custom_name],
            shuffle=False,
            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
            pin_memory=self.pin_memory,
            drop_last=False,
            collate_fn=self.my_collate_function,
            worker_init_fn=self.seed_worker,
            generator=self.generator,
        )

    def get_paths_and_sets(self, dataset_names_folds):
        paths_and_sets = list()
        for dataset_name, fold in dataset_names_folds:
            path = self.params["datasets"][dataset_name]
            paths_and_sets.append({"path": path, "set_name": fold})
        return paths_and_sets


class GenericDataset(Dataset):
    """
    Main class to handle dataset loading
    """

    def __init__(self, params, set_name, custom_name, paths_and_sets):
        self.params = params
        self.name = custom_name
        self.set_name = set_name

        self.preprocessing_transforms = get_preprocessing_transforms(
            params["config"]["preprocessings"]
        )
        self.load_in_memory = (
            self.params["config"]["load_in_memory"]
            if "load_in_memory" in self.params["config"]
            else True
        )

        # Load samples and preprocess images if load_in_memory is True
        self.samples = self.load_samples(paths_and_sets)

        self.curriculum_config = None

    def __len__(self):
        return len(self.samples)

    @staticmethod
    def load_image(path):
        """
        Load an image as a torch.Tensor and scale the values between 0 and 1.
        """
        img = read_image(path, mode=ImageReadMode.RGB)
        return img.to(dtype=torch.get_default_dtype()).div(255)

    def load_samples(self, paths_and_sets):
        """
        Load images and labels
        """
        samples = list()

        for path_and_set in paths_and_sets:
            path = path_and_set["path"]
            with open(os.path.join(path, "labels.json"), "rb") as f:
                gt_per_set = json.load(f)
            set_name = path_and_set["set_name"]
            gt = gt_per_set[set_name]
            for filename in natural_sort(gt.keys()):
                if isinstance(gt[filename], dict) and "text" in gt[filename]:
                    label = gt[filename]["text"]
                else:
                    label = gt[filename]
                samples.append(
                    {
                        "name": os.path.basename(filename),
                        "label": label,
                        "path": os.path.abspath(filename),
                    }
                )
                if self.load_in_memory:
                    samples[-1]["img"] = self.preprocessing_transforms(
                        self.load_image(filename)
                    )
        return samples

    def get_sample_img(self, i):
        """
        Get image by index
        """
        if self.load_in_memory:
            return self.samples[i]["img"]
        else:
            return self.preprocessing_transforms(
                self.load_image(self.samples[i]["path"])
            )