Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Compare revisions
ec31e049d3975795bcad40c7ca1a1951fea121e7 to 389c35051ed6568245820c9ee78e77642d8a9816
Compare revisions
Changes are shown as if the
source
revision was being merged into the
target
revision.
Learn more about comparing revisions.
Source
atr/dan
Select target project
No results found
389c35051ed6568245820c9ee78e77642d8a9816
Select Git revision
Swap
Target
atr/dan
Select target project
atr/dan
1 result
ec31e049d3975795bcad40c7ca1a1951fea121e7
Select Git revision
Show changes
Only incoming changes from source
Include changes to target since source was created
Compare
Commits on Source (2)
Create properties to access encoder/decoder
· 6f2c5cb4
Manon Blanco
authored
1 year ago
and
Yoann Schneider
committed
1 year ago
6f2c5cb4
Merge branch 'train-model-properties' into 'main'
· 389c3505
Yoann Schneider
authored
1 year ago
Create properties to access encoder/decoder Closes
#251
See merge request
!346
389c3505
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
dan/ocr/manager/training.py
+24
-16
24 additions, 16 deletions
dan/ocr/manager/training.py
with
24 additions
and
16 deletions
dan/ocr/manager/training.py
View file @
389c3505
...
...
@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
from
dan.ocr.decoder
import
GlobalHTADecoder
from
dan.ocr.encoder
import
FCN_Encoder
from
dan.ocr.manager.metrics
import
Inference
,
MetricManager
from
dan.ocr.manager.ocr
import
OCRDatasetManager
from
dan.ocr.mlflow
import
MLFLOW_AVAILABLE
,
logging_metrics
,
logging_tags_metrics
...
...
@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
import
mlflow
logger
=
logging
.
getLogger
(
__name__
)
MODEL_NAMES
=
(
"
encoder
"
,
"
decoder
"
)
MODEL_NAME_ENCODER
=
"
encoder
"
MODEL_NAME_DECODER
=
"
decoder
"
MODEL_NAMES
=
(
MODEL_NAME_ENCODER
,
MODEL_NAME_DECODER
)
class
GenericTrainingManager
:
...
...
@@ -69,6 +73,14 @@ class GenericTrainingManager:
self
.
init_paths
()
self
.
load_dataset
()
@property
def
encoder
(
self
)
->
FCN_Encoder
|
None
:
return
self
.
models
.
get
(
MODEL_NAME_ENCODER
)
@property
def
decoder
(
self
)
->
GlobalHTADecoder
|
None
:
return
self
.
models
.
get
(
MODEL_NAME_DECODER
)
def
init_paths
(
self
):
"""
Create output folders for results and checkpoints
...
...
@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
hidden_predict
=
None
cache
=
None
features
=
self
.
models
[
"
encoder
"
]
(
batch_data
[
"
imgs
"
].
to
(
self
.
device
))
features
=
self
.
encoder
(
batch_data
[
"
imgs
"
].
to
(
self
.
device
))
features_size
=
features
.
size
()
if
self
.
device_params
[
"
use_ddp
"
]:
features
=
self
.
models
[
"
decoder
"
].
module
.
features_updater
.
get_pos_features
(
features
)
else
:
features
=
self
.
models
[
"
decoder
"
].
features_updater
.
get_pos_features
(
features
=
self
.
decoder
.
module
.
features_updater
.
get_pos_features
(
features
)
else
:
features
=
self
.
decoder
.
features_updater
.
get_pos_features
(
features
)
features
=
torch
.
flatten
(
features
,
start_dim
=
2
,
end_dim
=
3
).
permute
(
2
,
0
,
1
)
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
models
[
"
decoder
"
]
(
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
decoder
(
features
,
simulated_y_pred
[:,
:
-
1
],
[
s
[:
2
]
for
s
in
batch_data
[
"
imgs_reduced_shape
"
]],
...
...
@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
for
i
in
range
(
b
):
pos
=
batch_data
[
"
imgs_position
"
]
features_list
.
append
(
self
.
models
[
"
encoder
"
]
(
self
.
encoder
(
x
[
i
:
i
+
1
,
:,
...
...
@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
i
,
:,
:
features_list
[
i
].
size
(
2
),
:
features_list
[
i
].
size
(
3
)
]
=
features_list
[
i
]
else
:
features
=
self
.
models
[
"
encoder
"
]
(
x
)
features
=
self
.
encoder
(
x
)
features_size
=
features
.
size
()
if
self
.
device_params
[
"
use_ddp
"
]:
features
=
self
.
models
[
"
decoder
"
].
module
.
features_updater
.
get_pos_features
(
features
)
else
:
features
=
self
.
models
[
"
decoder
"
].
features_updater
.
get_pos_features
(
features
=
self
.
decoder
.
module
.
features_updater
.
get_pos_features
(
features
)
else
:
features
=
self
.
decoder
.
features_updater
.
get_pos_features
(
features
)
features
=
torch
.
flatten
(
features
,
start_dim
=
2
,
end_dim
=
3
).
permute
(
2
,
0
,
1
)
for
i
in
range
(
0
,
max_chars
):
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
models
[
"
decoder
"
]
(
output
,
pred
,
hidden_predict
,
cache
,
weights
=
self
.
decoder
(
features
,
predicted_tokens
,
[
s
[:
2
]
for
s
in
batch_data
[
"
imgs_reduced_shape
"
]],
...
...
This diff is collapsed.
Click to expand it.