# -*- coding: utf-8 -*- import os import pickle import random import numpy as np import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from dan.manager.dataset import OCRDataset from dan.transforms import ( get_augmentation_transforms, get_normalization_transforms, get_preprocessing_transforms, ) from dan.utils import pad_images, pad_sequences_1D class OCRDatasetManager: def __init__(self, params, device: str): self.params = params # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html self.pin_memory = device != "cpu" self.train_dataset = None self.valid_datasets = dict() self.test_datasets = dict() self.train_loader = None self.valid_loaders = dict() self.test_loaders = dict() self.train_sampler = None self.valid_samplers = dict() self.test_samplers = dict() self.generator = torch.Generator() self.generator.manual_seed(0) self.load_in_memory = ( self.params["config"]["load_in_memory"] if "load_in_memory" in self.params["config"] else True ) self.charset = self.get_charset() self.tokens = self.get_tokens() self.params["config"]["padding_token"] = self.tokens["pad"] self.my_collate_function = OCRCollateFunction(self.params["config"]) self.normalization = get_normalization_transforms() self.augmentation = ( get_augmentation_transforms() if self.params["config"]["augmentation"] else None ) self.preprocessing = get_preprocessing_transforms( params["config"]["preprocessings"], to_pil_image=True ) def load_datasets(self): """ Load training and validation datasets """ self.train_dataset = OCRDataset( set_name="train", paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]), charset=self.charset, tokens=self.tokens, preprocessing_transforms=self.preprocessing, normalization_transforms=self.normalization, augmentation_transforms=self.augmentation, load_in_memory=self.load_in_memory, ) for custom_name in self.params["val"].keys(): self.valid_datasets[custom_name] = OCRDataset( set_name="val", paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]), charset=self.charset, tokens=self.tokens, preprocessing_transforms=self.preprocessing, normalization_transforms=self.normalization, augmentation_transforms=None, load_in_memory=self.load_in_memory, ) def load_ddp_samplers(self): """ Load training and validation data samplers """ if self.params["use_ddp"]: self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=True, ) for custom_name in self.valid_datasets.keys(): self.valid_samplers[custom_name] = DistributedSampler( self.valid_datasets[custom_name], num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=False, ) else: for custom_name in self.valid_datasets.keys(): self.valid_samplers[custom_name] = None def load_dataloaders(self): """ Load training and validation data loaders """ self.train_loader = DataLoader( self.train_dataset, batch_size=self.params["batch_size"], shuffle=True if self.train_sampler is None else False, drop_last=False, batch_sampler=self.train_sampler, sampler=self.train_sampler, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], pin_memory=self.pin_memory, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, generator=self.generator, ) for key in self.valid_datasets.keys(): self.valid_loaders[key] = DataLoader( self.valid_datasets[key], batch_size=1, sampler=self.valid_samplers[key], batch_sampler=self.valid_samplers[key], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], pin_memory=self.pin_memory, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, generator=self.generator, ) @staticmethod def seed_worker(): """ Set worker seed """ worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) def generate_test_loader(self, custom_name, sets_list): """ Load test dataset, data sampler and data loader """ if custom_name in self.test_loaders.keys(): return paths_and_sets = list() for set_info in sets_list: paths_and_sets.append( {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]} ) self.test_datasets[custom_name] = OCRDataset( set_name="test", paths_and_sets=paths_and_sets, charset=self.charset, tokens=self.tokens, preprocessing_transforms=self.preprocessing, normalization_transforms=self.normalization, augmentation_transforms=None, load_in_memory=self.load_in_memory, ) if self.params["use_ddp"]: self.test_samplers[custom_name] = DistributedSampler( self.test_datasets[custom_name], num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=False, ) else: self.test_samplers[custom_name] = None self.test_loaders[custom_name] = DataLoader( self.test_datasets[custom_name], batch_size=1, sampler=self.test_samplers[custom_name], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], pin_memory=self.pin_memory, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, generator=self.generator, ) def get_paths_and_sets(self, dataset_names_folds): """ Set the right path for each data set """ paths_and_sets = list() for dataset_name, fold in dataset_names_folds: path = self.params["datasets"][dataset_name] paths_and_sets.append({"path": path, "set_name": fold}) return paths_and_sets def get_charset(self): """ Merge the charset of the different datasets used """ if "charset" in self.params: return self.params["charset"] datasets = self.params["datasets"] charset = set() for key in datasets.keys(): with open(os.path.join(datasets[key], "charset.pkl"), "rb") as f: charset = charset.union(set(pickle.load(f))) if "" in charset: charset.remove("") return sorted(list(charset)) def get_tokens(self): """ Get special tokens """ return { "end": len(self.charset), "start": len(self.charset) + 1, "pad": len(self.charset) + 2, } class OCRCollateFunction: """ Merge samples data to mini-batch data for OCR task """ def __init__(self, config): self.label_padding_value = config["padding_token"] self.config = config def __call__(self, batch_data): labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] labels = pad_sequences_1D(labels, padding_value=self.label_padding_value).long() imgs = [batch_data[i]["img"] for i in range(len(batch_data))] imgs = pad_images(imgs) formatted_batch_data = { formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))] for formatted_key, initial_key in zip( [ "names", "labels_len", "raw_labels", "imgs_position", "imgs_reduced_shape", ], ["name", "label_len", "label", "img_position", "img_reduced_shape"], ) } formatted_batch_data.update( { "imgs": imgs, "labels": labels, } ) return formatted_batch_data