diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 0947d6c439eabbc65ded6e0cc7c952b34535bf1e..5f6ad28b9d5ca2556820d9cd1b59a034b6347702 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -1,232 +1,86 @@ # -*- coding: utf-8 -*- import json import os -import random import numpy as np import torch -from torch.utils.data import DataLoader, Dataset -from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import Dataset from torchvision.io import ImageReadMode, read_image from dan.datasets.utils import natural_sort -from dan.transforms import ( - get_augmentation_transforms, - get_normalization_transforms, - get_preprocessing_transforms, -) +from dan.utils import token_to_ind -class DatasetManager: - def __init__(self, params, device: str): - self.params = params - self.dataset_class = None - - self.my_collate_function = None - # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html - self.pin_memory = device != "cpu" - - self.train_dataset = None - self.valid_datasets = dict() - self.test_datasets = dict() +class OCRDataset(Dataset): + """ + Dataset class to handle dataset loading + """ - self.train_loader = None - self.valid_loaders = dict() - self.test_loaders = dict() + def __init__( + self, + set_name, + paths_and_sets, + charset, + tokens, + preprocessing_transforms, + normalization_transforms, + augmentation_transforms, + load_in_memory=False, + ): + self.set_name = set_name + self.charset = charset + self.tokens = tokens + self.load_in_memory = load_in_memory - self.train_sampler = None - self.valid_samplers = dict() - self.test_samplers = dict() + # Pre-processing, augmentation, normalization + self.preprocessing_transforms = preprocessing_transforms + self.normalization_transforms = normalization_transforms + self.augmentation_transforms = augmentation_transforms - self.generator = torch.Generator() - self.generator.manual_seed(0) + # Factor to reduce the height and width of the feature vector before feeding the decoder. + self.reduce_dims_factor = np.array([32, 8, 1]) - self.batch_size = { - "train": self.params["batch_size"], - "val": self.params["valid_batch_size"] - if "valid_batch_size" in self.params - else self.params["batch_size"], - "test": self.params["test_batch_size"] - if "test_batch_size" in self.params - else 1, - } + # Load samples and preprocess images if load_in_memory is True + self.samples = self.load_samples(paths_and_sets) - def apply_specific_treatment_after_dataset_loading(self, dataset): - raise NotImplementedError + # Curriculum config + self.curriculum_config = None - def load_datasets(self): + def __len__(self): """ - Load training and validation datasets + Return the dataset size """ - self.train_dataset = self.dataset_class( - self.params, - "train", - self.params["train"]["name"], - self.get_paths_and_sets(self.params["train"]["datasets"]), - normalization_transforms=get_normalization_transforms(), - augmentation_transforms=( - get_augmentation_transforms() - if self.params["config"]["augmentation"] - else None - ), - ) - - self.my_collate_function = self.train_dataset.collate_function( - self.params["config"] - ) - self.apply_specific_treatment_after_dataset_loading(self.train_dataset) - - for custom_name in self.params["val"].keys(): - self.valid_datasets[custom_name] = self.dataset_class( - self.params, - "val", - custom_name, - self.get_paths_and_sets(self.params["val"][custom_name]), - normalization_transforms=get_normalization_transforms(), - augmentation_transforms=None, - ) - self.apply_specific_treatment_after_dataset_loading( - self.valid_datasets[custom_name] - ) + return len(self.samples) - def load_ddp_samplers(self): + def __getitem__(self, idx): """ - Load training and validation data samplers + Return an item from the dataset (image and label) """ - if self.params["use_ddp"]: - self.train_sampler = DistributedSampler( - self.train_dataset, - num_replicas=self.params["num_gpu"], - rank=self.params["ddp_rank"], - shuffle=True, - ) - for custom_name in self.valid_datasets.keys(): - self.valid_samplers[custom_name] = DistributedSampler( - self.valid_datasets[custom_name], - num_replicas=self.params["num_gpu"], - rank=self.params["ddp_rank"], - shuffle=False, - ) - else: - for custom_name in self.valid_datasets.keys(): - self.valid_samplers[custom_name] = None - def load_dataloaders(self): - """ - Load training and validation data loaders - """ - self.train_loader = DataLoader( - self.train_dataset, - batch_size=self.batch_size["train"], - shuffle=True if self.train_sampler is None else False, - drop_last=False, - batch_sampler=self.train_sampler, - sampler=self.train_sampler, - num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=self.pin_memory, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator, - ) + # Load preprocessed image + sample = dict(**self.samples[idx]) + if not self.load_in_memory: + sample["img"] = self.get_sample_img(idx) - for key in self.valid_datasets.keys(): - self.valid_loaders[key] = DataLoader( - self.valid_datasets[key], - batch_size=self.batch_size["val"], - sampler=self.valid_samplers[key], - batch_sampler=self.valid_samplers[key], - shuffle=False, - num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=self.pin_memory, - drop_last=False, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator, - ) + # Apply data augmentation + if self.augmentation_transforms: + sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[ + "image" + ] - @staticmethod - def seed_worker(worker_id): - worker_seed = torch.initial_seed() % 2**32 - np.random.seed(worker_seed) - random.seed(worker_seed) + # Image normalization + sample["img"] = self.normalization_transforms(sample["img"]) - def generate_test_loader(self, custom_name, sets_list): - """ - Load test dataset, data sampler and data loader - """ - if custom_name in self.test_loaders.keys(): - return - paths_and_sets = list() - for set_info in sets_list: - paths_and_sets.append( - {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]} - ) - self.test_datasets[custom_name] = self.dataset_class( - self.params, - "test", - custom_name, - paths_and_sets, - normalization_transforms=get_normalization_transforms(), - ) - self.apply_specific_treatment_after_dataset_loading( - self.test_datasets[custom_name] + # Get final height and width + sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size( + sample["img"] ) - if self.params["use_ddp"]: - self.test_samplers[custom_name] = DistributedSampler( - self.test_datasets[custom_name], - num_replicas=self.params["num_gpu"], - rank=self.params["ddp_rank"], - shuffle=False, - ) - else: - self.test_samplers[custom_name] = None - self.test_loaders[custom_name] = DataLoader( - self.test_datasets[custom_name], - batch_size=self.batch_size["test"], - sampler=self.test_samplers[custom_name], - shuffle=False, - num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=self.pin_memory, - drop_last=False, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator, - ) - - def get_paths_and_sets(self, dataset_names_folds): - paths_and_sets = list() - for dataset_name, fold in dataset_names_folds: - path = self.params["datasets"][dataset_name] - paths_and_sets.append({"path": path, "set_name": fold}) - return paths_and_sets - - -class GenericDataset(Dataset): - """ - Main class to handle dataset loading - """ - - def __init__(self, params, set_name, custom_name, paths_and_sets): - self.params = params - self.name = custom_name - self.set_name = set_name - self.preprocessing_transforms = get_preprocessing_transforms( - params["config"]["preprocessings"] - ) - self.load_in_memory = ( - self.params["config"]["load_in_memory"] - if "load_in_memory" in self.params["config"] - else True + # Convert label into tokens + sample["token_label"], sample["label_len"] = self.convert_sample_label( + sample["label"] ) - - # Load samples and preprocess images if load_in_memory is True - self.samples = self.load_samples(paths_and_sets) - - self.curriculum_config = None - - def __len__(self): - return len(self.samples) + return sample @staticmethod def load_image(path): @@ -276,3 +130,31 @@ class GenericDataset(Dataset): return self.preprocessing_transforms( self.load_image(self.samples[i]["path"]) ) + + def compute_final_size(self, img): + """ + Compute the final image size and position after feature extraction + """ + final_c, final_h, final_w = img.shape + image_reduced_shape = np.ceil( + [final_h, final_w, final_c] / self.reduce_dims_factor + ).astype(int) + + if self.set_name == "train": + image_reduced_shape = [max(1, t) for t in image_reduced_shape] + + image_position = [ + [0, final_h], + [0, final_w], + ] + return image_reduced_shape, image_position + + def convert_sample_label(self, label): + """ + Tokenize the label and return its length + """ + token_label = token_to_ind(self.charset, label) + token_label.append(self.tokens["end"]) + label_len = len(token_label) + token_label.insert(0, self.tokens["start"]) + return token_label, label_len diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 1790f4a081cbaa265cf19483c7792f1f7af8e96b..383ac455aea8cd6ed104f81be85eb17bafec8952 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -1,35 +1,217 @@ # -*- coding: utf-8 -*- import os import pickle +import random import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler -from dan.manager.dataset import DatasetManager, GenericDataset -from dan.utils import pad_images, pad_sequences_1D, token_to_ind +from dan.manager.dataset import OCRDataset +from dan.transforms import ( + get_augmentation_transforms, + get_normalization_transforms, + get_preprocessing_transforms, +) +from dan.utils import pad_images, pad_sequences_1D -class OCRDatasetManager(DatasetManager): - """ - Specific class to handle OCR/HTR tasks - """ - +class OCRDatasetManager: def __init__(self, params, device: str): - super(OCRDatasetManager, self).__init__(params, device) + self.params = params - self.dataset_class = OCRDataset - self.charset = ( - params["charset"] if "charset" in params else self.get_merged_charsets() - ) + # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html + self.pin_memory = device != "cpu" + + self.train_dataset = None + self.valid_datasets = dict() + self.test_datasets = dict() + + self.train_loader = None + self.valid_loaders = dict() + self.test_loaders = dict() + + self.train_sampler = None + self.valid_samplers = dict() + self.test_samplers = dict() - self.tokens = {"pad": len(self.charset) + 2} - self.tokens["end"] = len(self.charset) - self.tokens["start"] = len(self.charset) + 1 + self.generator = torch.Generator() + self.generator.manual_seed(0) + + self.batch_size = self.get_batch_size_by_set() + + self.load_in_memory = ( + self.params["config"]["load_in_memory"] + if "load_in_memory" in self.params["config"] + else True + ) + self.charset = self.get_charset() + self.tokens = self.get_tokens() self.params["config"]["padding_token"] = self.tokens["pad"] - def get_merged_charsets(self): + self.my_collate_function = OCRCollateFunction(self.params["config"]) + self.normalization = get_normalization_transforms() + self.augmentation = ( + get_augmentation_transforms() + if self.params["config"]["augmentation"] + else None + ) + self.preprocessing = get_preprocessing_transforms( + params["config"]["preprocessings"] + ) + + def load_datasets(self): + """ + Load training and validation datasets + """ + self.train_dataset = OCRDataset( + set_name="train", + paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]), + charset=self.charset, + tokens=self.tokens, + preprocessing_transforms=self.preprocessing, + normalization_transforms=self.normalization, + augmentation_transforms=self.augmentation, + load_in_memory=self.load_in_memory, + ) + + for custom_name in self.params["val"].keys(): + self.valid_datasets[custom_name] = OCRDataset( + set_name="val", + paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]), + charset=self.charset, + tokens=self.tokens, + preprocessing_transforms=self.preprocessing, + normalization_transforms=self.normalization, + augmentation_transforms=None, + load_in_memory=self.load_in_memory, + ) + + def load_ddp_samplers(self): + """ + Load training and validation data samplers + """ + if self.params["use_ddp"]: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=self.params["num_gpu"], + rank=self.params["ddp_rank"], + shuffle=True, + ) + for custom_name in self.valid_datasets.keys(): + self.valid_samplers[custom_name] = DistributedSampler( + self.valid_datasets[custom_name], + num_replicas=self.params["num_gpu"], + rank=self.params["ddp_rank"], + shuffle=False, + ) + else: + for custom_name in self.valid_datasets.keys(): + self.valid_samplers[custom_name] = None + + def load_dataloaders(self): + """ + Load training and validation data loaders + """ + self.train_loader = DataLoader( + self.train_dataset, + batch_size=self.batch_size["train"], + shuffle=True if self.train_sampler is None else False, + drop_last=False, + batch_sampler=self.train_sampler, + sampler=self.train_sampler, + num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], + pin_memory=self.pin_memory, + collate_fn=self.my_collate_function, + worker_init_fn=self.seed_worker, + generator=self.generator, + ) + + for key in self.valid_datasets.keys(): + self.valid_loaders[key] = DataLoader( + self.valid_datasets[key], + batch_size=self.batch_size["val"], + sampler=self.valid_samplers[key], + batch_sampler=self.valid_samplers[key], + shuffle=False, + num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], + pin_memory=self.pin_memory, + drop_last=False, + collate_fn=self.my_collate_function, + worker_init_fn=self.seed_worker, + generator=self.generator, + ) + + @staticmethod + def seed_worker(): + """ + Set worker seed + """ + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + def generate_test_loader(self, custom_name, sets_list): + """ + Load test dataset, data sampler and data loader + """ + if custom_name in self.test_loaders.keys(): + return + paths_and_sets = list() + for set_info in sets_list: + paths_and_sets.append( + {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]} + ) + self.test_datasets[custom_name] = OCRDataset( + set_name="test", + paths_and_sets=paths_and_sets, + charset=self.charset, + tokens=self.tokens, + preprocessing_transforms=self.preprocessing, + normalization_transforms=self.normalization, + augmentation_transforms=None, + load_in_memory=self.load_in_memory, + ) + + if self.params["use_ddp"]: + self.test_samplers[custom_name] = DistributedSampler( + self.test_datasets[custom_name], + num_replicas=self.params["num_gpu"], + rank=self.params["ddp_rank"], + shuffle=False, + ) + else: + self.test_samplers[custom_name] = None + self.test_loaders[custom_name] = DataLoader( + self.test_datasets[custom_name], + batch_size=self.batch_size["test"], + sampler=self.test_samplers[custom_name], + shuffle=False, + num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], + pin_memory=self.pin_memory, + drop_last=False, + collate_fn=self.my_collate_function, + worker_init_fn=self.seed_worker, + generator=self.generator, + ) + + def get_paths_and_sets(self, dataset_names_folds): + """ + Set the right path for each data set + """ + paths_and_sets = list() + for dataset_name, fold in dataset_names_folds: + path = self.params["datasets"][dataset_name] + paths_and_sets.append({"path": path, "set_name": fold}) + return paths_and_sets + + def get_charset(self): """ Merge the charset of the different datasets used """ + if "charset" in self.params: + return self.params["charset"] datasets = self.params["datasets"] charset = set() for key in datasets.keys(): @@ -39,83 +221,29 @@ class OCRDatasetManager(DatasetManager): charset.remove("") return sorted(list(charset)) - def apply_specific_treatment_after_dataset_loading(self, dataset): - dataset.charset = self.charset - dataset.tokens = self.tokens - dataset.convert_labels() - - -class OCRDataset(GenericDataset): - """ - Specific class to handle OCR/HTR datasets - """ + def get_tokens(self): + """ + Get special tokens + """ + return { + "end": len(self.charset), + "start": len(self.charset) + 1, + "pad": len(self.charset) + 2, + } - def __init__( - self, - params, - set_name, - custom_name, - paths_and_sets, - normalization_transforms, - augmentation_transforms=None, - ): - super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets) - self.charset = None - self.tokens = None - # Factor to reduce the height and width of the feature vector before feeding the decoder. - self.reduce_dims_factor = np.array([32, 8, 1]) - self.collate_function = OCRCollateFunction - self.normalization_transforms = normalization_transforms - self.augmentation_transforms = augmentation_transforms - - def __getitem__(self, idx): - sample = dict(**self.samples[idx]) - - if not self.load_in_memory: - sample["img"] = self.get_sample_img(idx) - - # Data augmentation - if self.augmentation_transforms: - sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[ - "image" - ] - - # Normalization - sample["img"] = self.normalization_transforms(sample["img"]) - - # Get final height and width - final_c, final_h, final_w = sample["img"].shape - sample["img_reduced_shape"] = np.ceil( - [final_h, final_w, final_c] / self.reduce_dims_factor - ).astype(int) - - if self.set_name == "train": - sample["img_reduced_shape"] = [ - max(1, t) for t in sample["img_reduced_shape"] - ] - - sample["img_position"] = [ - [0, final_h], - [0, final_w], - ] - return sample - - def convert_labels(self): - """ - Label str to token at character level - """ - for i in range(len(self.samples)): - self.samples[i] = self.convert_sample_labels(self.samples[i]) - - def convert_sample_labels(self, sample): - label = sample["label"] - - sample["label"] = label - sample["token_label"] = token_to_ind(self.charset, label) - sample["token_label"].append(self.tokens["end"]) - sample["label_len"] = len(sample["token_label"]) - sample["token_label"].insert(0, self.tokens["start"]) - return sample + def get_batch_size_by_set(self): + """ + Return batch size for each set + """ + return { + "train": self.params["batch_size"], + "val": self.params["valid_batch_size"] + if "valid_batch_size" in self.params + else self.params["batch_size"], + "test": self.params["test_batch_size"] + if "test_batch_size" in self.params + else 1, + } class OCRCollateFunction: