diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 9e9b52a17f700eff39d5f90c30942206a7a4af44..4afb776c9389ed7c9c7e6003a1a7e3068d5187ce 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -1,239 +1,90 @@ # -*- coding: utf-8 -*- +import copy import json import os -import random import numpy as np import torch -from torch.utils.data import DataLoader, Dataset -from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import Dataset from torchvision.io import ImageReadMode, read_image from dan.datasets.utils import natural_sort -from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms +from dan.utils import token_to_ind -class DatasetManager: - def __init__(self, params, device: str): - self.params = params - self.dataset_class = None - - self.my_collate_function = None - # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html - self.pin_memory = device != "cpu" - - self.train_dataset = None - self.valid_datasets = dict() - self.test_datasets = dict() +class OCRDataset(Dataset): + """ + Dataset class to handle dataset loading + """ - self.train_loader = None - self.valid_loaders = dict() - self.test_loaders = dict() + def __init__( + self, + set_name, + paths_and_sets, + charset, + tokens, + preprocessing_transforms, + augmentation_transforms, + load_in_memory=False, + mean=None, + std=None, + ): + self.set_name = set_name + self.charset = charset + self.tokens = tokens + self.load_in_memory = load_in_memory + self.mean = mean + self.std = std - self.train_sampler = None - self.valid_samplers = dict() - self.test_samplers = dict() + # Pre-processing, augmentation + self.preprocessing_transforms = preprocessing_transforms + self.augmentation_transforms = augmentation_transforms - self.generator = torch.Generator() - self.generator.manual_seed(0) + # Factor to reduce the height and width of the feature vector before feeding the decoder. + self.reduce_dims_factor = np.array([32, 8, 1]) - self.batch_size = { - "train": self.params["batch_size"], - "val": self.params["valid_batch_size"] - if "valid_batch_size" in self.params - else self.params["batch_size"], - "test": self.params["test_batch_size"] - if "test_batch_size" in self.params - else 1, - } + # Load samples and preprocess images if load_in_memory is True + self.samples = self.load_samples(paths_and_sets) - def apply_specific_treatment_after_dataset_loading(self, dataset): - raise NotImplementedError + # Curriculum config + self.curriculum_config = None - def load_datasets(self): + def __len__(self): """ - Load training and validation datasets + Return the dataset size """ - self.train_dataset = self.dataset_class( - self.params, - "train", - self.params["train"]["name"], - self.get_paths_and_sets(self.params["train"]["datasets"]), - augmentation_transforms=( - get_augmentation_transforms() - if self.params["config"]["augmentation"] - else None - ), - ) - - ( - self.params["config"]["mean"], - self.params["config"]["std"], - ) = self.train_dataset.compute_std_mean() - - self.my_collate_function = self.train_dataset.collate_function( - self.params["config"] - ) - self.apply_specific_treatment_after_dataset_loading(self.train_dataset) - - for custom_name in self.params["val"].keys(): - self.valid_datasets[custom_name] = self.dataset_class( - self.params, - "val", - custom_name, - self.get_paths_and_sets(self.params["val"][custom_name]), - augmentation_transforms=None, - ) - self.apply_specific_treatment_after_dataset_loading( - self.valid_datasets[custom_name] - ) + return len(self.samples) - def load_ddp_samplers(self): + def __getitem__(self, idx): """ - Load training and validation data samplers + Return an item from the dataset (image and label) """ - if self.params["use_ddp"]: - self.train_sampler = DistributedSampler( - self.train_dataset, - num_replicas=self.params["num_gpu"], - rank=self.params["ddp_rank"], - shuffle=True, - ) - for custom_name in self.valid_datasets.keys(): - self.valid_samplers[custom_name] = DistributedSampler( - self.valid_datasets[custom_name], - num_replicas=self.params["num_gpu"], - rank=self.params["ddp_rank"], - shuffle=False, - ) - else: - for custom_name in self.valid_datasets.keys(): - self.valid_samplers[custom_name] = None + # Load preprocessed image + sample = copy.deepcopy(self.samples[idx]) + if not self.load_in_memory: + sample["img"] = self.get_sample_img(idx) - def load_dataloaders(self): - """ - Load training and validation data loaders - """ - self.train_loader = DataLoader( - self.train_dataset, - batch_size=self.batch_size["train"], - shuffle=True if self.train_sampler is None else False, - drop_last=False, - batch_sampler=self.train_sampler, - sampler=self.train_sampler, - num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=self.pin_memory, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator, - ) + # Convert to numpy + sample["img"] = np.array(sample["img"]) - for key in self.valid_datasets.keys(): - self.valid_loaders[key] = DataLoader( - self.valid_datasets[key], - batch_size=self.batch_size["val"], - sampler=self.valid_samplers[key], - batch_sampler=self.valid_samplers[key], - shuffle=False, - num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=self.pin_memory, - drop_last=False, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator, - ) + # Apply data augmentation + if self.augmentation_transforms: + sample["img"] = self.augmentation_transforms(image=sample["img"])["image"] - @staticmethod - def seed_worker(worker_id): - worker_seed = torch.initial_seed() % 2**32 - np.random.seed(worker_seed) - random.seed(worker_seed) + # Image normalization + sample["img"] = (sample["img"] - self.mean) / self.std - def generate_test_loader(self, custom_name, sets_list): - """ - Load test dataset, data sampler and data loader - """ - if custom_name in self.test_loaders.keys(): - return - paths_and_sets = list() - for set_info in sets_list: - paths_and_sets.append( - {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]} - ) - self.test_datasets[custom_name] = self.dataset_class( - self.params, - "test", - custom_name, - paths_and_sets, - ) - self.apply_specific_treatment_after_dataset_loading( - self.test_datasets[custom_name] - ) - if self.params["use_ddp"]: - self.test_samplers[custom_name] = DistributedSampler( - self.test_datasets[custom_name], - num_replicas=self.params["num_gpu"], - rank=self.params["ddp_rank"], - shuffle=False, - ) - else: - self.test_samplers[custom_name] = None - self.test_loaders[custom_name] = DataLoader( - self.test_datasets[custom_name], - batch_size=self.batch_size["test"], - sampler=self.test_samplers[custom_name], - shuffle=False, - num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=self.pin_memory, - drop_last=False, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator, + # Get final height and width + sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size( + sample["img"] ) - def get_paths_and_sets(self, dataset_names_folds): - paths_and_sets = list() - for dataset_name, fold in dataset_names_folds: - path = self.params["datasets"][dataset_name] - paths_and_sets.append({"path": path, "set_name": fold}) - return paths_and_sets - - -class GenericDataset(Dataset): - """ - Main class to handle dataset loading - """ - - def __init__(self, params, set_name, custom_name, paths_and_sets): - self.params = params - self.name = custom_name - self.set_name = set_name - self.mean = ( - np.array(params["config"]["mean"]) - if "mean" in params["config"].keys() - else None - ) - self.std = ( - np.array(params["config"]["std"]) - if "std" in params["config"].keys() - else None - ) - self.preprocessing_transforms = get_preprocessing_transforms( - params["config"]["preprocessings"] - ) - self.load_in_memory = ( - self.params["config"]["load_in_memory"] - if "load_in_memory" in self.params["config"] - else True + # Convert label into tokens + sample["token_label"], sample["label_len"] = self.convert_sample_label( + sample["label"] ) - # Load samples and preprocess images if load_in_memory is True - self.samples = self.load_samples(paths_and_sets) - - self.curriculum_config = None - - def __len__(self): - return len(self.samples) + return sample @staticmethod def load_image(path): @@ -273,6 +124,17 @@ class GenericDataset(Dataset): ) return samples + def get_sample_img(self, i): + """ + Get image by index + """ + if self.load_in_memory: + return self.samples[i]["img"] + else: + return self.preprocessing_transforms( + self.load_image(self.samples[i]["path"]) + ) + def compute_std_mean(self): """ Compute cumulated variance and mean of whole dataset @@ -299,13 +161,27 @@ class GenericDataset(Dataset): self.std = np.sqrt(diff / nb_pixels) return self.mean, self.std - def get_sample_img(self, i): + def compute_final_size(self, img): """ - Get image by index + Compute the final image size and position after feature extraction """ - if self.load_in_memory: - return self.samples[i]["img"] - else: - return self.preprocessing_transforms( - self.load_image(self.samples[i]["path"]) - ) + image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int) + + if self.set_name == "train": + image_reduced_shape = [max(1, t) for t in image_reduced_shape] + + image_position = [ + [0, img.shape[0]], + [0, img.shape[1]], + ] + return image_reduced_shape, image_position + + def convert_sample_label(self, label): + """ + Tokenize the label and return its length + """ + token_label = token_to_ind(self.charset, label) + token_label.append(self.tokens["end"]) + label_len = len(token_label) + token_label.insert(0, self.tokens["start"]) + return token_label, label_len diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 6e32c586423f67b23a0f7b6f0af0113061bd77f3..9425632335798b5eb7933d8462bd5da7e66b6158 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -1,37 +1,229 @@ # -*- coding: utf-8 -*- -import copy import os import pickle +import random import numpy as np import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler -from dan.manager.dataset import DatasetManager, GenericDataset -from dan.utils import pad_images, pad_sequences_1D, token_to_ind +from dan.manager.dataset import OCRDataset +from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms +from dan.utils import pad_images, pad_sequences_1D -class OCRDatasetManager(DatasetManager): - """ - Specific class to handle OCR/HTR tasks - """ - +class OCRDatasetManager: def __init__(self, params, device: str): - super(OCRDatasetManager, self).__init__(params, device) + self.params = params + + # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html + self.pin_memory = device != "cpu" + + self.train_dataset = None + self.valid_datasets = dict() + self.test_datasets = dict() + + self.train_loader = None + self.valid_loaders = dict() + self.test_loaders = dict() + + self.train_sampler = None + self.valid_samplers = dict() + self.test_samplers = dict() - self.dataset_class = OCRDataset - self.charset = ( - params["charset"] if "charset" in params else self.get_merged_charsets() + self.mean = ( + np.array(params["config"]["mean"]) + if "mean" in params["config"].keys() + else None ) + self.std = ( + np.array(params["config"]["std"]) + if "std" in params["config"].keys() + else None + ) + + self.generator = torch.Generator() + self.generator.manual_seed(0) - self.tokens = {"pad": len(self.charset) + 2} - self.tokens["end"] = len(self.charset) - self.tokens["start"] = len(self.charset) + 1 + self.batch_size = self.get_batch_size_by_set() + + self.load_in_memory = ( + self.params["config"]["load_in_memory"] + if "load_in_memory" in self.params["config"] + else True + ) + self.charset = self.get_charset() + self.tokens = self.get_tokens() self.params["config"]["padding_token"] = self.tokens["pad"] - def get_merged_charsets(self): + self.my_collate_function = OCRCollateFunction(self.params["config"]) + self.augmentation = ( + get_augmentation_transforms() + if self.params["config"]["augmentation"] + else None + ) + self.preprocessing = get_preprocessing_transforms( + params["config"]["preprocessings"] + ) + + def load_datasets(self): + """ + Load training and validation datasets + """ + self.train_dataset = OCRDataset( + set_name="train", + paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]), + charset=self.charset, + tokens=self.tokens, + preprocessing_transforms=self.preprocessing, + augmentation_transforms=self.augmentation, + load_in_memory=self.load_in_memory, + mean=self.mean, + std=self.std, + ) + + self.mean, self.std = self.train_dataset.compute_std_mean() + + for custom_name in self.params["val"].keys(): + self.valid_datasets[custom_name] = OCRDataset( + set_name="val", + paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]), + charset=self.charset, + tokens=self.tokens, + preprocessing_transforms=self.preprocessing, + augmentation_transforms=None, + load_in_memory=self.load_in_memory, + mean=self.mean, + std=self.std, + ) + + def load_ddp_samplers(self): + """ + Load training and validation data samplers + """ + if self.params["use_ddp"]: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=self.params["num_gpu"], + rank=self.params["ddp_rank"], + shuffle=True, + ) + for custom_name in self.valid_datasets.keys(): + self.valid_samplers[custom_name] = DistributedSampler( + self.valid_datasets[custom_name], + num_replicas=self.params["num_gpu"], + rank=self.params["ddp_rank"], + shuffle=False, + ) + else: + for custom_name in self.valid_datasets.keys(): + self.valid_samplers[custom_name] = None + + def load_dataloaders(self): + """ + Load training and validation data loaders + """ + self.train_loader = DataLoader( + self.train_dataset, + batch_size=self.batch_size["train"], + shuffle=True if self.train_sampler is None else False, + drop_last=False, + batch_sampler=self.train_sampler, + sampler=self.train_sampler, + num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], + pin_memory=self.pin_memory, + collate_fn=self.my_collate_function, + worker_init_fn=self.seed_worker, + generator=self.generator, + ) + + for key in self.valid_datasets.keys(): + self.valid_loaders[key] = DataLoader( + self.valid_datasets[key], + batch_size=self.batch_size["val"], + sampler=self.valid_samplers[key], + batch_sampler=self.valid_samplers[key], + shuffle=False, + num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], + pin_memory=self.pin_memory, + drop_last=False, + collate_fn=self.my_collate_function, + worker_init_fn=self.seed_worker, + generator=self.generator, + ) + + @staticmethod + def seed_worker(worker_id): + """ + Set worker seed + """ + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + def generate_test_loader(self, custom_name, sets_list): + """ + Load test dataset, data sampler and data loader + """ + if custom_name in self.test_loaders.keys(): + return + paths_and_sets = list() + for set_info in sets_list: + paths_and_sets.append( + {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]} + ) + self.test_datasets[custom_name] = OCRDataset( + set_name="test", + paths_and_sets=paths_and_sets, + charset=self.charset, + tokens=self.tokens, + preprocessing_transforms=self.preprocessing, + augmentation_transforms=None, + load_in_memory=self.load_in_memory, + mean=self.mean, + std=self.std, + ) + + if self.params["use_ddp"]: + self.test_samplers[custom_name] = DistributedSampler( + self.test_datasets[custom_name], + num_replicas=self.params["num_gpu"], + rank=self.params["ddp_rank"], + shuffle=False, + ) + else: + self.test_samplers[custom_name] = None + + self.test_loaders[custom_name] = DataLoader( + self.test_datasets[custom_name], + batch_size=self.batch_size["test"], + sampler=self.test_samplers[custom_name], + shuffle=False, + num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], + pin_memory=self.pin_memory, + drop_last=False, + collate_fn=self.my_collate_function, + worker_init_fn=self.seed_worker, + generator=self.generator, + ) + + def get_paths_and_sets(self, dataset_names_folds): + """ + Set the right path for each data set + """ + paths_and_sets = list() + for dataset_name, fold in dataset_names_folds: + path = self.params["datasets"][dataset_name] + paths_and_sets.append({"path": path, "set_name": fold}) + return paths_and_sets + + def get_charset(self): """ Merge the charset of the different datasets used """ + if "charset" in self.params: + return self.params["charset"] datasets = self.params["datasets"] charset = set() for key in datasets.keys(): @@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager): charset.remove("") return sorted(list(charset)) - def apply_specific_treatment_after_dataset_loading(self, dataset): - dataset.charset = self.charset - dataset.tokens = self.tokens - dataset.convert_labels() - - -class OCRDataset(GenericDataset): - """ - Specific class to handle OCR/HTR datasets - """ - - def __init__( - self, - params, - set_name, - custom_name, - paths_and_sets, - augmentation_transforms=None, - ): - super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets) - self.charset = None - self.tokens = None - # Factor to reduce the height and width of the feature vector before feeding the decoder. - self.reduce_dims_factor = np.array([32, 8, 1]) - self.collate_function = OCRCollateFunction - self.augmentation_transforms = augmentation_transforms - - def __getitem__(self, idx): - sample = copy.deepcopy(self.samples[idx]) - - if not self.load_in_memory: - sample["img"] = self.get_sample_img(idx) - - # Convert to numpy - sample["img"] = np.array(sample["img"]) - - # Data augmentation - if self.augmentation_transforms: - sample["img"] = self.augmentation_transforms(image=sample["img"])["image"] - - # Normalization - sample["img"] = (sample["img"] - self.mean) / self.std - - sample["img_reduced_shape"] = np.ceil( - sample["img"].shape / self.reduce_dims_factor - ).astype(int) - - if self.set_name == "train": - sample["img_reduced_shape"] = [ - max(1, t) for t in sample["img_reduced_shape"] - ] - - sample["img_position"] = [ - [0, sample["img"].shape[0]], - [0, sample["img"].shape[1]], - ] - - return sample - - def convert_labels(self): + def get_tokens(self): """ - Label str to token at character level + Get special tokens """ - for i in range(len(self.samples)): - self.samples[i] = self.convert_sample_labels(self.samples[i]) - - def convert_sample_labels(self, sample): - label = sample["label"] + return { + "end": len(self.charset), + "start": len(self.charset) + 1, + "pad": len(self.charset) + 2, + } - sample["label"] = label - sample["token_label"] = token_to_ind(self.charset, label) - sample["token_label"].append(self.tokens["end"]) - sample["label_len"] = len(sample["token_label"]) - sample["token_label"].insert(0, self.tokens["start"]) - return sample + def get_batch_size_by_set(self): + """ + Return batch size for each set + """ + return { + "train": self.params["batch_size"], + "val": self.params["valid_batch_size"] + if "valid_batch_size" in self.params + else self.params["batch_size"], + "test": self.params["test_batch_size"] + if "test_batch_size" in self.params + else 1, + } class OCRCollateFunction: