Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package Registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
a26125a5
Commit
a26125a5
authored
1 year ago
by
Yoann Schneider
Committed by
Mélodie Boillet
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
Save metrics results to YAML instead of plain text
parent
24b5eb14
No related branches found
No related tags found
1 merge request
!122
Save metrics results to YAML instead of plain text
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
dan/manager/metrics.py
+45
-30
45 additions, 30 deletions
dan/manager/metrics.py
dan/manager/training.py
+5
-5
5 additions, 5 deletions
dan/manager/training.py
tests/test_training.py
+35
-33
35 additions, 33 deletions
tests/test_training.py
with
85 additions
and
68 deletions
dan/manager/metrics.py
+
45
−
30
View file @
a26125a5
...
@@ -87,54 +87,66 @@ class MetricManager:
...
@@ -87,54 +87,66 @@ class MetricManager:
value
=
None
value
=
None
if
output
:
if
output
:
if
metric_name
in
[
"
nb_samples
"
,
"
weights
"
]:
if
metric_name
in
[
"
nb_samples
"
,
"
weights
"
]:
value
=
np
.
sum
(
self
.
epoch_metrics
[
metric_name
])
value
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
metric_name
])
)
elif
metric_name
in
[
elif
metric_name
in
[
"
time
"
,
"
time
"
,
]:
]:
total_time
=
np
.
sum
(
self
.
epoch_metrics
[
metric_name
])
value
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
metric_name
]))
sample_time
=
total_time
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_samples
"
])
sample_time
=
value
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_samples
"
])
display_values
[
"
sample_time
"
]
=
round
(
sample_time
,
4
)
display_values
[
"
sample_time
"
]
=
float
(
round
(
sample_time
,
4
))
value
=
total_time
elif
metric_name
==
"
loer
"
:
elif
metric_name
==
"
loer
"
:
display_values
[
"
pper
"
]
=
round
(
display_values
[
"
pper
"
]
=
float
(
np
.
sum
(
self
.
epoch_metrics
[
"
nb_pp_op_layout
"
])
round
(
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_gt_layout_token
"
]),
np
.
sum
(
self
.
epoch_metrics
[
"
nb_pp_op_layout
"
])
4
,
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_gt_layout_token
"
]),
4
,
)
)
)
elif
metric_name
==
"
map_cer_per_class
"
:
elif
metric_name
==
"
map_cer_per_class
"
:
value
=
compute_global_mAP_per_class
(
self
.
epoch_metrics
[
"
map_cer
"
])
value
=
float
(
compute_global_mAP_per_class
(
self
.
epoch_metrics
[
"
map_cer
"
])
)
for
key
in
value
.
keys
():
for
key
in
value
.
keys
():
display_values
[
"
map_cer_
"
+
key
]
=
round
(
value
[
key
],
4
)
display_values
[
"
map_cer_
"
+
key
]
=
float
(
round
(
value
[
key
],
4
)
)
continue
continue
elif
metric_name
==
"
layout_precision_per_class_per_threshold
"
:
elif
metric_name
==
"
layout_precision_per_class_per_threshold
"
:
value
=
compute_global_precision_per_class_per_threshold
(
value
=
float
(
self
.
epoch_metrics
[
"
map_cer
"
]
compute_global_precision_per_class_per_threshold
(
self
.
epoch_metrics
[
"
map_cer
"
]
)
)
)
for
key_class
in
value
.
keys
():
for
key_class
in
value
.
keys
():
for
threshold
in
value
[
key_class
].
keys
():
for
threshold
in
value
[
key_class
].
keys
():
display_values
[
display_values
[
"
map_cer_{}_{}
"
.
format
(
key_class
,
threshold
)
"
map_cer_{}_{}
"
.
format
(
key_class
,
threshold
)
]
=
round
(
value
[
key_class
][
threshold
],
4
)
]
=
float
(
round
(
value
[
key_class
][
threshold
],
4
)
)
continue
continue
if
metric_name
==
"
cer
"
:
if
metric_name
==
"
cer
"
:
value
=
np
.
sum
(
self
.
epoch_metrics
[
"
edit_chars
"
])
/
np
.
sum
(
value
=
float
(
self
.
epoch_metrics
[
"
nb_chars
"
]
np
.
sum
(
self
.
epoch_metrics
[
"
edit_chars
"
])
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_chars
"
])
)
)
if
output
:
if
output
:
display_values
[
"
nb_chars
"
]
=
np
.
sum
(
self
.
epoch_metrics
[
"
nb_chars
"
])
display_values
[
"
nb_chars
"
]
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
"
nb_chars
"
])
)
elif
metric_name
==
"
wer
"
:
elif
metric_name
==
"
wer
"
:
value
=
np
.
sum
(
self
.
epoch_metrics
[
"
edit_words
"
])
/
np
.
sum
(
value
=
float
(
self
.
epoch_metrics
[
"
nb_words
"
]
np
.
sum
(
self
.
epoch_metrics
[
"
edit_words
"
])
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words
"
])
)
)
if
output
:
if
output
:
display_values
[
"
nb_words
"
]
=
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words
"
])
display_values
[
"
nb_words
"
]
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words
"
])
)
elif
metric_name
==
"
wer_no_punct
"
:
elif
metric_name
==
"
wer_no_punct
"
:
value
=
np
.
sum
(
self
.
epoch_metrics
[
"
edit_words_no_punct
"
])
/
np
.
sum
(
value
=
float
(
self
.
epoch_metrics
[
"
nb_words_no_punct
"
]
np
.
sum
(
self
.
epoch_metrics
[
"
edit_words_no_punct
"
])
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words_no_punct
"
])
)
)
if
output
:
if
output
:
display_values
[
"
nb_words_no_punct
"
]
=
np
.
sum
(
display_values
[
"
nb_words_no_punct
"
]
=
int
(
self
.
epoch_metrics
[
"
nb_words_no_punct
"
]
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words_no_punct
"
]
)
)
)
elif
metric_name
in
[
elif
metric_name
in
[
"
loss
"
,
"
loss
"
,
...
@@ -143,15 +155,18 @@ class MetricManager:
...
@@ -143,15 +155,18 @@ class MetricManager:
"
syn_max_lines
"
,
"
syn_max_lines
"
,
"
syn_prob_lines
"
,
"
syn_prob_lines
"
,
]:
]:
value
=
np
.
average
(
value
=
float
(
self
.
epoch_metrics
[
metric_name
],
np
.
average
(
weights
=
np
.
array
(
self
.
epoch_metrics
[
"
nb_samples
"
]),
self
.
epoch_metrics
[
metric_name
],
weights
=
np
.
array
(
self
.
epoch_metrics
[
"
nb_samples
"
]),
)
)
)
elif
metric_name
==
"
map_cer
"
:
elif
metric_name
==
"
map_cer
"
:
value
=
compute_global_mAP
(
self
.
epoch_metrics
[
metric_name
])
value
=
float
(
compute_global_mAP
(
self
.
epoch_metrics
[
metric_name
])
)
elif
metric_name
==
"
loer
"
:
elif
metric_name
==
"
loer
"
:
value
=
np
.
sum
(
self
.
epoch_metrics
[
"
edit_graph
"
])
/
np
.
sum
(
value
=
float
(
self
.
epoch_metrics
[
"
nb_nodes_and_edges
"
]
np
.
sum
(
self
.
epoch_metrics
[
"
edit_graph
"
])
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_nodes_and_edges
"
])
)
)
elif
value
is
None
:
elif
value
is
None
:
continue
continue
...
...
This diff is collapsed.
Click to expand it.
dan/manager/training.py
+
5
−
5
View file @
a26125a5
...
@@ -12,6 +12,7 @@ import numpy as np
...
@@ -12,6 +12,7 @@ import numpy as np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
yaml
from
PIL
import
Image
from
PIL
import
Image
from
torch.cuda.amp
import
GradScaler
,
autocast
from
torch.cuda.amp
import
GradScaler
,
autocast
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
...
@@ -865,22 +866,21 @@ class GenericTrainingManager:
...
@@ -865,22 +866,21 @@ class GenericTrainingManager:
metrics
=
self
.
metric_manager
[
custom_name
].
get_display_values
(
output
=
True
)
metrics
=
self
.
metric_manager
[
custom_name
].
get_display_values
(
output
=
True
)
path
=
os
.
path
.
join
(
path
=
os
.
path
.
join
(
self
.
paths
[
"
results
"
],
self
.
paths
[
"
results
"
],
"
predict_{}_{}.
txt
"
.
format
(
custom_name
,
self
.
latest_epoch
),
"
predict_{}_{}.
yaml
"
.
format
(
custom_name
,
self
.
latest_epoch
),
)
)
with
open
(
path
,
"
w
"
)
as
f
:
with
open
(
path
,
"
w
"
)
as
f
:
for
metric_name
in
metrics
.
keys
():
yaml
.
dump
(
metrics
,
stream
=
f
)
f
.
write
(
"
{}: {}
\n
"
.
format
(
metric_name
,
metrics
[
metric_name
]))
# Log mlflow artifacts
# Log mlflow artifacts
mlflow
.
log_artifact
(
path
,
"
predictions
"
)
mlflow
.
log_artifact
(
path
,
"
predictions
"
)
def
output_pred
(
self
,
name
):
def
output_pred
(
self
,
name
):
path
=
os
.
path
.
join
(
path
=
os
.
path
.
join
(
self
.
paths
[
"
results
"
],
"
pred_{}_{}.
txt
"
.
format
(
name
,
self
.
latest_epoch
)
self
.
paths
[
"
results
"
],
"
pred_{}_{}.
yaml
"
.
format
(
name
,
self
.
latest_epoch
)
)
)
pred
=
"
\n
"
.
join
(
self
.
metric_manager
[
name
].
get
(
"
pred
"
))
pred
=
"
\n
"
.
join
(
self
.
metric_manager
[
name
].
get
(
"
pred
"
))
with
open
(
path
,
"
w
"
)
as
f
:
with
open
(
path
,
"
w
"
)
as
f
:
f
.
write
(
pred
)
yaml
.
dump
(
pred
,
stream
=
f
)
def
launch_ddp
(
self
):
def
launch_ddp
(
self
):
"""
"""
...
...
This diff is collapsed.
Click to expand it.
tests/test_training.py
+
35
−
33
View file @
a26125a5
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
pytest
import
pytest
import
torch
import
torch
import
yaml
from
dan.ocr.document.train
import
train_and_test
from
dan.ocr.document.train
import
train_and_test
from
tests.conftest
import
FIXTURES
from
tests.conftest
import
FIXTURES
...
@@ -13,33 +14,33 @@ from tests.conftest import FIXTURES
...
@@ -13,33 +14,33 @@ from tests.conftest import FIXTURES
(
(
"
best_0.pt
"
,
"
best_0.pt
"
,
"
last_3.pt
"
,
"
last_3.pt
"
,
[
{
"
nb_chars: 43
"
,
"
nb_chars
"
:
43
,
"
cer: 1.2791
"
,
"
cer
"
:
1.2791
,
"
nb_words: 9
"
,
"
nb_words
"
:
9
,
"
wer: 1.0
"
,
"
wer
"
:
1.0
,
"
nb_words_no_punct: 9
"
,
"
nb_words_no_punct
"
:
9
,
"
wer_no_punct: 1.0
"
,
"
wer_no_punct
"
:
1.0
,
"
nb_samples: 2
"
,
"
nb_samples
"
:
2
,
]
,
}
,
[
{
"
nb_chars: 41
"
,
"
nb_chars
"
:
41
,
"
cer: 1.2683
"
,
"
cer
"
:
1.2683
,
"
nb_words: 9
"
,
"
nb_words
"
:
9
,
"
wer: 1.0
"
,
"
wer
"
:
1.0
,
"
nb_words_no_punct: 9
"
,
"
nb_words_no_punct
"
:
9
,
"
wer_no_punct: 1.0
"
,
"
wer_no_punct
"
:
1.0
,
"
nb_samples: 2
"
,
"
nb_samples
"
:
2
,
]
,
}
,
[
{
"
nb_chars: 49
"
,
"
nb_chars
"
:
49
,
"
cer: 1.1429
"
,
"
cer
"
:
1.1429
,
"
nb_words: 9
"
,
"
nb_words
"
:
9
,
"
wer: 1.0
"
,
"
wer
"
:
1.0
,
"
nb_words_no_punct: 9
"
,
"
nb_words_no_punct
"
:
9
,
"
wer_no_punct: 1.0
"
,
"
wer_no_punct
"
:
1.0
,
"
nb_samples: 2
"
,
"
nb_samples
"
:
2
,
]
,
}
,
),
),
),
),
)
)
...
@@ -136,11 +137,12 @@ def test_train_and_test(
...
@@ -136,11 +137,12 @@ def test_train_and_test(
tmp_path
tmp_path
/
training_config
[
"
training_params
"
][
"
output_folder
"
]
/
training_config
[
"
training_params
"
][
"
output_folder
"
]
/
"
results
"
/
"
results
"
/
f
"
predict_training-
{
split_name
}
_0.txt
"
/
f
"
predict_training-
{
split_name
}
_0.yaml
"
).
open
(
).
open
()
as
f
:
"
r
"
,
)
as
f
:
res
=
f
.
read
().
splitlines
()
# Remove the times from the results as they vary
# Remove the times from the results as they vary
res
=
[
result
for
result
in
res
if
"
time
"
not
in
result
]
res
=
{
metric
:
value
for
metric
,
value
in
yaml
.
safe_load
(
f
).
items
()
if
"
time
"
not
in
metric
}
assert
res
==
expected_res
assert
res
==
expected_res
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment