Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package Registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
2bb85c50
Commit
2bb85c50
authored
1 year ago
by
Solene Tarride
Committed by
Mélodie Boillet
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
Merge DatasetManager / GenericDataset / OCRDatasetManager / OCRDataset classes
parent
fdaf48f4
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!191
Merge DatasetManager / GenericDataset / OCRDatasetManager / OCRDataset classes
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
dan/manager/dataset.py
+82
-200
82 additions, 200 deletions
dan/manager/dataset.py
dan/manager/ocr.py
+220
-92
220 additions, 92 deletions
dan/manager/ocr.py
with
302 additions
and
292 deletions
dan/manager/dataset.py
+
82
−
200
View file @
2bb85c50
# -*- coding: utf-8 -*-
import
json
import
os
import
random
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data
import
Dataset
from
torchvision.io
import
ImageReadMode
,
read_image
from
dan.datasets.utils
import
natural_sort
from
dan.transforms
import
(
get_augmentation_transforms
,
get_normalization_transforms
,
get_preprocessing_transforms
,
)
from
dan.utils
import
token_to_ind
class
DatasetManager
:
def
__init__
(
self
,
params
,
device
:
str
):
self
.
params
=
params
self
.
dataset_class
=
None
self
.
my_collate_function
=
None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self
.
pin_memory
=
device
!=
"
cpu
"
self
.
train_dataset
=
None
self
.
valid_datasets
=
dict
()
self
.
test_datasets
=
dict
()
class
OCRDataset
(
Dataset
):
"""
Dataset class to handle dataset loading
"""
self
.
train_loader
=
None
self
.
valid_loaders
=
dict
()
self
.
test_loaders
=
dict
()
def
__init__
(
self
,
set_name
,
paths_and_sets
,
charset
,
tokens
,
preprocessing_transforms
,
normalization_transforms
,
augmentation_transforms
,
load_in_memory
=
False
,
):
self
.
set_name
=
set_name
self
.
charset
=
charset
self
.
tokens
=
tokens
self
.
load_in_memory
=
load_in_memory
self
.
train_sampler
=
None
self
.
valid_samplers
=
dict
()
self
.
test_samplers
=
dict
()
# Pre-processing, augmentation, normalization
self
.
preprocessing_transforms
=
preprocessing_transforms
self
.
normalization_transforms
=
normalization_transforms
self
.
augmentation_transforms
=
augmentation_transforms
self
.
generator
=
torch
.
Generator
()
self
.
generator
.
manual_seed
(
0
)
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self
.
reduce_dims_factor
=
np
.
array
([
32
,
8
,
1
]
)
self
.
batch_size
=
{
"
train
"
:
self
.
params
[
"
batch_size
"
],
"
val
"
:
self
.
params
[
"
valid_batch_size
"
]
if
"
valid_batch_size
"
in
self
.
params
else
self
.
params
[
"
batch_size
"
],
"
test
"
:
self
.
params
[
"
test_batch_size
"
]
if
"
test_batch_size
"
in
self
.
params
else
1
,
}
# Load samples and preprocess images if load_in_memory is True
self
.
samples
=
self
.
load_samples
(
paths_and_sets
)
def
apply_specific_treatment_after_dataset_loading
(
self
,
dataset
):
raise
NotImplementedError
# Curriculum config
self
.
curriculum_config
=
None
def
load_datasets
(
self
):
def
__len__
(
self
):
"""
Load training and validation
dataset
s
Return the
dataset
size
"""
self
.
train_dataset
=
self
.
dataset_class
(
self
.
params
,
"
train
"
,
self
.
params
[
"
train
"
][
"
name
"
],
self
.
get_paths_and_sets
(
self
.
params
[
"
train
"
][
"
datasets
"
]),
normalization_transforms
=
get_normalization_transforms
(),
augmentation_transforms
=
(
get_augmentation_transforms
()
if
self
.
params
[
"
config
"
][
"
augmentation
"
]
else
None
),
)
self
.
my_collate_function
=
self
.
train_dataset
.
collate_function
(
self
.
params
[
"
config
"
]
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
train_dataset
)
for
custom_name
in
self
.
params
[
"
val
"
].
keys
():
self
.
valid_datasets
[
custom_name
]
=
self
.
dataset_class
(
self
.
params
,
"
val
"
,
custom_name
,
self
.
get_paths_and_sets
(
self
.
params
[
"
val
"
][
custom_name
]),
normalization_transforms
=
get_normalization_transforms
(),
augmentation_transforms
=
None
,
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
valid_datasets
[
custom_name
]
)
return
len
(
self
.
samples
)
def
load_ddp_samplers
(
self
):
def
__getitem__
(
self
,
idx
):
"""
Load training and validation data samplers
Return an item from the dataset (image and label)
"""
if
self
.
params
[
"
use_ddp
"
]:
self
.
train_sampler
=
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
True
,
)
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
valid_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
None
def
load_dataloaders
(
self
):
"""
Load training and validation data loaders
"""
self
.
train_loader
=
DataLoader
(
self
.
train_dataset
,
batch_size
=
self
.
batch_size
[
"
train
"
],
shuffle
=
True
if
self
.
train_sampler
is
None
else
False
,
drop_last
=
False
,
batch_sampler
=
self
.
train_sampler
,
sampler
=
self
.
train_sampler
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
# Load preprocessed image
sample
=
dict
(
**
self
.
samples
[
idx
])
if
not
self
.
load_in_memory
:
sample
[
"
img
"
]
=
self
.
get_sample_img
(
idx
)
for
key
in
self
.
valid_datasets
.
keys
():
self
.
valid_loaders
[
key
]
=
DataLoader
(
self
.
valid_datasets
[
key
],
batch_size
=
self
.
batch_size
[
"
val
"
],
sampler
=
self
.
valid_samplers
[
key
],
batch_sampler
=
self
.
valid_samplers
[
key
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
# Apply data augmentation
if
self
.
augmentation_transforms
:
sample
[
"
img
"
]
=
self
.
augmentation_transforms
(
image
=
np
.
array
(
sample
[
"
img
"
]))[
"
image
"
]
@staticmethod
def
seed_worker
(
worker_id
):
worker_seed
=
torch
.
initial_seed
()
%
2
**
32
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
# Image normalization
sample
[
"
img
"
]
=
self
.
normalization_transforms
(
sample
[
"
img
"
])
def
generate_test_loader
(
self
,
custom_name
,
sets_list
):
"""
Load test dataset, data sampler and data loader
"""
if
custom_name
in
self
.
test_loaders
.
keys
():
return
paths_and_sets
=
list
()
for
set_info
in
sets_list
:
paths_and_sets
.
append
(
{
"
path
"
:
self
.
params
[
"
datasets
"
][
set_info
[
0
]],
"
set_name
"
:
set_info
[
1
]}
)
self
.
test_datasets
[
custom_name
]
=
self
.
dataset_class
(
self
.
params
,
"
test
"
,
custom_name
,
paths_and_sets
,
normalization_transforms
=
get_normalization_transforms
(),
)
self
.
apply_specific_treatment_after_dataset_loading
(
self
.
test_datasets
[
custom_name
]
# Get final height and width
sample
[
"
img_reduced_shape
"
],
sample
[
"
img_position
"
]
=
self
.
compute_final_size
(
sample
[
"
img
"
]
)
if
self
.
params
[
"
use_ddp
"
]:
self
.
test_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
test_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
self
.
test_samplers
[
custom_name
]
=
None
self
.
test_loaders
[
custom_name
]
=
DataLoader
(
self
.
test_datasets
[
custom_name
],
batch_size
=
self
.
batch_size
[
"
test
"
],
sampler
=
self
.
test_samplers
[
custom_name
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
def
get_paths_and_sets
(
self
,
dataset_names_folds
):
paths_and_sets
=
list
()
for
dataset_name
,
fold
in
dataset_names_folds
:
path
=
self
.
params
[
"
datasets
"
][
dataset_name
]
paths_and_sets
.
append
({
"
path
"
:
path
,
"
set_name
"
:
fold
})
return
paths_and_sets
class
GenericDataset
(
Dataset
):
"""
Main class to handle dataset loading
"""
def
__init__
(
self
,
params
,
set_name
,
custom_name
,
paths_and_sets
):
self
.
params
=
params
self
.
name
=
custom_name
self
.
set_name
=
set_name
self
.
preprocessing_transforms
=
get_preprocessing_transforms
(
params
[
"
config
"
][
"
preprocessings
"
]
)
self
.
load_in_memory
=
(
self
.
params
[
"
config
"
][
"
load_in_memory
"
]
if
"
load_in_memory
"
in
self
.
params
[
"
config
"
]
else
True
# Convert label into tokens
sample
[
"
token_label
"
],
sample
[
"
label_len
"
]
=
self
.
convert_sample_label
(
sample
[
"
label
"
]
)
# Load samples and preprocess images if load_in_memory is True
self
.
samples
=
self
.
load_samples
(
paths_and_sets
)
self
.
curriculum_config
=
None
def
__len__
(
self
):
return
len
(
self
.
samples
)
return
sample
@staticmethod
def
load_image
(
path
):
...
...
@@ -276,3 +130,31 @@ class GenericDataset(Dataset):
return
self
.
preprocessing_transforms
(
self
.
load_image
(
self
.
samples
[
i
][
"
path
"
])
)
def
compute_final_size
(
self
,
img
):
"""
Compute the final image size and position after feature extraction
"""
final_c
,
final_h
,
final_w
=
img
.
shape
image_reduced_shape
=
np
.
ceil
(
[
final_h
,
final_w
,
final_c
]
/
self
.
reduce_dims_factor
).
astype
(
int
)
if
self
.
set_name
==
"
train
"
:
image_reduced_shape
=
[
max
(
1
,
t
)
for
t
in
image_reduced_shape
]
image_position
=
[
[
0
,
final_h
],
[
0
,
final_w
],
]
return
image_reduced_shape
,
image_position
def
convert_sample_label
(
self
,
label
):
"""
Tokenize the label and return its length
"""
token_label
=
token_to_ind
(
self
.
charset
,
label
)
token_label
.
append
(
self
.
tokens
[
"
end
"
])
label_len
=
len
(
token_label
)
token_label
.
insert
(
0
,
self
.
tokens
[
"
start
"
])
return
token_label
,
label_len
This diff is collapsed.
Click to expand it.
dan/manager/ocr.py
+
220
−
92
View file @
2bb85c50
# -*- coding: utf-8 -*-
import
os
import
pickle
import
random
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
dan.manager.dataset
import
DatasetManager
,
GenericDataset
from
dan.utils
import
pad_images
,
pad_sequences_1D
,
token_to_ind
from
dan.manager.dataset
import
OCRDataset
from
dan.transforms
import
(
get_augmentation_transforms
,
get_normalization_transforms
,
get_preprocessing_transforms
,
)
from
dan.utils
import
pad_images
,
pad_sequences_1D
class
OCRDatasetManager
(
DatasetManager
):
"""
Specific class to handle OCR/HTR tasks
"""
class
OCRDatasetManager
:
def
__init__
(
self
,
params
,
device
:
str
):
s
uper
(
OCRDatasetManager
,
self
).
__init__
(
params
,
device
)
s
elf
.
params
=
params
self
.
dataset_class
=
OCRDataset
self
.
charset
=
(
params
[
"
charset
"
]
if
"
charset
"
in
params
else
self
.
get_merged_charsets
()
)
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self
.
pin_memory
=
device
!=
"
cpu
"
self
.
train_dataset
=
None
self
.
valid_datasets
=
dict
()
self
.
test_datasets
=
dict
()
self
.
train_loader
=
None
self
.
valid_loaders
=
dict
()
self
.
test_loaders
=
dict
()
self
.
train_sampler
=
None
self
.
valid_samplers
=
dict
()
self
.
test_samplers
=
dict
()
self
.
tokens
=
{
"
pad
"
:
len
(
self
.
charset
)
+
2
}
self
.
tokens
[
"
end
"
]
=
len
(
self
.
charset
)
self
.
tokens
[
"
start
"
]
=
len
(
self
.
charset
)
+
1
self
.
generator
=
torch
.
Generator
()
self
.
generator
.
manual_seed
(
0
)
self
.
batch_size
=
self
.
get_batch_size_by_set
()
self
.
load_in_memory
=
(
self
.
params
[
"
config
"
][
"
load_in_memory
"
]
if
"
load_in_memory
"
in
self
.
params
[
"
config
"
]
else
True
)
self
.
charset
=
self
.
get_charset
()
self
.
tokens
=
self
.
get_tokens
()
self
.
params
[
"
config
"
][
"
padding_token
"
]
=
self
.
tokens
[
"
pad
"
]
def
get_merged_charsets
(
self
):
self
.
my_collate_function
=
OCRCollateFunction
(
self
.
params
[
"
config
"
])
self
.
normalization
=
get_normalization_transforms
()
self
.
augmentation
=
(
get_augmentation_transforms
()
if
self
.
params
[
"
config
"
][
"
augmentation
"
]
else
None
)
self
.
preprocessing
=
get_preprocessing_transforms
(
params
[
"
config
"
][
"
preprocessings
"
]
)
def
load_datasets
(
self
):
"""
Load training and validation datasets
"""
self
.
train_dataset
=
OCRDataset
(
set_name
=
"
train
"
,
paths_and_sets
=
self
.
get_paths_and_sets
(
self
.
params
[
"
train
"
][
"
datasets
"
]),
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
normalization_transforms
=
self
.
normalization
,
augmentation_transforms
=
self
.
augmentation
,
load_in_memory
=
self
.
load_in_memory
,
)
for
custom_name
in
self
.
params
[
"
val
"
].
keys
():
self
.
valid_datasets
[
custom_name
]
=
OCRDataset
(
set_name
=
"
val
"
,
paths_and_sets
=
self
.
get_paths_and_sets
(
self
.
params
[
"
val
"
][
custom_name
]),
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
normalization_transforms
=
self
.
normalization
,
augmentation_transforms
=
None
,
load_in_memory
=
self
.
load_in_memory
,
)
def
load_ddp_samplers
(
self
):
"""
Load training and validation data samplers
"""
if
self
.
params
[
"
use_ddp
"
]:
self
.
train_sampler
=
DistributedSampler
(
self
.
train_dataset
,
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
True
,
)
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
valid_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
for
custom_name
in
self
.
valid_datasets
.
keys
():
self
.
valid_samplers
[
custom_name
]
=
None
def
load_dataloaders
(
self
):
"""
Load training and validation data loaders
"""
self
.
train_loader
=
DataLoader
(
self
.
train_dataset
,
batch_size
=
self
.
batch_size
[
"
train
"
],
shuffle
=
True
if
self
.
train_sampler
is
None
else
False
,
drop_last
=
False
,
batch_sampler
=
self
.
train_sampler
,
sampler
=
self
.
train_sampler
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
for
key
in
self
.
valid_datasets
.
keys
():
self
.
valid_loaders
[
key
]
=
DataLoader
(
self
.
valid_datasets
[
key
],
batch_size
=
self
.
batch_size
[
"
val
"
],
sampler
=
self
.
valid_samplers
[
key
],
batch_sampler
=
self
.
valid_samplers
[
key
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
@staticmethod
def
seed_worker
():
"""
Set worker seed
"""
worker_seed
=
torch
.
initial_seed
()
%
2
**
32
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
def
generate_test_loader
(
self
,
custom_name
,
sets_list
):
"""
Load test dataset, data sampler and data loader
"""
if
custom_name
in
self
.
test_loaders
.
keys
():
return
paths_and_sets
=
list
()
for
set_info
in
sets_list
:
paths_and_sets
.
append
(
{
"
path
"
:
self
.
params
[
"
datasets
"
][
set_info
[
0
]],
"
set_name
"
:
set_info
[
1
]}
)
self
.
test_datasets
[
custom_name
]
=
OCRDataset
(
set_name
=
"
test
"
,
paths_and_sets
=
paths_and_sets
,
charset
=
self
.
charset
,
tokens
=
self
.
tokens
,
preprocessing_transforms
=
self
.
preprocessing
,
normalization_transforms
=
self
.
normalization
,
augmentation_transforms
=
None
,
load_in_memory
=
self
.
load_in_memory
,
)
if
self
.
params
[
"
use_ddp
"
]:
self
.
test_samplers
[
custom_name
]
=
DistributedSampler
(
self
.
test_datasets
[
custom_name
],
num_replicas
=
self
.
params
[
"
num_gpu
"
],
rank
=
self
.
params
[
"
ddp_rank
"
],
shuffle
=
False
,
)
else
:
self
.
test_samplers
[
custom_name
]
=
None
self
.
test_loaders
[
custom_name
]
=
DataLoader
(
self
.
test_datasets
[
custom_name
],
batch_size
=
self
.
batch_size
[
"
test
"
],
sampler
=
self
.
test_samplers
[
custom_name
],
shuffle
=
False
,
num_workers
=
self
.
params
[
"
num_gpu
"
]
*
self
.
params
[
"
worker_per_gpu
"
],
pin_memory
=
self
.
pin_memory
,
drop_last
=
False
,
collate_fn
=
self
.
my_collate_function
,
worker_init_fn
=
self
.
seed_worker
,
generator
=
self
.
generator
,
)
def
get_paths_and_sets
(
self
,
dataset_names_folds
):
"""
Set the right path for each data set
"""
paths_and_sets
=
list
()
for
dataset_name
,
fold
in
dataset_names_folds
:
path
=
self
.
params
[
"
datasets
"
][
dataset_name
]
paths_and_sets
.
append
({
"
path
"
:
path
,
"
set_name
"
:
fold
})
return
paths_and_sets
def
get_charset
(
self
):
"""
Merge the charset of the different datasets used
"""
if
"
charset
"
in
self
.
params
:
return
self
.
params
[
"
charset
"
]
datasets
=
self
.
params
[
"
datasets
"
]
charset
=
set
()
for
key
in
datasets
.
keys
():
...
...
@@ -39,83 +221,29 @@ class OCRDatasetManager(DatasetManager):
charset
.
remove
(
""
)
return
sorted
(
list
(
charset
))
def
apply_specific_treatment_after_dataset_loading
(
self
,
dataset
):
dataset
.
charset
=
self
.
charset
dataset
.
tokens
=
self
.
tokens
dataset
.
convert_labels
()
class
OCRDataset
(
GenericDataset
):
"""
Specific class to handle OCR/HTR datasets
"""
def
get_tokens
(
self
):
"""
Get special tokens
"""
return
{
"
end
"
:
len
(
self
.
charset
),
"
start
"
:
len
(
self
.
charset
)
+
1
,
"
pad
"
:
len
(
self
.
charset
)
+
2
,
}
def
__init__
(
self
,
params
,
set_name
,
custom_name
,
paths_and_sets
,
normalization_transforms
,
augmentation_transforms
=
None
,
):
super
(
OCRDataset
,
self
).
__init__
(
params
,
set_name
,
custom_name
,
paths_and_sets
)
self
.
charset
=
None
self
.
tokens
=
None
# Factor to reduce the height and width of the feature vector before feeding the decoder.
self
.
reduce_dims_factor
=
np
.
array
([
32
,
8
,
1
])
self
.
collate_function
=
OCRCollateFunction
self
.
normalization_transforms
=
normalization_transforms
self
.
augmentation_transforms
=
augmentation_transforms
def
__getitem__
(
self
,
idx
):
sample
=
dict
(
**
self
.
samples
[
idx
])
if
not
self
.
load_in_memory
:
sample
[
"
img
"
]
=
self
.
get_sample_img
(
idx
)
# Data augmentation
if
self
.
augmentation_transforms
:
sample
[
"
img
"
]
=
self
.
augmentation_transforms
(
image
=
np
.
array
(
sample
[
"
img
"
]))[
"
image
"
]
# Normalization
sample
[
"
img
"
]
=
self
.
normalization_transforms
(
sample
[
"
img
"
])
# Get final height and width
final_c
,
final_h
,
final_w
=
sample
[
"
img
"
].
shape
sample
[
"
img_reduced_shape
"
]
=
np
.
ceil
(
[
final_h
,
final_w
,
final_c
]
/
self
.
reduce_dims_factor
).
astype
(
int
)
if
self
.
set_name
==
"
train
"
:
sample
[
"
img_reduced_shape
"
]
=
[
max
(
1
,
t
)
for
t
in
sample
[
"
img_reduced_shape
"
]
]
sample
[
"
img_position
"
]
=
[
[
0
,
final_h
],
[
0
,
final_w
],
]
return
sample
def
convert_labels
(
self
):
"""
Label str to token at character level
"""
for
i
in
range
(
len
(
self
.
samples
)):
self
.
samples
[
i
]
=
self
.
convert_sample_labels
(
self
.
samples
[
i
])
def
convert_sample_labels
(
self
,
sample
):
label
=
sample
[
"
label
"
]
sample
[
"
label
"
]
=
label
sample
[
"
token_label
"
]
=
token_to_ind
(
self
.
charset
,
label
)
sample
[
"
token_label
"
].
append
(
self
.
tokens
[
"
end
"
])
sample
[
"
label_len
"
]
=
len
(
sample
[
"
token_label
"
])
sample
[
"
token_label
"
].
insert
(
0
,
self
.
tokens
[
"
start
"
])
return
sample
def
get_batch_size_by_set
(
self
):
"""
Return batch size for each set
"""
return
{
"
train
"
:
self
.
params
[
"
batch_size
"
],
"
val
"
:
self
.
params
[
"
valid_batch_size
"
]
if
"
valid_batch_size
"
in
self
.
params
else
self
.
params
[
"
batch_size
"
],
"
test
"
:
self
.
params
[
"
test_batch_size
"
]
if
"
test_batch_size
"
in
self
.
params
else
1
,
}
class
OCRCollateFunction
:
...
...
This diff is collapsed.
Click to expand it.
Mélodie Boillet
@mboillet
mentioned in commit
8ba39efd
·
1 year ago
mentioned in commit
8ba39efd
mentioned in commit 8ba39efdc35ae5bd702db4fc70bfe27406f0706d
Toggle commit list
Mélodie Boillet
@mboillet
mentioned in commit
59ab6de9
·
1 year ago
mentioned in commit
59ab6de9
mentioned in commit 59ab6de96d9730be63576e0db3172cb96d98c9bc
Toggle commit list
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment