Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
1cccd99b
Commit
1cccd99b
authored
1 year ago
by
Solene Tarride
Browse files
Options
Downloads
Patches
Plain Diff
Fix shape of tot_prob
parent
1a96389b
No related branches found
No related tags found
1 merge request
!287
Support subword and word language models
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
dan/ocr/decoder.py
+43
-37
43 additions, 37 deletions
dan/ocr/decoder.py
dan/ocr/predict/inference.py
+10
-4
10 additions, 4 deletions
dan/ocr/predict/inference.py
with
53 additions
and
41 deletions
dan/ocr/decoder.py
+
43
−
37
View file @
1cccd99b
...
...
@@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu
from
torch.nn.init
import
xavier_uniform_
from
torchaudio.models.decoder
import
ctc_decoder
from
dan.utils
import
read_txt
from
dan.utils
import
LM_MAPPING
,
read_txt
class
PositionalEncoding1D
(
Module
):
...
...
@@ -487,56 +487,65 @@ class CTCLanguageDecoder:
lexicon_path
:
str
,
tokens_path
:
str
,
language_model_weight
:
float
=
1.0
,
blank_token
:
str
=
"
<ctc>
"
,
unk_token
:
str
=
"
<unk>
"
,
sil_token
:
str
=
"
<space>
"
,
temperature
:
float
=
1.0
,
):
self
.
space_token
=
LM_MAPPING
[
"
"
]
self
.
unknown_token
=
LM_MAPPING
[
"
<unk>
"
]
self
.
blank_token
=
LM_MAPPING
[
"
<ctc>
"
]
self
.
language_model_weight
=
language_model_weight
self
.
temperature
=
temperature
self
.
tokens_to_index
=
{
token
:
i
for
i
,
token
in
enumerate
(
read_txt
(
tokens_path
).
split
(
"
\n
"
))
}
self
.
blank_token_id
=
self
.
tokens_to_index
[
self
.
blank_token
]
self
.
decoder
=
ctc_decoder
(
lm
=
language_model_path
,
lexicon
=
lexicon_path
,
tokens
=
tokens_path
,
lm_weight
=
language_model_weight
,
blank_token
=
blank_token
,
unk_word
=
unk
_token
,
sil_token
=
s
il
_token
,
lm_weight
=
self
.
language_model_weight
,
blank_token
=
self
.
blank_token
,
unk_word
=
self
.
unknown
_token
,
sil_token
=
s
elf
.
space
_token
,
nbest
=
1
,
)
self
.
temperature
=
temperature
self
.
space_token
=
sil_token
self
.
tokens_to_idx
=
read_txt
(
tokens_path
).
split
(
"
\n
"
)
self
.
blank_id
=
self
.
tokens_to_idx
.
index
(
blank_token
)
# No GPU support
self
.
device
=
torch
.
device
(
"
cpu
"
)
def
add_ctc_frames
(
self
,
batch_features
):
def
add_ctc_frames
(
self
,
batch_features
,
batch_frames
):
"""
Add CTC frames between each characters to avoid duplicate removal
"""
batch_size
,
n_frames
,
n_tokens
=
batch_features
.
shape
batch_size
,
_
,
n_tokens
=
batch_features
.
shape
torch
.
clone
(
batch_features
)
# visualize_debug(batch_features.exp()[0, :batch_frames[0], :].numpy(), "probs.jpg", False)
# Create tensor with high probability CTC token
high_prob
=
0.99
low_prob
=
0.01
low_prob
=
1
-
high_prob
ctc_probs
=
(
torch
.
ones
((
batch_size
,
1
,
n_tokens
),
dtype
=
torch
.
float32
)
*
low_prob
/
(
n_tokens
-
1
)
)
ctc_probs
[:,
:,
self
.
blank_id
]
=
high_prob
ctc_probs
[:,
:,
self
.
blank_
token_
id
]
=
high_prob
ctc_probs
=
ctc_probs
.
log
()
# Insert CTC tensor between frames
for
i
in
range
(
n
_frames
):
for
fn
in
range
(
batch
_frames
[
0
]
-
1
):
batch_features
=
torch
.
cat
(
[
batch_features
[:,
:
2
*
i
+
1
,
:],
batch_features
[:,
:
2
*
fn
+
1
,
:],
ctc_probs
,
batch_features
[:,
2
*
i
+
1
:,
:],
batch_features
[:,
2
*
fn
+
1
:,
:],
],
dim
=
1
,
)
return
batch_features
def
post_process
(
self
,
hypotheses
):
# Update the number of frames
batch_frames
=
2
*
batch_frames
-
1
return
batch_features
,
batch_frames
def
post_process
(
self
,
hypotheses
,
batch_sizes
):
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
"""
...
...
@@ -548,12 +557,14 @@ class CTCLanguageDecoder:
]
# Normalize confidence score
out
[
"
confidence
"
]
=
[
np
.
exp
(
hypothesis
[
0
].
score
/
hypothesis
[
0
].
timesteps
[
-
1
].
item
())
for
hypothesis
in
hypotheses
np
.
exp
(
hypothesis
[
0
].
score
/
((
self
.
language_model_weight
+
1
)
*
length
.
item
())
)
for
hypothesis
,
length
in
zip
(
hypotheses
,
batch_sizes
)
]
return
out
def
__call__
(
self
,
batch_features
,
batch_
siz
es
):
def
__call__
(
self
,
batch_features
,
batch_
fram
es
):
"""
Decode a feature vector using n-gram language modelling.
Args:
...
...
@@ -565,21 +576,16 @@ class CTCLanguageDecoder:
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features
=
batch_features
.
permute
((
0
,
2
,
1
))
# Apply temperature scaling
batch_features
=
batch_features
/
self
.
temperature
# Apply log softmax
batch_features
=
torch
.
nn
.
functional
.
log_softmax
(
batch_features
,
dim
=-
1
)
batch_features
=
self
.
add_ctc_frames
(
batch_features
)
batch_sizes
*=
2
batch_features
=
torch
.
nn
.
functional
.
log_softmax
(
batch_features
/
self
.
temperature
,
dim
=-
1
)
batch_features
,
batch_frames
=
self
.
add_ctc_frames
(
batch_features
,
batch_frames
)
# No GPU support for torchaudio's ctc_decoder
device
=
torch
.
device
(
"
cpu
"
)
batch_features
=
batch_features
.
to
(
device
)
if
isinstance
(
batch_sizes
,
list
):
batch_sizes
=
torch
.
tensor
(
batch_sizes
)
batch_sizes
.
to
(
device
)
batch_features
=
batch_features
.
to
(
self
.
device
)
batch_frames
=
batch_frames
.
to
(
self
.
device
)
# Decode
hypotheses
=
self
.
decoder
(
batch_features
,
batch_
siz
es
)
return
self
.
post_process
(
hypotheses
)
hypotheses
=
self
.
decoder
(
batch_features
,
batch_
fram
es
)
return
self
.
post_process
(
hypotheses
,
batch_frames
)
This diff is collapsed.
Click to expand it.
dan/ocr/predict/inference.py
+
10
−
4
View file @
1cccd99b
...
...
@@ -90,6 +90,7 @@ class DAN:
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
lm_decoder
=
None
if
use_language_model
:
self
.
lm_decoder
=
CTCLanguageDecoder
(
...
...
@@ -97,9 +98,6 @@ class DAN:
lexicon_path
=
parameters
[
"
lm_decoder
"
][
"
lexicon_path
"
],
tokens_path
=
parameters
[
"
lm_decoder
"
][
"
tokens_path
"
],
language_model_weight
=
parameters
[
"
lm_decoder
"
][
"
language_model_weight
"
],
blank_token
=
parameters
[
"
lm_decoder
"
][
"
blank_token
"
],
unk_token
=
parameters
[
"
lm_decoder
"
][
"
unk_token
"
],
sil_token
=
parameters
[
"
lm_decoder
"
][
"
sil_token
"
],
)
self
.
mean
,
self
.
std
=
(
...
...
@@ -178,6 +176,7 @@ class DAN:
(
batch_size
,),
dtype
=
torch
.
int
,
device
=
self
.
device
)
# end token index will be used for ctc
tot_pred
=
torch
.
zeros
(
(
batch_size
,
len
(
self
.
charset
)
+
1
,
self
.
max_chars
),
dtype
=
torch
.
float
,
...
...
@@ -268,7 +267,7 @@ class DAN:
out
[
"
text
"
]
=
predicted_text
if
use_language_model
:
out
[
"
language_model
"
]
=
self
.
lm_decoder
(
tot_pred
,
predict
ed_tokens
_len
)
out
[
"
language_model
"
]
=
self
.
lm_decoder
(
tot_pred
,
predict
ion
_len
)
if
confidences
:
out
[
"
confidences
"
]
=
confidence_scores
if
attentions
:
...
...
@@ -474,7 +473,14 @@ def run(
cuda_device
=
f
"
:
{
gpu_device
}
"
if
gpu_device
is
not
None
else
""
device
=
f
"
cuda
{
cuda_device
}
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
dan_model
=
DAN
(
device
,
temperature
)
<<<<<<<
HEAD
dan_model
.
load
(
model
,
parameters
,
charset
,
mode
=
"
eval
"
)
=======
dan_model
.
load
(
model
,
parameters
,
charset
,
mode
=
"
eval
"
,
use_language_model
=
use_language_model
)
batch_size
=
1
if
use_language_model
else
batch_size
>>>>>>>
e7c611f
(
Fix
shape
of
tot_prob
)
images
=
image_dir
.
rglob
(
f
"
*
{
image_extension
}
"
)
if
not
image
else
[
image
]
for
image_batch
in
list_to_batches
(
images
,
n
=
batch_size
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment