Skip to content
Snippets Groups Projects

Merge DatasetManager / GenericDataset / OCRDatasetManager / OCRDataset classes

All threads resolved!
2 files
+ 302
292
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 82
200
# -*- coding: utf-8 -*-
import json
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort
from dan.transforms import (
get_augmentation_transforms,
get_normalization_transforms,
get_preprocessing_transforms,
)
from dan.utils import token_to_ind
class DatasetManager:
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
class OCRDataset(Dataset):
"""
Dataset class to handle dataset loading
"""
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
def __init__(
self,
set_name,
paths_and_sets,
charset,
tokens,
preprocessing_transforms,
normalization_transforms,
augmentation_transforms,
load_in_memory=False,
):
self.set_name = set_name
self.charset = charset
self.tokens = tokens
self.load_in_memory = load_in_memory
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
# Pre-processing, augmentation, normalization
self.preprocessing_transforms = preprocessing_transforms
self.normalization_transforms = normalization_transforms
self.augmentation_transforms = augmentation_transforms
self.generator = torch.Generator()
self.generator.manual_seed(0)
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.batch_size = {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
def apply_specific_treatment_after_dataset_loading(self, dataset):
raise NotImplementedError
# Curriculum config
self.curriculum_config = None
def load_datasets(self):
def __len__(self):
"""
Load training and validation datasets
Return the dataset size
"""
self.train_dataset = self.dataset_class(
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
return len(self.samples)
def load_ddp_samplers(self):
def __getitem__(self, idx):
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Load preprocessed image
sample = dict(**self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Apply data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[
"image"
]
@staticmethod
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Image normalization
sample["img"] = self.normalization_transforms(sample["img"])
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params,
"test",
custom_name,
paths_and_sets,
normalization_transforms=get_normalization_transforms(),
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
# Get final height and width
sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
sample["img"]
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.preprocessing_transforms = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
# Convert label into tokens
sample["token_label"], sample["label_len"] = self.convert_sample_label(
sample["label"]
)
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
self.curriculum_config = None
def __len__(self):
return len(self.samples)
return sample
@staticmethod
def load_image(path):
@@ -276,3 +130,31 @@ class GenericDataset(Dataset):
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)
def compute_final_size(self, img):
"""
Compute the final image size and position after feature extraction
"""
final_c, final_h, final_w = img.shape
image_reduced_shape = np.ceil(
[final_h, final_w, final_c] / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
image_reduced_shape = [max(1, t) for t in image_reduced_shape]
image_position = [
[0, final_h],
[0, final_w],
]
return image_reduced_shape, image_position
def convert_sample_label(self, label):
"""
Tokenize the label and return its length
"""
token_label = token_to_ind(self.charset, label)
token_label.append(self.tokens["end"])
label_len = len(token_label)
token_label.insert(0, self.tokens["start"])
return token_label, label_len
Loading