Skip to content
Snippets Groups Projects
Verified Commit 8ba39efd authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply 2bb85c50

parent 5d35d16d
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
import copy
import json
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort
from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import token_to_ind
class DatasetManager:
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
class OCRDataset(Dataset):
"""
Dataset class to handle dataset loading
"""
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
def __init__(
self,
set_name,
paths_and_sets,
charset,
tokens,
preprocessing_transforms,
augmentation_transforms,
load_in_memory=False,
mean=None,
std=None,
):
self.set_name = set_name
self.charset = charset
self.tokens = tokens
self.load_in_memory = load_in_memory
self.mean = mean
self.std = std
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
# Pre-processing, augmentation
self.preprocessing_transforms = preprocessing_transforms
self.augmentation_transforms = augmentation_transforms
self.generator = torch.Generator()
self.generator.manual_seed(0)
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.batch_size = {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
def apply_specific_treatment_after_dataset_loading(self, dataset):
raise NotImplementedError
# Curriculum config
self.curriculum_config = None
def load_datasets(self):
def __len__(self):
"""
Load training and validation datasets
Return the dataset size
"""
self.train_dataset = self.dataset_class(
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
(
self.params["config"]["mean"],
self.params["config"]["std"],
) = self.train_dataset.compute_std_mean()
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
return len(self.samples)
def load_ddp_samplers(self):
def __getitem__(self, idx):
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
# Load preprocessed image
sample = copy.deepcopy(self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Convert to numpy
sample["img"] = np.array(sample["img"])
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Apply data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
@staticmethod
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Image normalization
sample["img"] = (sample["img"] - self.mean) / self.std
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params,
"test",
custom_name,
paths_and_sets,
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
# Get final height and width
sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
sample["img"]
)
def get_paths_and_sets(self, dataset_names_folds):
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.mean = (
np.array(params["config"]["mean"])
if "mean" in params["config"].keys()
else None
)
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
)
self.preprocessing_transforms = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
# Convert label into tokens
sample["token_label"], sample["label_len"] = self.convert_sample_label(
sample["label"]
)
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
self.curriculum_config = None
def __len__(self):
return len(self.samples)
return sample
@staticmethod
def load_image(path):
......@@ -273,6 +124,17 @@ class GenericDataset(Dataset):
)
return samples
def get_sample_img(self, i):
"""
Get image by index
"""
if self.load_in_memory:
return self.samples[i]["img"]
else:
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)
def compute_std_mean(self):
"""
Compute cumulated variance and mean of whole dataset
......@@ -299,13 +161,27 @@ class GenericDataset(Dataset):
self.std = np.sqrt(diff / nb_pixels)
return self.mean, self.std
def get_sample_img(self, i):
def compute_final_size(self, img):
"""
Get image by index
Compute the final image size and position after feature extraction
"""
if self.load_in_memory:
return self.samples[i]["img"]
else:
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)
image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
if self.set_name == "train":
image_reduced_shape = [max(1, t) for t in image_reduced_shape]
image_position = [
[0, img.shape[0]],
[0, img.shape[1]],
]
return image_reduced_shape, image_position
def convert_sample_label(self, label):
"""
Tokenize the label and return its length
"""
token_label = token_to_ind(self.charset, label)
token_label.append(self.tokens["end"])
label_len = len(token_label)
token_label.insert(0, self.tokens["start"])
return token_label, label_len
# -*- coding: utf-8 -*-
import copy
import os
import pickle
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from dan.manager.dataset import DatasetManager, GenericDataset
from dan.utils import pad_images, pad_sequences_1D, token_to_ind
from dan.manager.dataset import OCRDataset
from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import pad_images, pad_sequences_1D
class OCRDatasetManager(DatasetManager):
"""
Specific class to handle OCR/HTR tasks
"""
class OCRDatasetManager:
def __init__(self, params, device: str):
super(OCRDatasetManager, self).__init__(params, device)
self.params = params
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
self.dataset_class = OCRDataset
self.charset = (
params["charset"] if "charset" in params else self.get_merged_charsets()
self.mean = (
np.array(params["config"]["mean"])
if "mean" in params["config"].keys()
else None
)
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
)
self.generator = torch.Generator()
self.generator.manual_seed(0)
self.tokens = {"pad": len(self.charset) + 2}
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1
self.batch_size = self.get_batch_size_by_set()
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
self.charset = self.get_charset()
self.tokens = self.get_tokens()
self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self):
self.my_collate_function = OCRCollateFunction(self.params["config"])
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
)
self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
def load_datasets(self):
"""
Load training and validation datasets
"""
self.train_dataset = OCRDataset(
set_name="train",
paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=self.augmentation,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
self.mean, self.std = self.train_dataset.compute_std_mean()
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = OCRDataset(
set_name="val",
paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
def load_ddp_samplers(self):
"""
Load training and validation data samplers
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod
def seed_worker(worker_id):
"""
Set worker seed
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = OCRDataset(
set_name="test",
paths_and_sets=paths_and_sets,
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
"""
Set the right path for each data set
"""
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
def get_charset(self):
"""
Merge the charset of the different datasets used
"""
if "charset" in self.params:
return self.params["charset"]
datasets = self.params["datasets"]
charset = set()
for key in datasets.keys():
......@@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager):
charset.remove("")
return sorted(list(charset))
def apply_specific_treatment_after_dataset_loading(self, dataset):
dataset.charset = self.charset
dataset.tokens = self.tokens
dataset.convert_labels()
class OCRDataset(GenericDataset):
"""
Specific class to handle OCR/HTR datasets
"""
def __init__(
self,
params,
set_name,
custom_name,
paths_and_sets,
augmentation_transforms=None,
):
super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
self.charset = None
self.tokens = None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.collate_function = OCRCollateFunction
self.augmentation_transforms = augmentation_transforms
def __getitem__(self, idx):
sample = copy.deepcopy(self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
# Convert to numpy
sample["img"] = np.array(sample["img"])
# Data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
# Normalization
sample["img"] = (sample["img"] - self.mean) / self.std
sample["img_reduced_shape"] = np.ceil(
sample["img"].shape / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
sample["img_reduced_shape"] = [
max(1, t) for t in sample["img_reduced_shape"]
]
sample["img_position"] = [
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
return sample
def convert_labels(self):
def get_tokens(self):
"""
Label str to token at character level
Get special tokens
"""
for i in range(len(self.samples)):
self.samples[i] = self.convert_sample_labels(self.samples[i])
def convert_sample_labels(self, sample):
label = sample["label"]
return {
"end": len(self.charset),
"start": len(self.charset) + 1,
"pad": len(self.charset) + 2,
}
sample["label"] = label
sample["token_label"] = token_to_ind(self.charset, label)
sample["token_label"].append(self.tokens["end"])
sample["label_len"] = len(sample["token_label"])
sample["token_label"].insert(0, self.tokens["start"])
return sample
def get_batch_size_by_set(self):
"""
Return batch size for each set
"""
return {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
class OCRCollateFunction:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment