Something went wrong on our end
dataset.py 9.44 KiB
# -*- coding: utf-8 -*-
import json
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort
from dan.transforms import (
get_augmentation_transforms,
get_normalization_transforms,
get_preprocessing_transforms,
)
class DatasetManager:
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
self.generator = torch.Generator()
self.generator.manual_seed(0)
self.batch_size = {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
def apply_specific_treatment_after_dataset_loading(self, dataset):
raise NotImplementedError
def load_datasets(self):
"""
Load training and validation datasets
"""
self.train_dataset = self.dataset_class(
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
def load_ddp_samplers(self):
"""
Load training and validation data samplers
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params,
"test",
custom_name,
paths_and_sets,
normalization_transforms=get_normalization_transforms(),
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.preprocessing_transforms = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
self.curriculum_config = None
def __len__(self):
return len(self.samples)
@staticmethod
def load_image(path):
"""
Load an image as a torch.Tensor and scale the values between 0 and 1.
"""
img = read_image(path, mode=ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
def load_samples(self, paths_and_sets):
"""
Load images and labels
"""
samples = list()
for path_and_set in paths_and_sets:
path = path_and_set["path"]
with open(os.path.join(path, "labels.json"), "rb") as f:
gt_per_set = json.load(f)
set_name = path_and_set["set_name"]
gt = gt_per_set[set_name]
for filename in natural_sort(gt.keys()):
if isinstance(gt[filename], dict) and "text" in gt[filename]:
label = gt[filename]["text"]
else:
label = gt[filename]
samples.append(
{
"name": os.path.basename(filename),
"label": label,
"path": os.path.abspath(filename),
}
)
if self.load_in_memory:
samples[-1]["img"] = self.preprocessing_transforms(
self.load_image(filename)
)
return samples
def get_sample_img(self, i):
"""
Get image by index
"""
if self.load_in_memory:
return self.samples[i]["img"]
else:
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)