Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package Registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
8ba39efd
Verified
Commit
8ba39efd
authored
1 year ago
by
Mélodie Boillet
Browse files
Options
Downloads
Patches
Plain Diff
Apply
2bb85c50
parent
5d35d16d
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
dan/manager/dataset.py
+90
-214
90 additions, 214 deletions
dan/manager/dataset.py
dan/manager/ocr.py
+228
-88
228 additions, 88 deletions
dan/manager/ocr.py
with
318 additions
and
302 deletions
dan/manager/dataset.py
+
90
−
214
View file @
8ba39efd
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
copy
import
json
import
json
import
os
import
os
import
random
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
Dataset
from
torch.utils.data.distributed
import
DistributedSampler
from
torchvision.io
import
ImageReadMode
,
read_image
from
torchvision.io
import
ImageReadMode
,
read_image
from
dan.datasets.utils
import
natural_sort
from
dan.datasets.utils
import
natural_sort
from
dan.
transforms
import
get_augmentation_transforms
,
get_preprocessing_transforms
from
dan.
utils
import
token_to_ind
class
DatasetManager
:
class
OCRDataset
(
Dataset
):
def
__init__
(
self
,
params
,
device
:
str
):
"""
self
.
params
=
params
Dataset class to handle dataset loading
self
.
dataset_class
=
None
"""
self
.
my_collate_function
=
None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self
.
pin_memory
=
device
!=
"
cpu
"
self
.
train_dataset
=
None
self
.
valid_datasets
=
dict
()
self
.
test_datasets
=
dict
()
self
.
train_loader
=
None
def
__init__
(
self
.
valid_loaders
=
dict
()
self
,
self
.
test_loaders
=
dict
()
set_name
,
paths_and_sets
,
charset
,
tokens
,
preprocessing_transforms
,
augmentation_transforms
,
load_in_memory
=
False
,
mean
=
None
,
std
=
None
,
):
self
.
set_name
=
set_name
self
.
charset
=
charset
self
.
tokens
=
tokens
self
.
load_in_memory
=
load_in_memory
self
.
mean
=
mean
self
.
std
=
std
self
.
train_sampler
=
N
on
e
# Pre-processing, augmentati
on
self
.
valid_samplers
=
dict
()
self
.
preprocessing_transforms
=
preprocessing_transforms
self
.
test_samplers
=
dict
()
self
.
augmentation_transforms
=
augmentation_transforms
self
.
generator
=
torch
.
Generator
()
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self
.
generator
.
manual_seed
(
0
)
self
.
reduce_dims_factor
=
np
.
array
([
32
,
8
,
1
]
)
self
.
batch_size
=
{
# Load samples and preprocess images if load_in_memory is True
"
train
"
:
self
.
params
[
"
batch_size
"
],
self
.
samples
=
self
.
load_samples
(
paths_and_sets
)
"
val
"
:
self
.
params
[
"
valid_batch_size
"
]
if
"
valid_batch_size
"
in
self
.
params
else
self
.
params
[
"
batch_size
"
],
"
test
"
:
self
.
params
[
"
test_batch_size
"
]
if
"
test_batch_size
"
in
self
.
params
else
1
,
}
def
apply_specific_treatment_after_dataset_loading
(
self
,
dataset
):
# Curriculum config
raise
NotImplementedError
self
.
curriculum_config
=
None
def
load_datasets
(
self
):
def
__len__
(
self
):
"""
"""
Load training and validation
dataset
s
Return the
dataset
size
"""
"""
self
.
train_dataset
=
self
.
dataset_class
(
return
len
(
self
.
samples
)
self
.
params
,
"
train
"
,
self
.
params
[
"
train
"
][
"
name
"
],
self
.
get_paths_and_sets
(
self
.
params
[
"
train
"
][
"
datasets
"
]),
augmentation_transforms
=
(
get_augmentation_transforms
()
if
self
.
params
[
"
config
"
][
"
augmentation
"
]
else
None
),
)
(
self
.
params
[
"
config
"
][
"
mean
"
],
self
.
params
[
"
config
"
][
"
std
"
],
)
=
self
.
train_dataset
.
compute_std_mean
()
self
.
my_collate_function
=
self
.
train_dataset
.
collate_function
(
self
.
params
[
"
config
"
]
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
train_dataset
)
for
custom_name
in
self
.
params
[
"
val
"
].
keys
():
self
.
valid_datasets
[
custom_name
]
=
self
.
dataset_class
(
self
.
params
,
"
val
"
,
custom_name
,
self
.
get_paths_and_sets
(
self
.
params
[
"
val
"
][
custom_name
]),
augmentation_transforms
=
None
,
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
valid_datasets
[
custom_name
]
)
def
load_ddp_samplers
(
self
):
def
__getitem__
(
self
,
idx
):
"""
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
"""
if
self
.
params
[
"
use_ddp
"
]:
# Load preprocessed image
self
.
train_sampler
=
DistributedSampler
(
sample
=
copy
.
deepcopy
(
self
.
samples
[
idx
])
self
.
train_dataset
,
if
not
self
.
load_in_memory
:
num_replicas
=
self
.
params
[
"
num_gpu
"
],
sample
[
"
img
"
]
=
self
.
get_sample_img
(
idx
)
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
True
,
)
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
valid_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
None
def
load_dataloaders
(
self
):
# Convert to numpy
"""
sample
[
"
img
"
]
=
np
.
array
(
sample
[
"
img
"
])
Load training and validation data loaders
"""
self
.
train_loader
=
DataLoader
(
self
.
train_dataset
,
batch_size
=
self
.
batch_size
[
"
train
"
],
shuffle
=
True
if
self
.
train_sampler
is
None
else
False
,
drop_last
=
False
,
batch_sampler
=
self
.
train_sampler
,
sampler
=
self
.
train_sampler
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
for
key
in
self
.
valid_datasets
.
keys
():
# Apply data augmentation
self
.
valid_loaders
[
key
]
=
DataLoader
(
if
self
.
augmentation_transforms
:
self
.
valid_datasets
[
key
],
sample
[
"
img
"
]
=
self
.
augmentation_transforms
(
image
=
sample
[
"
img
"
])[
"
image
"
]
batch_size
=
self
.
batch_size
[
"
val
"
],
sampler
=
self
.
valid_samplers
[
key
],
batch_sampler
=
self
.
valid_samplers
[
key
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
@staticmethod
# Image normalization
def
seed_worker
(
worker_id
):
sample
[
"
img
"
]
=
(
sample
[
"
img
"
]
-
self
.
mean
)
/
self
.
std
worker_seed
=
torch
.
initial_seed
()
%
2
**
32
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
def
generate_test_loader
(
self
,
custom_name
,
sets_list
):
# Get final height and width
"""
sample
[
"
img_reduced_shape
"
],
sample
[
"
img_position
"
]
=
self
.
compute_final_size
(
Load test dataset, data sampler and data loader
sample
[
"
img
"
]
"""
if
custom_name
in
self
.
test_loaders
.
keys
():
return
paths_and_sets
=
list
()
for
set_info
in
sets_list
:
paths_and_sets
.
append
(
{
"
path
"
:
self
.
params
[
"
datasets
"
][
set_info
[
0
]],
"
set_name
"
:
set_info
[
1
]}
)
self
.
test_datasets
[
custom_name
]
=
self
.
dataset_class
(
self
.
params
,
"
test
"
,
custom_name
,
paths_and_sets
,
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
test_datasets
[
custom_name
]
)
if
self
.
params
[
"
use_ddp
"
]:
self
.
test_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
test_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
self
.
test_samplers
[
custom_name
]
=
None
self
.
test_loaders
[
custom_name
]
=
DataLoader
(
self
.
test_datasets
[
custom_name
],
batch_size
=
self
.
batch_size
[
"
test
"
],
sampler
=
self
.
test_samplers
[
custom_name
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
)
def
get_paths_and_sets
(
self
,
dataset_names_folds
):
# Convert label into tokens
paths_and_sets
=
list
()
sample
[
"
token_label
"
],
sample
[
"
label_len
"
]
=
self
.
convert_sample_label
(
for
dataset_name
,
fold
in
dataset_names_folds
:
sample
[
"
label
"
]
path
=
self
.
params
[
"
datasets
"
][
dataset_name
]
paths_and_sets
.
append
({
"
path
"
:
path
,
"
set_name
"
:
fold
})
return
paths_and_sets
class
GenericDataset
(
Dataset
):
"""
Main class to handle dataset loading
"""
def
__init__
(
self
,
params
,
set_name
,
custom_name
,
paths_and_sets
):
self
.
params
=
params
self
.
name
=
custom_name
self
.
set_name
=
set_name
self
.
mean
=
(
np
.
array
(
params
[
"
config
"
][
"
mean
"
])
if
"
mean
"
in
params
[
"
config
"
].
keys
()
else
None
)
self
.
std
=
(
np
.
array
(
params
[
"
config
"
][
"
std
"
])
if
"
std
"
in
params
[
"
config
"
].
keys
()
else
None
)
self
.
preprocessing_transforms
=
get_preprocessing_transforms
(
params
[
"
config
"
][
"
preprocessings
"
]
)
self
.
load_in_memory
=
(
self
.
params
[
"
config
"
][
"
load_in_memory
"
]
if
"
load_in_memory
"
in
self
.
params
[
"
config
"
]
else
True
)
)
# Load samples and preprocess images if load_in_memory is True
return
sample
self
.
samples
=
self
.
load_samples
(
paths_and_sets
)
self
.
curriculum_config
=
None
def
__len__
(
self
):
return
len
(
self
.
samples
)
@staticmethod
@staticmethod
def
load_image
(
path
):
def
load_image
(
path
):
...
@@ -273,6 +124,17 @@ class GenericDataset(Dataset):
...
@@ -273,6 +124,17 @@ class GenericDataset(Dataset):
)
)
return
samples
return
samples
def
get_sample_img
(
self
,
i
):
"""
Get image by index
"""
if
self
.
load_in_memory
:
return
self
.
samples
[
i
][
"
img
"
]
else
:
return
self
.
preprocessing_transforms
(
self
.
load_image
(
self
.
samples
[
i
][
"
path
"
])
)
def
compute_std_mean
(
self
):
def
compute_std_mean
(
self
):
"""
"""
Compute cumulated variance and mean of whole dataset
Compute cumulated variance and mean of whole dataset
...
@@ -299,13 +161,27 @@ class GenericDataset(Dataset):
...
@@ -299,13 +161,27 @@ class GenericDataset(Dataset):
self
.
std
=
np
.
sqrt
(
diff
/
nb_pixels
)
self
.
std
=
np
.
sqrt
(
diff
/
nb_pixels
)
return
self
.
mean
,
self
.
std
return
self
.
mean
,
self
.
std
def
get_sample_img
(
self
,
i
):
def
compute_final_size
(
self
,
i
mg
):
"""
"""
Get image by index
Compute the final image size and position after feature extraction
"""
"""
if
self
.
load_in_memory
:
image_reduced_shape
=
np
.
ceil
(
img
.
shape
/
self
.
reduce_dims_factor
).
astype
(
int
)
return
self
.
samples
[
i
][
"
img
"
]
else
:
if
self
.
set_name
==
"
train
"
:
return
self
.
preprocessing_transforms
(
image_reduced_shape
=
[
max
(
1
,
t
)
for
t
in
image_reduced_shape
]
self
.
load_image
(
self
.
samples
[
i
][
"
path
"
])
)
image_position
=
[
[
0
,
img
.
shape
[
0
]],
[
0
,
img
.
shape
[
1
]],
]
return
image_reduced_shape
,
image_position
def
convert_sample_label
(
self
,
label
):
"""
Tokenize the label and return its length
"""
token_label
=
token_to_ind
(
self
.
charset
,
label
)
token_label
.
append
(
self
.
tokens
[
"
end
"
])
label_len
=
len
(
token_label
)
token_label
.
insert
(
0
,
self
.
tokens
[
"
start
"
])
return
token_label
,
label_len
This diff is collapsed.
Click to expand it.
dan/manager/ocr.py
+
228
−
88
View file @
8ba39efd
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
copy
import
os
import
os
import
pickle
import
pickle
import
random
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
dan.manager.dataset
import
DatasetManager
,
GenericDataset
from
dan.manager.dataset
import
OCRDataset
from
dan.utils
import
pad_images
,
pad_sequences_1D
,
token_to_ind
from
dan.transforms
import
get_augmentation_transforms
,
get_preprocessing_transforms
from
dan.utils
import
pad_images
,
pad_sequences_1D
class
OCRDatasetManager
(
DatasetManager
):
class
OCRDatasetManager
:
"""
Specific class to handle OCR/HTR tasks
"""
def
__init__
(
self
,
params
,
device
:
str
):
def
__init__
(
self
,
params
,
device
:
str
):
super
(
OCRDatasetManager
,
self
).
__init__
(
params
,
device
)
self
.
params
=
params
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self
.
pin_memory
=
device
!=
"
cpu
"
self
.
train_dataset
=
None
self
.
valid_datasets
=
dict
()
self
.
test_datasets
=
dict
()
self
.
train_loader
=
None
self
.
valid_loaders
=
dict
()
self
.
test_loaders
=
dict
()
self
.
train_sampler
=
None
self
.
valid_samplers
=
dict
()
self
.
test_samplers
=
dict
()
self
.
dataset_class
=
OCRDataset
self
.
mean
=
(
self
.
charset
=
(
np
.
array
(
params
[
"
config
"
][
"
mean
"
])
params
[
"
charset
"
]
if
"
charset
"
in
params
else
self
.
get_merged_charsets
()
if
"
mean
"
in
params
[
"
config
"
].
keys
()
else
None
)
)
self
.
std
=
(
np
.
array
(
params
[
"
config
"
][
"
std
"
])
if
"
std
"
in
params
[
"
config
"
].
keys
()
else
None
)
self
.
generator
=
torch
.
Generator
()
self
.
generator
.
manual_seed
(
0
)
self
.
tokens
=
{
"
pad
"
:
len
(
self
.
charset
)
+
2
}
self
.
batch_size
=
self
.
get_batch_size_by_set
()
self
.
tokens
[
"
end
"
]
=
len
(
self
.
charset
)
self
.
tokens
[
"
start
"
]
=
len
(
self
.
charset
)
+
1
self
.
load_in_memory
=
(
self
.
params
[
"
config
"
][
"
load_in_memory
"
]
if
"
load_in_memory
"
in
self
.
params
[
"
config
"
]
else
True
)
self
.
charset
=
self
.
get_charset
()
self
.
tokens
=
self
.
get_tokens
()
self
.
params
[
"
config
"
][
"
padding_token
"
]
=
self
.
tokens
[
"
pad
"
]
self
.
params
[
"
config
"
][
"
padding_token
"
]
=
self
.
tokens
[
"
pad
"
]
def
get_merged_charsets
(
self
):
self
.
my_collate_function
=
OCRCollateFunction
(
self
.
params
[
"
config
"
])
self
.
augmentation
=
(
get_augmentation_transforms
()
if
self
.
params
[
"
config
"
][
"
augmentation
"
]
else
None
)
self
.
preprocessing
=
get_preprocessing_transforms
(
params
[
"
config
"
][
"
preprocessings
"
]
)
def
load_datasets
(
self
):
"""
Load training and validation datasets
"""
self
.
train_dataset
=
OCRDataset
(
set_name
=
"
train
"
,
paths_and_sets
=
self
.
get_paths_and_sets
(
self
.
params
[
"
train
"
][
"
datasets
"
]),
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
augmentation_transforms
=
self
.
augmentation
,
load_in_memory
=
self
.
load_in_memory
,
mean
=
self
.
mean
,
std
=
self
.
std
,
)
self
.
mean
,
self
.
std
=
self
.
train_dataset
.
compute_std_mean
()
for
custom_name
in
self
.
params
[
"
val
"
].
keys
():
self
.
valid_datasets
[
custom_name
]
=
OCRDataset
(
set_name
=
"
val
"
,
paths_and_sets
=
self
.
get_paths_and_sets
(
self
.
params
[
"
val
"
][
custom_name
]),
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
augmentation_transforms
=
None
,
load_in_memory
=
self
.
load_in_memory
,
mean
=
self
.
mean
,
std
=
self
.
std
,
)
def
load_ddp_samplers
(
self
):
"""
Load training and validation data samplers
"""
if
self
.
params
[
"
use_ddp
"
]:
self
.
train_sampler
=
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
True
,
)
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
valid_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
None
def
load_dataloaders
(
self
):
"""
Load training and validation data loaders
"""
self
.
train_loader
=
DataLoader
(
self
.
train_dataset
,
batch_size
=
self
.
batch_size
[
"
train
"
],
shuffle
=
True
if
self
.
train_sampler
is
None
else
False
,
drop_last
=
False
,
batch_sampler
=
self
.
train_sampler
,
sampler
=
self
.
train_sampler
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
for
key
in
self
.
valid_datasets
.
keys
():
self
.
valid_loaders
[
key
]
=
DataLoader
(
self
.
valid_datasets
[
key
],
batch_size
=
self
.
batch_size
[
"
val
"
],
sampler
=
self
.
valid_samplers
[
key
],
batch_sampler
=
self
.
valid_samplers
[
key
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
@staticmethod
def
seed_worker
(
worker_id
):
"""
Set worker seed
"""
worker_seed
=
torch
.
initial_seed
()
%
2
**
32
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
def
generate_test_loader
(
self
,
custom_name
,
sets_list
):
"""
Load test dataset, data sampler and data loader
"""
if
custom_name
in
self
.
test_loaders
.
keys
():
return
paths_and_sets
=
list
()
for
set_info
in
sets_list
:
paths_and_sets
.
append
(
{
"
path
"
:
self
.
params
[
"
datasets
"
][
set_info
[
0
]],
"
set_name
"
:
set_info
[
1
]}
)
self
.
test_datasets
[
custom_name
]
=
OCRDataset
(
set_name
=
"
test
"
,
paths_and_sets
=
paths_and_sets
,
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
augmentation_transforms
=
None
,
load_in_memory
=
self
.
load_in_memory
,
mean
=
self
.
mean
,
std
=
self
.
std
,
)
if
self
.
params
[
"
use_ddp
"
]:
self
.
test_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
test_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
self
.
test_samplers
[
custom_name
]
=
None
self
.
test_loaders
[
custom_name
]
=
DataLoader
(
self
.
test_datasets
[
custom_name
],
batch_size
=
self
.
batch_size
[
"
test
"
],
sampler
=
self
.
test_samplers
[
custom_name
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
def
get_paths_and_sets
(
self
,
dataset_names_folds
):
"""
Set the right path for each data set
"""
paths_and_sets
=
list
()
for
dataset_name
,
fold
in
dataset_names_folds
:
path
=
self
.
params
[
"
datasets
"
][
dataset_name
]
paths_and_sets
.
append
({
"
path
"
:
path
,
"
set_name
"
:
fold
})
return
paths_and_sets
def
get_charset
(
self
):
"""
"""
Merge the charset of the different datasets used
Merge the charset of the different datasets used
"""
"""
if
"
charset
"
in
self
.
params
:
return
self
.
params
[
"
charset
"
]
datasets
=
self
.
params
[
"
datasets
"
]
datasets
=
self
.
params
[
"
datasets
"
]
charset
=
set
()
charset
=
set
()
for
key
in
datasets
.
keys
():
for
key
in
datasets
.
keys
():
...
@@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager):
...
@@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager):
charset
.
remove
(
""
)
charset
.
remove
(
""
)
return
sorted
(
list
(
charset
))
return
sorted
(
list
(
charset
))
def
apply_specific_treatment_after_dataset_loading
(
self
,
dataset
):
def
get_tokens
(
self
):
dataset
.
charset
=
self
.
charset
dataset
.
tokens
=
self
.
tokens
dataset
.
convert_labels
()
class
OCRDataset
(
GenericDataset
):
"""
Specific class to handle OCR/HTR datasets
"""
def
__init__
(
self
,
params
,
set_name
,
custom_name
,
paths_and_sets
,
augmentation_transforms
=
None
,
):
super
(
OCRDataset
,
self
).
__init__
(
params
,
set_name
,
custom_name
,
paths_and_sets
)
self
.
charset
=
None
self
.
tokens
=
None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self
.
reduce_dims_factor
=
np
.
array
([
32
,
8
,
1
])
self
.
collate_function
=
OCRCollateFunction
self
.
augmentation_transforms
=
augmentation_transforms
def
__getitem__
(
self
,
idx
):
sample
=
copy
.
deepcopy
(
self
.
samples
[
idx
])
if
not
self
.
load_in_memory
:
sample
[
"
img
"
]
=
self
.
get_sample_img
(
idx
)
# Convert to numpy
sample
[
"
img
"
]
=
np
.
array
(
sample
[
"
img
"
])
# Data augmentation
if
self
.
augmentation_transforms
:
sample
[
"
img
"
]
=
self
.
augmentation_transforms
(
image
=
sample
[
"
img
"
])[
"
image
"
]
# Normalization
sample
[
"
img
"
]
=
(
sample
[
"
img
"
]
-
self
.
mean
)
/
self
.
std
sample
[
"
img_reduced_shape
"
]
=
np
.
ceil
(
sample
[
"
img
"
].
shape
/
self
.
reduce_dims_factor
).
astype
(
int
)
if
self
.
set_name
==
"
train
"
:
sample
[
"
img_reduced_shape
"
]
=
[
max
(
1
,
t
)
for
t
in
sample
[
"
img_reduced_shape
"
]
]
sample
[
"
img_position
"
]
=
[
[
0
,
sample
[
"
img
"
].
shape
[
0
]],
[
0
,
sample
[
"
img
"
].
shape
[
1
]],
]
return
sample
def
convert_labels
(
self
):
"""
"""
Label str to token at character level
Get special tokens
"""
"""
for
i
in
range
(
len
(
self
.
samples
)):
return
{
self
.
samples
[
i
]
=
self
.
convert_sample_labels
(
self
.
samples
[
i
])
"
end
"
:
len
(
self
.
charset
),
"
start
"
:
len
(
self
.
charset
)
+
1
,
def
convert_sample_labels
(
self
,
sample
):
"
pad
"
:
len
(
self
.
charset
)
+
2
,
label
=
sample
[
"
label
"
]
}
sample
[
"
label
"
]
=
label
def
get_batch_size_by_set
(
self
):
sample
[
"
token_label
"
]
=
token_to_ind
(
self
.
charset
,
label
)
"""
sample
[
"
token_label
"
].
append
(
self
.
tokens
[
"
end
"
])
Return batch size for each set
sample
[
"
label_len
"
]
=
len
(
sample
[
"
token_label
"
])
"""
sample
[
"
token_label
"
].
insert
(
0
,
self
.
tokens
[
"
start
"
])
return
{
return
sample
"
train
"
:
self
.
params
[
"
batch_size
"
],
"
val
"
:
self
.
params
[
"
valid_batch_size
"
]
if
"
valid_batch_size
"
in
self
.
params
else
self
.
params
[
"
batch_size
"
],
"
test
"
:
self
.
params
[
"
test_batch_size
"
]
if
"
test_batch_size
"
in
self
.
params
else
1
,
}
class
OCRCollateFunction
:
class
OCRCollateFunction
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment