Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package Registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
389c3505
Commit
389c3505
authored
1 year ago
by
Yoann Schneider
Browse files
Options
Downloads
Plain Diff
Merge branch 'train-model-properties' into 'main'
Create properties to access encoder/decoder Closes
#251
See merge request
!346
parents
9aae0c16
6f2c5cb4
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!346
Create properties to access encoder/decoder
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
dan/ocr/manager/training.py
+24
-16
24 additions, 16 deletions
dan/ocr/manager/training.py
with
24 additions
and
16 deletions
dan/ocr/manager/training.py
+
24
−
16
View file @
389c3505
...
@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
dan.ocr.decoder
import
GlobalHTADecoder
from
dan.ocr.encoder
import
FCN_Encoder
from
dan.ocr.manager.metrics
import
Inference
,
MetricManager
from
dan.ocr.manager.metrics
import
Inference
,
MetricManager
from
dan.ocr.manager.ocr
import
OCRDatasetManager
from
dan.ocr.manager.ocr
import
OCRDatasetManager
from
dan.ocr.mlflow
import
MLFLOW_AVAILABLE
,
logging_metrics
,
logging_tags_metrics
from
dan.ocr.mlflow
import
MLFLOW_AVAILABLE
,
logging_metrics
,
logging_tags_metrics
...
@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
...
@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
import
mlflow
import
mlflow
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
MODEL_NAMES
=
(
"
encoder
"
,
"
decoder
"
)
MODEL_NAME_ENCODER
=
"
encoder
"
MODEL_NAME_DECODER
=
"
decoder
"
MODEL_NAMES
=
(
MODEL_NAME_ENCODER
,
MODEL_NAME_DECODER
)
class
GenericTrainingManager
:
class
GenericTrainingManager
:
...
@@ -69,6 +73,14 @@ class GenericTrainingManager:
...
@@ -69,6 +73,14 @@ class GenericTrainingManager:
self
.
init_paths
()
self
.
init_paths
()
self
.
load_dataset
()
self
.
load_dataset
()
@property
def
encoder
(
self
)
->
FCN_Encoder
|
None
:
return
self
.
models
.
get
(
MODEL_NAME_ENCODER
)
@property
def
decoder
(
self
)
->
GlobalHTADecoder
|
None
:
return
self
.
models
.
get
(
MODEL_NAME_DECODER
)
def
init_paths
(
self
):
def
init_paths
(
self
):
"""
"""
Create output folders for results and checkpoints
Create output folders for results and checkpoints
...
@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
...
@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
hidden_predict
=
None
hidden_predict
=
None
cache
=
None
cache
=
None
features
=
self
.
models
[
"
encoder
"
]
(
batch_data
[
"
imgs
"
].
to
(
self
.
device
))
features
=
self
.
encoder
(
batch_data
[
"
imgs
"
].
to
(
self
.
device
))
features_size
=
features
.
size
()
features_size
=
features
.
size
()
if
self
.
device_params
[
"
use_ddp
"
]:
if
self
.
device_params
[
"
use_ddp
"
]:
features
=
self
.
models
[
features
=
self
.
decoder
.
module
.
features_updater
.
get_pos_features
(
"
decoder
"
].
module
.
features_updater
.
get_pos_features
(
features
)
else
:
features
=
self
.
models
[
"
decoder
"
].
features_updater
.
get_pos_features
(
features
features
)
)
else
:
features
=
self
.
decoder
.
features_updater
.
get_pos_features
(
features
)
features
=
torch
.
flatten
(
features
,
start_dim
=
2
,
end_dim
=
3
).
permute
(
2
,
0
,
1
)
features
=
torch
.
flatten
(
features
,
start_dim
=
2
,
end_dim
=
3
).
permute
(
2
,
0
,
1
)
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
models
[
"
decoder
"
]
(
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
decoder
(
features
,
features
,
simulated_y_pred
[:,
:
-
1
],
simulated_y_pred
[:,
:
-
1
],
[
s
[:
2
]
for
s
in
batch_data
[
"
imgs_reduced_shape
"
]],
[
s
[:
2
]
for
s
in
batch_data
[
"
imgs_reduced_shape
"
]],
...
@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
...
@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
for
i
in
range
(
b
):
for
i
in
range
(
b
):
pos
=
batch_data
[
"
imgs_position
"
]
pos
=
batch_data
[
"
imgs_position
"
]
features_list
.
append
(
features_list
.
append
(
self
.
models
[
"
encoder
"
]
(
self
.
encoder
(
x
[
x
[
i
:
i
+
1
,
i
:
i
+
1
,
:,
:,
...
@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
...
@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
i
,
:,
:
features_list
[
i
].
size
(
2
),
:
features_list
[
i
].
size
(
3
)
i
,
:,
:
features_list
[
i
].
size
(
2
),
:
features_list
[
i
].
size
(
3
)
]
=
features_list
[
i
]
]
=
features_list
[
i
]
else
:
else
:
features
=
self
.
models
[
"
encoder
"
]
(
x
)
features
=
self
.
encoder
(
x
)
features_size
=
features
.
size
()
features_size
=
features
.
size
()
if
self
.
device_params
[
"
use_ddp
"
]:
if
self
.
device_params
[
"
use_ddp
"
]:
features
=
self
.
models
[
features
=
self
.
decoder
.
module
.
features_updater
.
get_pos_features
(
"
decoder
"
].
module
.
features_updater
.
get_pos_features
(
features
)
else
:
features
=
self
.
models
[
"
decoder
"
].
features_updater
.
get_pos_features
(
features
features
)
)
else
:
features
=
self
.
decoder
.
features_updater
.
get_pos_features
(
features
)
features
=
torch
.
flatten
(
features
,
start_dim
=
2
,
end_dim
=
3
).
permute
(
2
,
0
,
1
)
features
=
torch
.
flatten
(
features
,
start_dim
=
2
,
end_dim
=
3
).
permute
(
2
,
0
,
1
)
for
i
in
range
(
0
,
max_chars
):
for
i
in
range
(
0
,
max_chars
):
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
models
[
"
decoder
"
]
(
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
decoder
(
features
,
features
,
predicted_tokens
,
predicted_tokens
,
[
s
[:
2
]
for
s
in
batch_data
[
"
imgs_reduced_shape
"
]],
[
s
[:
2
]
for
s
in
batch_data
[
"
imgs_reduced_shape
"
]],
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment