Skip to content
Snippets Groups Projects
Verified Commit 8ba39efd authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply 2bb85c50

parent 5d35d16d
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import copy
import json import json
import os import os
import random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision.io import ImageReadMode, read_image from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort from dan.datasets.utils import natural_sort
from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms from dan.utils import token_to_ind
class DatasetManager: class OCRDataset(Dataset):
def __init__(self, params, device: str): """
self.params = params Dataset class to handle dataset loading
self.dataset_class = None """
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None def __init__(
self.valid_loaders = dict() self,
self.test_loaders = dict() set_name,
paths_and_sets,
charset,
tokens,
preprocessing_transforms,
augmentation_transforms,
load_in_memory=False,
mean=None,
std=None,
):
self.set_name = set_name
self.charset = charset
self.tokens = tokens
self.load_in_memory = load_in_memory
self.mean = mean
self.std = std
self.train_sampler = None # Pre-processing, augmentation
self.valid_samplers = dict() self.preprocessing_transforms = preprocessing_transforms
self.test_samplers = dict() self.augmentation_transforms = augmentation_transforms
self.generator = torch.Generator() # Factor to reduce the height and width of the feature vector before feeding the decoder.
self.generator.manual_seed(0) self.reduce_dims_factor = np.array([32, 8, 1])
self.batch_size = { # Load samples and preprocess images if load_in_memory is True
"train": self.params["batch_size"], self.samples = self.load_samples(paths_and_sets)
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
def apply_specific_treatment_after_dataset_loading(self, dataset): # Curriculum config
raise NotImplementedError self.curriculum_config = None
def load_datasets(self): def __len__(self):
""" """
Load training and validation datasets Return the dataset size
""" """
self.train_dataset = self.dataset_class( return len(self.samples)
self.params,
"train",
self.params["train"]["name"],
self.get_paths_and_sets(self.params["train"]["datasets"]),
augmentation_transforms=(
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
),
)
(
self.params["config"]["mean"],
self.params["config"]["std"],
) = self.train_dataset.compute_std_mean()
self.my_collate_function = self.train_dataset.collate_function(
self.params["config"]
)
self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = self.dataset_class(
self.params,
"val",
custom_name,
self.get_paths_and_sets(self.params["val"][custom_name]),
augmentation_transforms=None,
)
self.apply_specific_treatment_after_dataset_loading(
self.valid_datasets[custom_name]
)
def load_ddp_samplers(self): def __getitem__(self, idx):
""" """
Load training and validation data samplers Return an item from the dataset (image and label)
""" """
if self.params["use_ddp"]: # Load preprocessed image
self.train_sampler = DistributedSampler( sample = copy.deepcopy(self.samples[idx])
self.train_dataset, if not self.load_in_memory:
num_replicas=self.params["num_gpu"], sample["img"] = self.get_sample_img(idx)
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self): # Convert to numpy
""" sample["img"] = np.array(sample["img"])
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys(): # Apply data augmentation
self.valid_loaders[key] = DataLoader( if self.augmentation_transforms:
self.valid_datasets[key], sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod # Image normalization
def seed_worker(worker_id): sample["img"] = (sample["img"] - self.mean) / self.std
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list): # Get final height and width
""" sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
Load test dataset, data sampler and data loader sample["img"]
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = self.dataset_class(
self.params,
"test",
custom_name,
paths_and_sets,
)
self.apply_specific_treatment_after_dataset_loading(
self.test_datasets[custom_name]
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
) )
def get_paths_and_sets(self, dataset_names_folds): # Convert label into tokens
paths_and_sets = list() sample["token_label"], sample["label_len"] = self.convert_sample_label(
for dataset_name, fold in dataset_names_folds: sample["label"]
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
class GenericDataset(Dataset):
"""
Main class to handle dataset loading
"""
def __init__(self, params, set_name, custom_name, paths_and_sets):
self.params = params
self.name = custom_name
self.set_name = set_name
self.mean = (
np.array(params["config"]["mean"])
if "mean" in params["config"].keys()
else None
)
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
)
self.preprocessing_transforms = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
) )
# Load samples and preprocess images if load_in_memory is True return sample
self.samples = self.load_samples(paths_and_sets)
self.curriculum_config = None
def __len__(self):
return len(self.samples)
@staticmethod @staticmethod
def load_image(path): def load_image(path):
...@@ -273,6 +124,17 @@ class GenericDataset(Dataset): ...@@ -273,6 +124,17 @@ class GenericDataset(Dataset):
) )
return samples return samples
def get_sample_img(self, i):
"""
Get image by index
"""
if self.load_in_memory:
return self.samples[i]["img"]
else:
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)
def compute_std_mean(self): def compute_std_mean(self):
""" """
Compute cumulated variance and mean of whole dataset Compute cumulated variance and mean of whole dataset
...@@ -299,13 +161,27 @@ class GenericDataset(Dataset): ...@@ -299,13 +161,27 @@ class GenericDataset(Dataset):
self.std = np.sqrt(diff / nb_pixels) self.std = np.sqrt(diff / nb_pixels)
return self.mean, self.std return self.mean, self.std
def get_sample_img(self, i): def compute_final_size(self, img):
""" """
Get image by index Compute the final image size and position after feature extraction
""" """
if self.load_in_memory: image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
return self.samples[i]["img"]
else: if self.set_name == "train":
return self.preprocessing_transforms( image_reduced_shape = [max(1, t) for t in image_reduced_shape]
self.load_image(self.samples[i]["path"])
) image_position = [
[0, img.shape[0]],
[0, img.shape[1]],
]
return image_reduced_shape, image_position
def convert_sample_label(self, label):
"""
Tokenize the label and return its length
"""
token_label = token_to_ind(self.charset, label)
token_label.append(self.tokens["end"])
label_len = len(token_label)
token_label.insert(0, self.tokens["start"])
return token_label, label_len
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import copy
import os import os
import pickle import pickle
import random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from dan.manager.dataset import DatasetManager, GenericDataset from dan.manager.dataset import OCRDataset
from dan.utils import pad_images, pad_sequences_1D, token_to_ind from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
from dan.utils import pad_images, pad_sequences_1D
class OCRDatasetManager(DatasetManager): class OCRDatasetManager:
"""
Specific class to handle OCR/HTR tasks
"""
def __init__(self, params, device: str): def __init__(self, params, device: str):
super(OCRDatasetManager, self).__init__(params, device) self.params = params
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
self.test_datasets = dict()
self.train_loader = None
self.valid_loaders = dict()
self.test_loaders = dict()
self.train_sampler = None
self.valid_samplers = dict()
self.test_samplers = dict()
self.dataset_class = OCRDataset self.mean = (
self.charset = ( np.array(params["config"]["mean"])
params["charset"] if "charset" in params else self.get_merged_charsets() if "mean" in params["config"].keys()
else None
) )
self.std = (
np.array(params["config"]["std"])
if "std" in params["config"].keys()
else None
)
self.generator = torch.Generator()
self.generator.manual_seed(0)
self.tokens = {"pad": len(self.charset) + 2} self.batch_size = self.get_batch_size_by_set()
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1 self.load_in_memory = (
self.params["config"]["load_in_memory"]
if "load_in_memory" in self.params["config"]
else True
)
self.charset = self.get_charset()
self.tokens = self.get_tokens()
self.params["config"]["padding_token"] = self.tokens["pad"] self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self): self.my_collate_function = OCRCollateFunction(self.params["config"])
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
)
self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"]
)
def load_datasets(self):
"""
Load training and validation datasets
"""
self.train_dataset = OCRDataset(
set_name="train",
paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=self.augmentation,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
self.mean, self.std = self.train_dataset.compute_std_mean()
for custom_name in self.params["val"].keys():
self.valid_datasets[custom_name] = OCRDataset(
set_name="val",
paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
def load_ddp_samplers(self):
"""
Load training and validation data samplers
"""
if self.params["use_ddp"]:
self.train_sampler = DistributedSampler(
self.train_dataset,
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=True,
)
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
for custom_name in self.valid_datasets.keys():
self.valid_samplers[custom_name] = None
def load_dataloaders(self):
"""
Load training and validation data loaders
"""
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size["train"],
shuffle=True if self.train_sampler is None else False,
drop_last=False,
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
for key in self.valid_datasets.keys():
self.valid_loaders[key] = DataLoader(
self.valid_datasets[key],
batch_size=self.batch_size["val"],
sampler=self.valid_samplers[key],
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
@staticmethod
def seed_worker(worker_id):
"""
Set worker seed
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def generate_test_loader(self, custom_name, sets_list):
"""
Load test dataset, data sampler and data loader
"""
if custom_name in self.test_loaders.keys():
return
paths_and_sets = list()
for set_info in sets_list:
paths_and_sets.append(
{"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
)
self.test_datasets[custom_name] = OCRDataset(
set_name="test",
paths_and_sets=paths_and_sets,
charset=self.charset,
tokens=self.tokens,
preprocessing_transforms=self.preprocessing,
augmentation_transforms=None,
load_in_memory=self.load_in_memory,
mean=self.mean,
std=self.std,
)
if self.params["use_ddp"]:
self.test_samplers[custom_name] = DistributedSampler(
self.test_datasets[custom_name],
num_replicas=self.params["num_gpu"],
rank=self.params["ddp_rank"],
shuffle=False,
)
else:
self.test_samplers[custom_name] = None
self.test_loaders[custom_name] = DataLoader(
self.test_datasets[custom_name],
batch_size=self.batch_size["test"],
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
)
def get_paths_and_sets(self, dataset_names_folds):
"""
Set the right path for each data set
"""
paths_and_sets = list()
for dataset_name, fold in dataset_names_folds:
path = self.params["datasets"][dataset_name]
paths_and_sets.append({"path": path, "set_name": fold})
return paths_and_sets
def get_charset(self):
""" """
Merge the charset of the different datasets used Merge the charset of the different datasets used
""" """
if "charset" in self.params:
return self.params["charset"]
datasets = self.params["datasets"] datasets = self.params["datasets"]
charset = set() charset = set()
for key in datasets.keys(): for key in datasets.keys():
...@@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager): ...@@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager):
charset.remove("") charset.remove("")
return sorted(list(charset)) return sorted(list(charset))
def apply_specific_treatment_after_dataset_loading(self, dataset): def get_tokens(self):
dataset.charset = self.charset
dataset.tokens = self.tokens
dataset.convert_labels()
class OCRDataset(GenericDataset):
"""
Specific class to handle OCR/HTR datasets
"""
def __init__(
self,
params,
set_name,
custom_name,
paths_and_sets,
augmentation_transforms=None,
):
super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
self.charset = None
self.tokens = None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self.reduce_dims_factor = np.array([32, 8, 1])
self.collate_function = OCRCollateFunction
self.augmentation_transforms = augmentation_transforms
def __getitem__(self, idx):
sample = copy.deepcopy(self.samples[idx])
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
# Convert to numpy
sample["img"] = np.array(sample["img"])
# Data augmentation
if self.augmentation_transforms:
sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
# Normalization
sample["img"] = (sample["img"] - self.mean) / self.std
sample["img_reduced_shape"] = np.ceil(
sample["img"].shape / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
sample["img_reduced_shape"] = [
max(1, t) for t in sample["img_reduced_shape"]
]
sample["img_position"] = [
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
return sample
def convert_labels(self):
""" """
Label str to token at character level Get special tokens
""" """
for i in range(len(self.samples)): return {
self.samples[i] = self.convert_sample_labels(self.samples[i]) "end": len(self.charset),
"start": len(self.charset) + 1,
def convert_sample_labels(self, sample): "pad": len(self.charset) + 2,
label = sample["label"] }
sample["label"] = label def get_batch_size_by_set(self):
sample["token_label"] = token_to_ind(self.charset, label) """
sample["token_label"].append(self.tokens["end"]) Return batch size for each set
sample["label_len"] = len(sample["token_label"]) """
sample["token_label"].insert(0, self.tokens["start"]) return {
return sample "train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
if "valid_batch_size" in self.params
else self.params["batch_size"],
"test": self.params["test_batch_size"]
if "test_batch_size" in self.params
else 1,
}
class OCRCollateFunction: class OCRCollateFunction:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment