Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package Registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
a47466ef
Commit
a47466ef
authored
1 year ago
by
Yoann Schneider
Committed by
Manon Blanco
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
Metrics manager refactoring
parent
0f154c32
No related branches found
No related tags found
1 merge request
!249
Metrics manager refactoring
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
dan/ocr/manager/metrics.py
+110
-181
110 additions, 181 deletions
dan/ocr/manager/metrics.py
with
110 additions
and
181 deletions
dan/ocr/manager/metrics.py
+
110
−
181
View file @
a47466ef
# -*- coding: utf-8 -*-
import
re
from
collections
import
defaultdict
from
operator
import
attrgetter
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
Dict
,
List
,
Optional
import
editdistance
import
numpy
as
np
from
dan.utils
import
parse_tokens
# Remove punctuation
REGEX_PUNCTUATION
=
re
.
compile
(
r
"
([\[\]{}/\\()\"
'
&+*=<>?.;:,!\-—_€#%°])
"
)
# Remove consecutive linebreaks
REGEX_CONSECUTIVE_LINEBREAKS
=
re
.
compile
(
r
"
\n+
"
)
# Remove consecutive spaces
REGEX_CONSECUTIVE_SPACES
=
re
.
compile
(
r
"
+
"
)
# Keep only one space character
REGEX_ONLY_ONE_SPACE
=
re
.
compile
(
r
"
\s+
"
)
class
MetricManager
:
def
__init__
(
self
,
metric_names
,
dataset_name
,
tokens
:
Optional
[
Path
]):
self
.
dataset_name
=
dataset_name
def
__init__
(
self
,
metric_names
:
List
[
str
],
dataset_name
:
str
,
tokens
:
Optional
[
Path
]
):
self
.
dataset_name
:
str
=
dataset_name
self
.
remove_tokens
:
str
=
None
self
.
layout_tokens
=
None
if
tokens
:
tokens
=
parse_tokens
(
tokens
)
self
.
layout_tokens
=
""
.
join
(
layout_tokens
=
""
.
join
(
list
(
map
(
attrgetter
(
"
start
"
),
tokens
.
values
()))
+
list
(
map
(
attrgetter
(
"
end
"
),
tokens
.
values
()))
)
self
.
metric_names
=
metric_names
self
.
epoch_metrics
=
None
self
.
remove_tokens
:
re
.
Pattern
=
re
.
compile
(
r
"
([
"
+
layout_tokens
+
"
])
"
)
self
.
metric_names
:
List
[
str
]
=
metric_names
self
.
epoch_metrics
=
defaultdict
(
list
)
self
.
linked_metrics
=
{
"
cer
"
:
[
"
edit_chars
"
,
"
nb_chars
"
],
"
wer
"
:
[
"
edit_words
"
,
"
nb_words
"
],
"
wer_no_punct
"
:
[
"
edit_words_no_punct
"
,
"
nb_words_no_punct
"
],
}
def
edit_cer_from_string
(
self
,
gt
:
str
,
pred
:
str
):
"""
Format and compute edit distance between two strings at character level
"""
gt
=
self
.
format_string_for_cer
(
gt
)
pred
=
self
.
format_string_for_cer
(
pred
)
return
editdistance
.
eval
(
gt
,
pred
)
def
nb_chars_cer_from_string
(
self
,
gt
:
str
)
->
int
:
"""
Compute length after formatting of ground truth string
"""
return
len
(
self
.
format_string_for_cer
(
gt
))
self
.
init_metrics
()
def
format_string_for_wer
(
self
,
text
:
str
,
remove_punct
:
bool
=
False
):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
if
remove_punct
:
text
=
REGEX_PUNCTUATION
.
sub
(
""
,
text
)
if
self
.
remove_tokens
is
not
None
:
text
=
self
.
remove_tokens
.
sub
(
""
,
text
)
return
REGEX_ONLY_ONE_SPACE
.
sub
(
"
"
,
text
).
strip
().
split
(
"
"
)
def
init_metrics
(
self
):
def
format_string_for_cer
(
self
,
text
:
str
):
"""
Initialization of the metrics specified in metrics_name
Format string for CER computation: remove layout tokens and extra spaces
"""
self
.
epoch_metrics
=
{
"
nb_samples
"
:
list
(),
"
names
"
:
list
(),
}
if
self
.
remove_tokens
is
not
None
:
text
=
self
.
remove_tokens
.
sub
(
""
,
text
)
for
metric_name
in
self
.
metric_names
:
if
metric_name
in
self
.
linked_metrics
:
for
linked_metric_name
in
self
.
linked_metrics
[
metric_name
]:
if
linked_metric_name
not
in
self
.
epoch_metrics
:
self
.
epoch_metrics
[
linked_metric_name
]
=
list
()
else
:
self
.
epoch_metrics
[
metric_name
]
=
list
()
text
=
REGEX_CONSECUTIVE_LINEBREAKS
.
sub
(
"
\n
"
,
text
)
return
REGEX_CONSECUTIVE_SPACES
.
sub
(
"
"
,
text
).
strip
()
def
update_metrics
(
self
,
batch_metrics
):
"""
Add batch metrics to the metrics
"""
for
key
in
batch_metrics
:
if
key
in
self
.
epoch_metrics
:
self
.
epoch_metrics
[
key
]
+=
batch_metrics
[
key
]
self
.
epoch_metrics
[
key
]
+=
batch_metrics
[
key
]
def
get_display_values
(
self
,
output
=
False
):
def
get_display_values
(
self
,
output
:
bool
=
False
):
"""
f
ormat metrics values for shell display purposes
F
ormat metrics values for shell display purposes
"""
metric_names
=
self
.
metric_names
.
copy
()
if
output
:
metric_names
.
ext
end
(
[
"
nb_samples
"
]
)
metric_names
.
app
end
(
"
nb_samples
"
)
display_values
=
dict
()
for
metric_name
in
metric_names
:
value
=
None
if
output
:
if
metric_name
==
"
nb_samples
"
:
value
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
metric_name
]))
elif
metric_name
==
"
time
"
:
match
metric_name
:
case
"
time
"
|
"
nb_samples
"
:
if
not
output
:
continue
value
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
metric_name
]))
sample_time
=
value
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_samples
"
])
display_values
[
"
sample_time
"
]
=
float
(
round
(
sample_time
,
4
))
if
metric_name
==
"
cer
"
:
value
=
float
(
np
.
sum
(
self
.
epoch_metrics
[
"
edit_chars
"
])
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_chars
"
])
)
if
output
:
display_values
[
"
nb_chars
"
]
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
"
nb_chars
"
])
if
metric_name
==
"
time
"
:
sample_time
=
value
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_samples
"
])
display_values
[
"
sample_time
"
]
=
float
(
round
(
sample_time
,
4
))
display_values
[
metric_name
]
=
value
continue
case
"
cer
"
:
num_name
,
denom_name
=
"
edit_chars
"
,
"
nb_chars
"
case
"
wer
"
|
"
wer_no_punct
"
:
suffix
=
metric_name
[
3
:]
num_name
,
denom_name
=
"
edit_words
"
+
suffix
,
"
nb_words
"
+
suffix
case
"
loss
"
|
"
loss_ce
"
:
display_values
[
metric_name
]
=
round
(
float
(
np
.
average
(
self
.
epoch_metrics
[
metric_name
],
weights
=
np
.
array
(
self
.
epoch_metrics
[
"
nb_samples
"
]),
),
),
4
,
)
elif
metric_name
==
"
wer
"
:
value
=
float
(
np
.
sum
(
self
.
epoch_metrics
[
"
edit_words
"
])
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words
"
])
)
if
output
:
display_values
[
"
nb_words
"
]
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words
"
])
)
elif
metric_name
==
"
wer_no_punct
"
:
value
=
float
(
np
.
sum
(
self
.
epoch_metrics
[
"
edit_words_no_punct
"
])
/
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words_no_punct
"
])
)
if
output
:
display_values
[
"
nb_words_no_punct
"
]
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
"
nb_words_no_punct
"
])
)
elif
metric_name
in
[
"
loss
"
,
"
loss_ce
"
,
]:
value
=
float
(
np
.
average
(
self
.
epoch_metrics
[
metric_name
],
weights
=
np
.
array
(
self
.
epoch_metrics
[
"
nb_samples
"
]),
)
)
elif
value
is
None
:
continue
continue
case
_
:
continue
value
=
float
(
np
.
sum
(
self
.
epoch_metrics
[
num_name
])
/
np
.
sum
(
self
.
epoch_metrics
[
denom_name
])
)
if
output
:
display_values
[
denom_name
]
=
int
(
np
.
sum
(
self
.
epoch_metrics
[
denom_name
]))
display_values
[
metric_name
]
=
round
(
value
,
4
)
return
display_values
def
compute_metrics
(
self
,
values
,
metric_names
):
def
compute_metrics
(
self
,
values
:
Dict
[
str
,
int
|
float
],
metric_names
:
List
[
str
]
)
->
Dict
[
str
,
List
[
int
|
float
]]:
metrics
=
{
"
nb_samples
"
:
[
values
[
"
nb_samples
"
],
...
...
@@ -125,111 +135,30 @@ class MetricManager:
}
if
"
time
"
in
values
:
metrics
[
"
time
"
]
=
[
values
[
"
time
"
]]
gt
,
prediction
=
values
[
"
str_y
"
],
values
[
"
str_x
"
]
for
metric_name
in
metric_names
:
if
metric_name
==
"
cer
"
:
metrics
[
"
edit_chars
"
]
=
[
edit_cer_from_string
(
u
,
v
,
self
.
layout_tokens
)
for
u
,
v
in
zip
(
values
[
"
str_y
"
],
values
[
"
str_x
"
])
]
metrics
[
"
nb_chars
"
]
=
[
nb_chars_cer_from_string
(
gt
,
self
.
layout_tokens
)
for
gt
in
values
[
"
str_y
"
]
]
elif
metric_name
==
"
wer
"
:
split_gt
=
[
format_string_for_wer
(
gt
,
self
.
layout_tokens
)
for
gt
in
values
[
"
str_y
"
]
]
split_pred
=
[
format_string_for_wer
(
pred
,
self
.
layout_tokens
)
for
pred
in
values
[
"
str_x
"
]
]
metrics
[
"
edit_words
"
]
=
[
edit_wer_from_formatted_split_text
(
gt
,
pred
)
for
(
gt
,
pred
)
in
zip
(
split_gt
,
split_pred
)
]
metrics
[
"
nb_words
"
]
=
[
len
(
gt
)
for
gt
in
split_gt
]
elif
metric_name
==
"
wer_no_punct
"
:
split_gt
=
[
format_string_for_wer
(
gt
,
self
.
layout_tokens
,
remove_punct
=
True
)
for
gt
in
values
[
"
str_y
"
]
]
split_pred
=
[
format_string_for_wer
(
pred
,
self
.
layout_tokens
,
remove_punct
=
True
)
for
pred
in
values
[
"
str_x
"
]
]
metrics
[
"
edit_words_no_punct
"
]
=
[
edit_wer_from_formatted_split_text
(
gt
,
pred
)
for
(
gt
,
pred
)
in
zip
(
split_gt
,
split_pred
)
]
metrics
[
"
nb_words_no_punct
"
]
=
[
len
(
gt
)
for
gt
in
split_gt
]
elif
metric_name
in
[
"
loss_ce
"
,
"
loss
"
,
]:
metrics
[
metric_name
]
=
[
values
[
metric_name
],
]
match
metric_name
:
case
"
cer
"
:
metrics
[
"
edit_chars
"
]
=
list
(
map
(
self
.
edit_cer_from_string
,
gt
,
prediction
)
)
metrics
[
"
nb_chars
"
]
=
list
(
map
(
self
.
nb_chars_cer_from_string
,
gt
))
case
"
wer
"
|
"
wer_no_punct
"
:
suffix
=
metric_name
[
3
:]
split_gt
=
list
(
map
(
self
.
format_string_for_wer
,
gt
,
[
bool
(
suffix
)]))
split_pred
=
list
(
map
(
self
.
format_string_for_wer
,
prediction
,
[
bool
(
suffix
)])
)
metrics
[
"
edit_words
"
+
suffix
]
=
list
(
map
(
editdistance
.
eval
,
split_gt
,
split_pred
)
)
metrics
[
"
nb_words
"
+
suffix
]
=
list
(
map
(
len
,
split_gt
))
case
"
loss
"
|
"
loss_ce
"
:
metrics
[
metric_name
]
=
[
values
[
metric_name
],
]
return
metrics
def
get
(
self
,
name
):
def
get
(
self
,
name
:
str
):
return
self
.
epoch_metrics
[
name
]
def
keep_all_but_ner_tokens
(
str
,
tokens
):
"""
Remove all ner tokens from string
"""
return
re
.
sub
(
"
([
"
+
tokens
+
"
])
"
,
""
,
str
)
def
edit_cer_from_string
(
gt
,
pred
,
layout_tokens
=
None
):
"""
Format and compute edit distance between two strings at character level
"""
gt
=
format_string_for_cer
(
gt
,
layout_tokens
)
pred
=
format_string_for_cer
(
pred
,
layout_tokens
)
return
editdistance
.
eval
(
gt
,
pred
)
def
nb_chars_cer_from_string
(
gt
,
layout_tokens
=
None
):
"""
Compute length after formatting of ground truth string
"""
return
len
(
format_string_for_cer
(
gt
,
layout_tokens
))
def
format_string_for_wer
(
str
,
layout_tokens
,
remove_punct
=
False
):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
if
remove_punct
:
str
=
re
.
sub
(
r
"
([\[\]{}/\\()\"
'
&+*=<>?.;:,!\-—_€#%°])
"
,
""
,
str
)
# remove punctuation
if
layout_tokens
is
not
None
:
str
=
keep_all_but_ner_tokens
(
str
,
layout_tokens
)
# remove layout tokens from metric
str
=
re
.
sub
(
"
([
\n
])+
"
,
"
"
,
str
).
strip
()
# keep only one space character
return
str
.
split
(
"
"
)
def
format_string_for_cer
(
str
,
layout_tokens
):
"""
Format string for CER computation: remove layout tokens and extra spaces
"""
if
layout_tokens
is
not
None
:
str
=
keep_all_but_ner_tokens
(
str
,
layout_tokens
)
# remove layout tokens from metric
str
=
re
.
sub
(
"
([
\n
])+
"
,
"
\n
"
,
str
)
# remove consecutive line breaks
str
=
re
.
sub
(
"
([ ])+
"
,
"
"
,
str
).
strip
()
# remove consecutive spaces
return
str
def
edit_wer_from_formatted_split_text
(
gt
,
pred
):
"""
Compute edit distance at word level from formatted string as list
"""
return
editdistance
.
eval
(
gt
,
pred
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment