Skip to content
Snippets Groups Projects

Merge DatasetManager / GenericDataset / OCRDatasetManager / OCRDataset classes

All threads resolved!
2 files
+ 302
292
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 82
200
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import json
import json
import os
import os
import random
import numpy as np
import numpy as np
import torch
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision.io import ImageReadMode, read_image
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort
from dan.datasets.utils import natural_sort
from dan.transforms import (
from dan.utils import token_to_ind
get_augmentation_transforms,
get_normalization_transforms,
get_preprocessing_transforms,
)
class DatasetManager:
class OCRDataset(Dataset):
def __init__(self, params, device: str):
"""
self.params = params
Dataset class to handle dataset loading
self.dataset_class = None
"""
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None
def __init__(
self.valid_loaders = dict()
self,
self.test_loaders = dict()
set_name,
 
paths_and_sets,
 
charset,
 
tokens,
 
preprocessing_transforms,
 
normalization_transforms,
 
augmentation_transforms,
 
load_in_memory=False,
 
):
 
self.set_name = set_name
 
self.charset = charset
 
self.tokens = tokens
 
self.load_in_memory = load_in_memory
self.train_sampler = None
# Pre-processing, augmentation, normalization
self.valid_samplers = dict()
self.preprocessing_transforms = preprocessing_transforms
self.test_samplers = dict()
self.normalization_transforms = normalization_transforms
 
self.augmentation_transforms = augmentation_transforms
self.generator = torch.Generator()
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.generator.manual_seed(0)
self.reduce_dims_factor = np.array([32, 8, 1])
self.batch_size = {
# Load samples and preprocess images if load_in_memory is True
"train": self.params["batch_size"],
self.samples = self.load_samples(paths_and_sets)
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
def apply_specific_treatment_after_dataset_loading(self, dataset):
# Curriculum config
raise NotImplementedError
self.curriculum_config = None
def load_datasets(self):
def __len__(self):
"""
"""
Load training and validation datasets
Return the dataset size
"""
"""
self.train_dataset = self.dataset_class(
return len(self.samples)
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
def load_ddp_samplers(self):
def __getitem__(self, idx):
"""
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
# Load preprocessed image
"""
sample = dict(**self.samples[idx])
Load training and validation data loaders
if not self.load_in_memory:
"""
sample["img"] = self.get_sample_img(idx)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys():
# Apply data augmentation
self.valid_loaders[key] = DataLoader(
if self.augmentation_transforms:
self.valid_datasets[key],
sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[
batch_size=self.batch_size["val"],
"image"
sampler=self.valid_samplers[key],
]
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod
# Image normalization
def seed_worker(worker_id):
sample["img"] = self.normalization_transforms(sample["img"])
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list):
# Get final height and width
"""
sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
Load test dataset, data sampler and data loader
sample["img"]
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params,
"test",
custom_name,
paths_and_sets,
normalization_transforms=get_normalization_transforms(),
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
)
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.preprocessing_transforms = get_preprocessing_transforms(
# Convert label into tokens
params["config"]["preprocessings"]
sample["token_label"], sample["label_len"] = self.convert_sample_label(
)
sample["label"]
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
)
return sample
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
self.curriculum_config = None
def __len__(self):
return len(self.samples)
@staticmethod
@staticmethod
def load_image(path):
def load_image(path):
@@ -276,3 +130,31 @@ class GenericDataset(Dataset):
@@ -276,3 +130,31 @@ class GenericDataset(Dataset):
return self.preprocessing_transforms(
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
self.load_image(self.samples[i]["path"])
)
)
 
 
def compute_final_size(self, img):
 
"""
 
Compute the final image size and position after feature extraction
 
"""
 
final_c, final_h, final_w = img.shape
 
image_reduced_shape = np.ceil(
 
[final_h, final_w, final_c] / self.reduce_dims_factor
 
).astype(int)
 
 
if self.set_name == "train":
 
image_reduced_shape = [max(1, t) for t in image_reduced_shape]
 
 
image_position = [
 
[0, final_h],
 
[0, final_w],
 
]
 
return image_reduced_shape, image_position
 
 
def convert_sample_label(self, label):
 
"""
 
Tokenize the label and return its length
 
"""
 
token_label = token_to_ind(self.charset, label)
 
token_label.append(self.tokens["end"])
 
label_len = len(token_label)
 
token_label.insert(0, self.tokens["start"])
 
return token_label, label_len
Loading