Skip to content
Snippets Groups Projects
Commit 2bb85c50 authored by Solene Tarride's avatar Solene Tarride Committed by Mélodie Boillet
Browse files

Merge DatasetManager / GenericDataset / OCRDatasetManager / OCRDataset classes

parent fdaf48f4
No related branches found
No related tags found
1 merge request!191Merge DatasetManager / GenericDataset / OCRDatasetManager / OCRDataset classes
# -*- coding: utf-8 -*-
import json
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort
from dan.transforms import (
get_augmentation_transforms,
get_normalization_transforms,
get_preprocessing_transforms,
)
from dan.utils import token_to_ind
class DatasetManager:
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
class OCRDataset(Dataset):
"""
Dataset class to handle dataset loading
"""
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
def __init__(
self,
set_name,
paths_and_sets,
charset,
tokens,
preprocessing_transforms,
normalization_transforms,
augmentation_transforms,
load_in_memory=False,
):
self.set_name = set_name
self.charset = charset
self.tokens = tokens
self.load_in_memory = load_in_memory
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
# Pre-processing, augmentation, normalization
self.preprocessing_transforms = preprocessing_transforms
self.normalization_transforms = normalization_transforms
self.augmentation_transforms = augmentation_transforms
self.generator = torch.Generator()
self.generator.manual_seed(0)
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.batch_size = {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
def apply_specific_treatment_after_dataset_loading(self, dataset):
raise NotImplementedError
# Curriculum config
self.curriculum_config = None
def load_datasets(self):
def __len__(self):
"""
Load training and validation datasets
Return the dataset size
"""
self.train_dataset = self.dataset_class(
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
normalization_transforms=get_normalization_transforms(),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
return len(self.samples)
def load_ddp_samplers(self):
def __getitem__(self, idx):
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Load preprocessed image
sample = dict(**self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
# Apply data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[
"image"
]
@staticmethod
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Image normalization
sample["img"] = self.normalization_transforms(sample["img"])
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params,
"test",
custom_name,
paths_and_sets,
normalization_transforms=get_normalization_transforms(),
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
# Get final height and width
sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
sample["img"]
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.preprocessing_transforms = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
# Convert label into tokens
sample["token_label"], sample["label_len"] = self.convert_sample_label(
sample["label"]
)
# Load samples and preprocess images if load_in_memory is True
self.samples = self.load_samples(paths_and_sets)
self.curriculum_config = None
def __len__(self):
return len(self.samples)
return sample
@staticmethod
def load_image(path):
......@@ -276,3 +130,31 @@ class GenericDataset(Dataset):
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)
def compute_final_size(self, img):
"""
Compute the final image size and position after feature extraction
"""
final_c, final_h, final_w = img.shape
image_reduced_shape = np.ceil(
[final_h, final_w, final_c] / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
image_reduced_shape = [max(1, t) for t in image_reduced_shape]
image_position = [
[0, final_h],
[0, final_w],
]
return image_reduced_shape, image_position
def convert_sample_label(self, label):
"""
Tokenize the label and return its length
"""
token_label = token_to_ind(self.charset, label)
token_label.append(self.tokens["end"])
label_len = len(token_label)
token_label.insert(0, self.tokens["start"])
return token_label, label_len
# -*- coding: utf-8 -*-
import os
import pickle
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from dan.manager.dataset import DatasetManager, GenericDataset
from dan.utils import pad_images, pad_sequences_1D, token_to_ind
from dan.manager.dataset import OCRDataset
from dan.transforms import (
get_augmentation_transforms,
get_normalization_transforms,
get_preprocessing_transforms,
)
from dan.utils import pad_images, pad_sequences_1D
class OCRDatasetManager(DatasetManager):
"""
Specific class to handle OCR/HTR tasks
"""
class OCRDatasetManager:
def __init__(self, params, device: str):
super(OCRDatasetManager, self).__init__(params, device)
self.params = params
self.dataset_class = OCRDataset
self.charset = (
params["charset"] if "charset" in params else self.get_merged_charsets()
)
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
self.tokens = {"pad": len(self.charset) + 2}
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1
self.generator = torch.Generator()
self.generator.manual_seed(0)
self.batch_size = self.get_batch_size_by_set()
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
self.charset = self.get_charset()
self.tokens = self.get_tokens()
self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self):
self.my_collate_function = OCRCollateFunction(self.params["config"])
self.normalization = get_normalization_transforms()
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
)
self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
def load_datasets(self):
"""
Load training and validation datasets
"""
self.train_dataset = OCRDataset(
set_name="train",
paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
normalization_transforms=self.normalization,
augmentation_transforms=self.augmentation,
load_in_memory=self.load_in_memory,
)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = OCRDataset(
set_name="val",
paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
normalization_transforms=self.normalization,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
)
def load_ddp_samplers(self):
"""
Load training and validation data samplers
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod
def seed_worker():
"""
Set worker seed
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = OCRDataset(
set_name="test",
paths_and_sets=paths_and_sets,
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
normalization_transforms=self.normalization,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
"""
Set the right path for each data set
"""
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
def get_charset(self):
"""
Merge the charset of the different datasets used
"""
if "charset" in self.params:
return self.params["charset"]
datasets = self.params["datasets"]
charset = set()
for key in datasets.keys():
......@@ -39,83 +221,29 @@ class OCRDatasetManager(DatasetManager):
charset.remove("")
return sorted(list(charset))
def apply_specific_treatment_after_dataset_loading(self, dataset):
dataset.charset = self.charset
dataset.tokens = self.tokens
dataset.convert_labels()
class OCRDataset(GenericDataset):
"""
Specific class to handle OCR/HTR datasets
"""
def get_tokens(self):
"""
Get special tokens
"""
return {
"end": len(self.charset),
"start": len(self.charset) + 1,
"pad": len(self.charset) + 2,
}
def __init__(
self,
params,
set_name,
custom_name,
paths_and_sets,
normalization_transforms,
augmentation_transforms=None,
):
super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
self.charset = None
self.tokens = None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.collate_function = OCRCollateFunction
self.normalization_transforms = normalization_transforms
self.augmentation_transforms = augmentation_transforms
def __getitem__(self, idx):
sample = dict(**self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
# Data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[
"image"
]
# Normalization
sample["img"] = self.normalization_transforms(sample["img"])
# Get final height and width
final_c, final_h, final_w = sample["img"].shape
sample["img_reduced_shape"] = np.ceil(
[final_h, final_w, final_c] / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
sample["img_reduced_shape"] = [
max(1, t) for t in sample["img_reduced_shape"]
]
sample["img_position"] = [
[0, final_h],
[0, final_w],
]
return sample
def convert_labels(self):
"""
Label str to token at character level
"""
for i in range(len(self.samples)):
self.samples[i] = self.convert_sample_labels(self.samples[i])
def convert_sample_labels(self, sample):
label = sample["label"]
sample["label"] = label
sample["token_label"] = token_to_ind(self.charset, label)
sample["token_label"].append(self.tokens["end"])
sample["label_len"] = len(sample["token_label"])
sample["token_label"].insert(0, self.tokens["start"])
return sample
def get_batch_size_by_set(self):
"""
Return batch size for each set
"""
return {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
class OCRCollateFunction:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment