Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
8ba39efd
Verified
Commit
8ba39efd
authored
1 year ago
by
Mélodie Boillet
Browse files
Options
Downloads
Patches
Plain Diff
Apply
2bb85c50
parent
5d35d16d
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
dan/manager/dataset.py
+90
-214
90 additions, 214 deletions
dan/manager/dataset.py
dan/manager/ocr.py
+228
-88
228 additions, 88 deletions
dan/manager/ocr.py
with
318 additions
and
302 deletions
dan/manager/dataset.py
+
90
−
214
View file @
8ba39efd
# -*- coding: utf-8 -*-
import
copy
import
json
import
os
import
random
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data
import
Dataset
from
torchvision.io
import
ImageReadMode
,
read_image
from
dan.datasets.utils
import
natural_sort
from
dan.
transforms
import
get_augmentation_transforms
,
get_preprocessing_transforms
from
dan.
utils
import
token_to_ind
class
DatasetManager
:
def
__init__
(
self
,
params
,
device
:
str
):
self
.
params
=
params
self
.
dataset_class
=
None
self
.
my_collate_function
=
None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self
.
pin_memory
=
device
!=
"
cpu
"
self
.
train_dataset
=
None
self
.
valid_datasets
=
dict
()
self
.
test_datasets
=
dict
()
class
OCRDataset
(
Dataset
):
"""
Dataset class to handle dataset loading
"""
self
.
train_loader
=
None
self
.
valid_loaders
=
dict
()
self
.
test_loaders
=
dict
()
def
__init__
(
self
,
set_name
,
paths_and_sets
,
charset
,
tokens
,
preprocessing_transforms
,
augmentation_transforms
,
load_in_memory
=
False
,
mean
=
None
,
std
=
None
,
):
self
.
set_name
=
set_name
self
.
charset
=
charset
self
.
tokens
=
tokens
self
.
load_in_memory
=
load_in_memory
self
.
mean
=
mean
self
.
std
=
std
self
.
train_sampler
=
N
on
e
self
.
valid_samplers
=
dict
()
self
.
test_samplers
=
dict
()
# Pre-processing, augmentati
on
self
.
preprocessing_transforms
=
preprocessing_transforms
self
.
augmentation_transforms
=
augmentation_transforms
self
.
generator
=
torch
.
Generator
()
self
.
generator
.
manual_seed
(
0
)
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self
.
reduce_dims_factor
=
np
.
array
([
32
,
8
,
1
]
)
self
.
batch_size
=
{
"
train
"
:
self
.
params
[
"
batch_size
"
],
"
val
"
:
self
.
params
[
"
valid_batch_size
"
]
if
"
valid_batch_size
"
in
self
.
params
else
self
.
params
[
"
batch_size
"
],
"
test
"
:
self
.
params
[
"
test_batch_size
"
]
if
"
test_batch_size
"
in
self
.
params
else
1
,
}
# Load samples and preprocess images if load_in_memory is True
self
.
samples
=
self
.
load_samples
(
paths_and_sets
)
def
apply_specific_treatment_after_dataset_loading
(
self
,
dataset
):
raise
NotImplementedError
# Curriculum config
self
.
curriculum_config
=
None
def
load_datasets
(
self
):
def
__len__
(
self
):
"""
Load training and validation
dataset
s
Return the
dataset
size
"""
self
.
train_dataset
=
self
.
dataset_class
(
self
.
params
,
"
train
"
,
self
.
params
[
"
train
"
][
"
name
"
],
self
.
get_paths_and_sets
(
self
.
params
[
"
train
"
][
"
datasets
"
]),
augmentation_transforms
=
(
get_augmentation_transforms
()
if
self
.
params
[
"
config
"
][
"
augmentation
"
]
else
None
),
)
(
self
.
params
[
"
config
"
][
"
mean
"
],
self
.
params
[
"
config
"
][
"
std
"
],
)
=
self
.
train_dataset
.
compute_std_mean
()
self
.
my_collate_function
=
self
.
train_dataset
.
collate_function
(
self
.
params
[
"
config
"
]
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
train_dataset
)
for
custom_name
in
self
.
params
[
"
val
"
].
keys
():
self
.
valid_datasets
[
custom_name
]
=
self
.
dataset_class
(
self
.
params
,
"
val
"
,
custom_name
,
self
.
get_paths_and_sets
(
self
.
params
[
"
val
"
][
custom_name
]),
augmentation_transforms
=
None
,
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
valid_datasets
[
custom_name
]
)
return
len
(
self
.
samples
)
def
load_ddp_samplers
(
self
):
def
__getitem__
(
self
,
idx
):
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
if
self
.
params
[
"
use_ddp
"
]:
self
.
train_sampler
=
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
True
,
)
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
valid_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
None
# Load preprocessed image
sample
=
copy
.
deepcopy
(
self
.
samples
[
idx
])
if
not
self
.
load_in_memory
:
sample
[
"
img
"
]
=
self
.
get_sample_img
(
idx
)
def
load_dataloaders
(
self
):
"""
Load training and validation data loaders
"""
self
.
train_loader
=
DataLoader
(
self
.
train_dataset
,
batch_size
=
self
.
batch_size
[
"
train
"
],
shuffle
=
True
if
self
.
train_sampler
is
None
else
False
,
drop_last
=
False
,
batch_sampler
=
self
.
train_sampler
,
sampler
=
self
.
train_sampler
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
# Convert to numpy
sample
[
"
img
"
]
=
np
.
array
(
sample
[
"
img
"
])
for
key
in
self
.
valid_datasets
.
keys
():
self
.
valid_loaders
[
key
]
=
DataLoader
(
self
.
valid_datasets
[
key
],
batch_size
=
self
.
batch_size
[
"
val
"
],
sampler
=
self
.
valid_samplers
[
key
],
batch_sampler
=
self
.
valid_samplers
[
key
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
# Apply data augmentation
if
self
.
augmentation_transforms
:
sample
[
"
img
"
]
=
self
.
augmentation_transforms
(
image
=
sample
[
"
img
"
])[
"
image
"
]
@staticmethod
def
seed_worker
(
worker_id
):
worker_seed
=
torch
.
initial_seed
()
%
2
**
32
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
# Image normalization
sample
[
"
img
"
]
=
(
sample
[
"
img
"
]
-
self
.
mean
)
/
self
.
std
def
generate_test_loader
(
self
,
custom_name
,
sets_list
):
"""
Load test dataset, data sampler and data loader
"""
if
custom_name
in
self
.
test_loaders
.
keys
():
return
paths_and_sets
=
list
()
for
set_info
in
sets_list
:
paths_and_sets
.
append
(
{
"
path
"
:
self
.
params
[
"
datasets
"
][
set_info
[
0
]],
"
set_name
"
:
set_info
[
1
]}
)
self
.
test_datasets
[
custom_name
]
=
self
.
dataset_class
(
self
.
params
,
"
test
"
,
custom_name
,
paths_and_sets
,
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
test_datasets
[
custom_name
]
)
if
self
.
params
[
"
use_ddp
"
]:
self
.
test_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
test_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
self
.
test_samplers
[
custom_name
]
=
None
self
.
test_loaders
[
custom_name
]
=
DataLoader
(
self
.
test_datasets
[
custom_name
],
batch_size
=
self
.
batch_size
[
"
test
"
],
sampler
=
self
.
test_samplers
[
custom_name
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
# Get final height and width
sample
[
"
img_reduced_shape
"
],
sample
[
"
img_position
"
]
=
self
.
compute_final_size
(
sample
[
"
img
"
]
)
def
get_paths_and_sets
(
self
,
dataset_names_folds
):
paths_and_sets
=
list
()
for
dataset_name
,
fold
in
dataset_names_folds
:
path
=
self
.
params
[
"
datasets
"
][
dataset_name
]
paths_and_sets
.
append
({
"
path
"
:
path
,
"
set_name
"
:
fold
})
return
paths_and_sets
class
GenericDataset
(
Dataset
):
"""
Main class to handle dataset loading
"""
def
__init__
(
self
,
params
,
set_name
,
custom_name
,
paths_and_sets
):
self
.
params
=
params
self
.
name
=
custom_name
self
.
set_name
=
set_name
self
.
mean
=
(
np
.
array
(
params
[
"
config
"
][
"
mean
"
])
if
"
mean
"
in
params
[
"
config
"
].
keys
()
else
None
)
self
.
std
=
(
np
.
array
(
params
[
"
config
"
][
"
std
"
])
if
"
std
"
in
params
[
"
config
"
].
keys
()
else
None
)
self
.
preprocessing_transforms
=
get_preprocessing_transforms
(
params
[
"
config
"
][
"
preprocessings
"
]
)
self
.
load_in_memory
=
(
self
.
params
[
"
config
"
][
"
load_in_memory
"
]
if
"
load_in_memory
"
in
self
.
params
[
"
config
"
]
else
True
# Convert label into tokens
sample
[
"
token_label
"
],
sample
[
"
label_len
"
]
=
self
.
convert_sample_label
(
sample
[
"
label
"
]
)
# Load samples and preprocess images if load_in_memory is True
self
.
samples
=
self
.
load_samples
(
paths_and_sets
)
self
.
curriculum_config
=
None
def
__len__
(
self
):
return
len
(
self
.
samples
)
return
sample
@staticmethod
def
load_image
(
path
):
...
...
@@ -273,6 +124,17 @@ class GenericDataset(Dataset):
)
return
samples
def
get_sample_img
(
self
,
i
):
"""
Get image by index
"""
if
self
.
load_in_memory
:
return
self
.
samples
[
i
][
"
img
"
]
else
:
return
self
.
preprocessing_transforms
(
self
.
load_image
(
self
.
samples
[
i
][
"
path
"
])
)
def
compute_std_mean
(
self
):
"""
Compute cumulated variance and mean of whole dataset
...
...
@@ -299,13 +161,27 @@ class GenericDataset(Dataset):
self
.
std
=
np
.
sqrt
(
diff
/
nb_pixels
)
return
self
.
mean
,
self
.
std
def
get_sample_img
(
self
,
i
):
def
compute_final_size
(
self
,
i
mg
):
"""
Get image by index
Compute the final image size and position after feature extraction
"""
if
self
.
load_in_memory
:
return
self
.
samples
[
i
][
"
img
"
]
else
:
return
self
.
preprocessing_transforms
(
self
.
load_image
(
self
.
samples
[
i
][
"
path
"
])
)
image_reduced_shape
=
np
.
ceil
(
img
.
shape
/
self
.
reduce_dims_factor
).
astype
(
int
)
if
self
.
set_name
==
"
train
"
:
image_reduced_shape
=
[
max
(
1
,
t
)
for
t
in
image_reduced_shape
]
image_position
=
[
[
0
,
img
.
shape
[
0
]],
[
0
,
img
.
shape
[
1
]],
]
return
image_reduced_shape
,
image_position
def
convert_sample_label
(
self
,
label
):
"""
Tokenize the label and return its length
"""
token_label
=
token_to_ind
(
self
.
charset
,
label
)
token_label
.
append
(
self
.
tokens
[
"
end
"
])
label_len
=
len
(
token_label
)
token_label
.
insert
(
0
,
self
.
tokens
[
"
start
"
])
return
token_label
,
label_len
This diff is collapsed.
Click to expand it.
dan/manager/ocr.py
+
228
−
88
View file @
8ba39efd
# -*- coding: utf-8 -*-
import
copy
import
os
import
pickle
import
random
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
dan.manager.dataset
import
DatasetManager
,
GenericDataset
from
dan.utils
import
pad_images
,
pad_sequences_1D
,
token_to_ind
from
dan.manager.dataset
import
OCRDataset
from
dan.transforms
import
get_augmentation_transforms
,
get_preprocessing_transforms
from
dan.utils
import
pad_images
,
pad_sequences_1D
class
OCRDatasetManager
(
DatasetManager
):
"""
Specific class to handle OCR/HTR tasks
"""
class
OCRDatasetManager
:
def
__init__
(
self
,
params
,
device
:
str
):
super
(
OCRDatasetManager
,
self
).
__init__
(
params
,
device
)
self
.
params
=
params
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self
.
pin_memory
=
device
!=
"
cpu
"
self
.
train_dataset
=
None
self
.
valid_datasets
=
dict
()
self
.
test_datasets
=
dict
()
self
.
train_loader
=
None
self
.
valid_loaders
=
dict
()
self
.
test_loaders
=
dict
()
self
.
train_sampler
=
None
self
.
valid_samplers
=
dict
()
self
.
test_samplers
=
dict
()
self
.
dataset_class
=
OCRDataset
self
.
charset
=
(
params
[
"
charset
"
]
if
"
charset
"
in
params
else
self
.
get_merged_charsets
()
self
.
mean
=
(
np
.
array
(
params
[
"
config
"
][
"
mean
"
])
if
"
mean
"
in
params
[
"
config
"
].
keys
()
else
None
)
self
.
std
=
(
np
.
array
(
params
[
"
config
"
][
"
std
"
])
if
"
std
"
in
params
[
"
config
"
].
keys
()
else
None
)
self
.
generator
=
torch
.
Generator
()
self
.
generator
.
manual_seed
(
0
)
self
.
tokens
=
{
"
pad
"
:
len
(
self
.
charset
)
+
2
}
self
.
tokens
[
"
end
"
]
=
len
(
self
.
charset
)
self
.
tokens
[
"
start
"
]
=
len
(
self
.
charset
)
+
1
self
.
batch_size
=
self
.
get_batch_size_by_set
()
self
.
load_in_memory
=
(
self
.
params
[
"
config
"
][
"
load_in_memory
"
]
if
"
load_in_memory
"
in
self
.
params
[
"
config
"
]
else
True
)
self
.
charset
=
self
.
get_charset
()
self
.
tokens
=
self
.
get_tokens
()
self
.
params
[
"
config
"
][
"
padding_token
"
]
=
self
.
tokens
[
"
pad
"
]
def
get_merged_charsets
(
self
):
self
.
my_collate_function
=
OCRCollateFunction
(
self
.
params
[
"
config
"
])
self
.
augmentation
=
(
get_augmentation_transforms
()
if
self
.
params
[
"
config
"
][
"
augmentation
"
]
else
None
)
self
.
preprocessing
=
get_preprocessing_transforms
(
params
[
"
config
"
][
"
preprocessings
"
]
)
def
load_datasets
(
self
):
"""
Load training and validation datasets
"""
self
.
train_dataset
=
OCRDataset
(
set_name
=
"
train
"
,
paths_and_sets
=
self
.
get_paths_and_sets
(
self
.
params
[
"
train
"
][
"
datasets
"
]),
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
augmentation_transforms
=
self
.
augmentation
,
load_in_memory
=
self
.
load_in_memory
,
mean
=
self
.
mean
,
std
=
self
.
std
,
)
self
.
mean
,
self
.
std
=
self
.
train_dataset
.
compute_std_mean
()
for
custom_name
in
self
.
params
[
"
val
"
].
keys
():
self
.
valid_datasets
[
custom_name
]
=
OCRDataset
(
set_name
=
"
val
"
,
paths_and_sets
=
self
.
get_paths_and_sets
(
self
.
params
[
"
val
"
][
custom_name
]),
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
augmentation_transforms
=
None
,
load_in_memory
=
self
.
load_in_memory
,
mean
=
self
.
mean
,
std
=
self
.
std
,
)
def
load_ddp_samplers
(
self
):
"""
Load training and validation data samplers
"""
if
self
.
params
[
"
use_ddp
"
]:
self
.
train_sampler
=
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
True
,
)
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
valid_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
None
def
load_dataloaders
(
self
):
"""
Load training and validation data loaders
"""
self
.
train_loader
=
DataLoader
(
self
.
train_dataset
,
batch_size
=
self
.
batch_size
[
"
train
"
],
shuffle
=
True
if
self
.
train_sampler
is
None
else
False
,
drop_last
=
False
,
batch_sampler
=
self
.
train_sampler
,
sampler
=
self
.
train_sampler
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
for
key
in
self
.
valid_datasets
.
keys
():
self
.
valid_loaders
[
key
]
=
DataLoader
(
self
.
valid_datasets
[
key
],
batch_size
=
self
.
batch_size
[
"
val
"
],
sampler
=
self
.
valid_samplers
[
key
],
batch_sampler
=
self
.
valid_samplers
[
key
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
@staticmethod
def
seed_worker
(
worker_id
):
"""
Set worker seed
"""
worker_seed
=
torch
.
initial_seed
()
%
2
**
32
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
def
generate_test_loader
(
self
,
custom_name
,
sets_list
):
"""
Load test dataset, data sampler and data loader
"""
if
custom_name
in
self
.
test_loaders
.
keys
():
return
paths_and_sets
=
list
()
for
set_info
in
sets_list
:
paths_and_sets
.
append
(
{
"
path
"
:
self
.
params
[
"
datasets
"
][
set_info
[
0
]],
"
set_name
"
:
set_info
[
1
]}
)
self
.
test_datasets
[
custom_name
]
=
OCRDataset
(
set_name
=
"
test
"
,
paths_and_sets
=
paths_and_sets
,
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
augmentation_transforms
=
None
,
load_in_memory
=
self
.
load_in_memory
,
mean
=
self
.
mean
,
std
=
self
.
std
,
)
if
self
.
params
[
"
use_ddp
"
]:
self
.
test_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
test_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
self
.
test_samplers
[
custom_name
]
=
None
self
.
test_loaders
[
custom_name
]
=
DataLoader
(
self
.
test_datasets
[
custom_name
],
batch_size
=
self
.
batch_size
[
"
test
"
],
sampler
=
self
.
test_samplers
[
custom_name
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
def
get_paths_and_sets
(
self
,
dataset_names_folds
):
"""
Set the right path for each data set
"""
paths_and_sets
=
list
()
for
dataset_name
,
fold
in
dataset_names_folds
:
path
=
self
.
params
[
"
datasets
"
][
dataset_name
]
paths_and_sets
.
append
({
"
path
"
:
path
,
"
set_name
"
:
fold
})
return
paths_and_sets
def
get_charset
(
self
):
"""
Merge the charset of the different datasets used
"""
if
"
charset
"
in
self
.
params
:
return
self
.
params
[
"
charset
"
]
datasets
=
self
.
params
[
"
datasets
"
]
charset
=
set
()
for
key
in
datasets
.
keys
():
...
...
@@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager):
charset
.
remove
(
""
)
return
sorted
(
list
(
charset
))
def
apply_specific_treatment_after_dataset_loading
(
self
,
dataset
):
dataset
.
charset
=
self
.
charset
dataset
.
tokens
=
self
.
tokens
dataset
.
convert_labels
()
class
OCRDataset
(
GenericDataset
):
"""
Specific class to handle OCR/HTR datasets
"""
def
__init__
(
self
,
params
,
set_name
,
custom_name
,
paths_and_sets
,
augmentation_transforms
=
None
,
):
super
(
OCRDataset
,
self
).
__init__
(
params
,
set_name
,
custom_name
,
paths_and_sets
)
self
.
charset
=
None
self
.
tokens
=
None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self
.
reduce_dims_factor
=
np
.
array
([
32
,
8
,
1
])
self
.
collate_function
=
OCRCollateFunction
self
.
augmentation_transforms
=
augmentation_transforms
def
__getitem__
(
self
,
idx
):
sample
=
copy
.
deepcopy
(
self
.
samples
[
idx
])
if
not
self
.
load_in_memory
:
sample
[
"
img
"
]
=
self
.
get_sample_img
(
idx
)
# Convert to numpy
sample
[
"
img
"
]
=
np
.
array
(
sample
[
"
img
"
])
# Data augmentation
if
self
.
augmentation_transforms
:
sample
[
"
img
"
]
=
self
.
augmentation_transforms
(
image
=
sample
[
"
img
"
])[
"
image
"
]
# Normalization
sample
[
"
img
"
]
=
(
sample
[
"
img
"
]
-
self
.
mean
)
/
self
.
std
sample
[
"
img_reduced_shape
"
]
=
np
.
ceil
(
sample
[
"
img
"
].
shape
/
self
.
reduce_dims_factor
).
astype
(
int
)
if
self
.
set_name
==
"
train
"
:
sample
[
"
img_reduced_shape
"
]
=
[
max
(
1
,
t
)
for
t
in
sample
[
"
img_reduced_shape
"
]
]
sample
[
"
img_position
"
]
=
[
[
0
,
sample
[
"
img
"
].
shape
[
0
]],
[
0
,
sample
[
"
img
"
].
shape
[
1
]],
]
return
sample
def
convert_labels
(
self
):
def
get_tokens
(
self
):
"""
Label str to token at character level
Get special tokens
"""
for
i
in
range
(
len
(
self
.
samples
)):
self
.
samples
[
i
]
=
self
.
convert_sample_labels
(
self
.
samples
[
i
])
def
convert_sample_labels
(
self
,
sample
):
label
=
sample
[
"
label
"
]
return
{
"
end
"
:
len
(
self
.
charset
),
"
start
"
:
len
(
self
.
charset
)
+
1
,
"
pad
"
:
len
(
self
.
charset
)
+
2
,
}
sample
[
"
label
"
]
=
label
sample
[
"
token_label
"
]
=
token_to_ind
(
self
.
charset
,
label
)
sample
[
"
token_label
"
].
append
(
self
.
tokens
[
"
end
"
])
sample
[
"
label_len
"
]
=
len
(
sample
[
"
token_label
"
])
sample
[
"
token_label
"
].
insert
(
0
,
self
.
tokens
[
"
start
"
])
return
sample
def
get_batch_size_by_set
(
self
):
"""
Return batch size for each set
"""
return
{
"
train
"
:
self
.
params
[
"
batch_size
"
],
"
val
"
:
self
.
params
[
"
valid_batch_size
"
]
if
"
valid_batch_size
"
in
self
.
params
else
self
.
params
[
"
batch_size
"
],
"
test
"
:
self
.
params
[
"
test_batch_size
"
]
if
"
test_batch_size
"
in
self
.
params
else
1
,
}
class
OCRCollateFunction
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment