Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
D
DAN
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Releases
Package registry
Container Registry
Operate
Terraform modules
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Automatic Text Recognition
DAN
Commits
2538568c
Commit
2538568c
authored
1 year ago
by
Solene Tarride
Browse files
Options
Downloads
Patches
Plain Diff
Implement CTCLanguageDecoder
parent
095667f4
No related branches found
No related tags found
1 merge request
!287
Support subword and word language models
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
dan/ocr/decoder.py
+61
-102
61 additions, 102 deletions
dan/ocr/decoder.py
dan/ocr/predict/__init__.py
+1
-1
1 addition, 1 deletion
dan/ocr/predict/__init__.py
dan/ocr/predict/inference.py
+12
-3
12 additions, 3 deletions
dan/ocr/predict/inference.py
dan/utils.py
+1
-3
1 addition, 3 deletions
dan/utils.py
with
75 additions
and
109 deletions
dan/ocr/decoder.py
+
61
−
102
View file @
2538568c
# -*- coding: utf-8 -*-
from
typing
import
Dict
,
List
,
Union
import
numpy
as
np
import
torch
from
torch
import
relu
,
softmax
from
torch.nn
import
Conv1d
,
Dropout
,
Embedding
,
LayerNorm
,
Linear
,
Module
,
ModuleList
from
torch.nn.init
import
xavier_uniform_
from
torchaudio.models.decoder
import
CTCHypothesis
,
ctc_decoder
from
torchaudio.models.decoder
import
ctc_decoder
from
dan.utils
import
LMTokenMapping
,
read_txt
from
dan.utils
import
read_txt
class
PositionalEncoding1D
(
Module
):
...
...
@@ -470,13 +468,17 @@ class GlobalHTADecoder(Module):
class
CTCLanguageDecoder
:
"""
Initialize a CTC decoder with n-gram language modeling.
:param language_model_path: Path to a KenLM or ARPA language model.
:param lexicon_path: Path to a lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free decoding.
:param tokens_path: Path to a file containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line.
:param language_model_weight: Weight of the language model.
:param temperature: Temperature for model calibreation.
Args:
language_model_path (str): path to a KenLM or ARPA language model
lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
decoding.
tokens_path (str): path to a file containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
language_model_weight (float): weight of the language model.
blank_token (str): token representing the blank/ctc symbol
unk_token (str): token representing unknown characters
sil_token (str): token representing the space character
"""
def
__init__
(
...
...
@@ -485,138 +487,95 @@ class CTCLanguageDecoder:
lexicon_path
:
str
,
tokens_path
:
str
,
language_model_weight
:
float
=
1.0
,
blank_token
:
str
=
"
<ctc>
"
,
unk_token
:
str
=
"
<unk>
"
,
sil_token
:
str
=
"
<space>
"
,
temperature
:
float
=
1.0
,
):
self
.
mapping
=
LMTokenMapping
()
self
.
language_model_weight
=
language_model_weight
self
.
temperature
=
temperature
self
.
tokens_to_index
=
{
token
:
i
for
i
,
token
in
enumerate
(
read_txt
(
tokens_path
).
split
(
"
\n
"
))
}
self
.
index_to_token
=
{
i
:
token
for
token
,
i
in
self
.
tokens_to_index
.
items
()}
self
.
blank_token_id
=
self
.
tokens_to_index
[
self
.
mapping
.
ctc
.
encoded
]
# Torchaudio's decoder
# https://pytorch.org/audio/master/generated/torchaudio.models.decoder.ctc_decoder.html
self
.
decoder
=
ctc_decoder
(
lm
=
language_model_path
,
lexicon
=
lexicon_path
,
tokens
=
tokens_path
,
lm_weight
=
self
.
language_model_weight
,
blank_token
=
self
.
mapping
.
ctc
.
encoded
,
sil_token
=
self
.
mapping
.
space
.
encoded
,
unk_word
=
"
⁇
"
,
lm_weight
=
language_model_weight
,
blank_token
=
blank_token
,
unk_word
=
unk_token
,
sil_token
=
sil_token
,
nbest
=
1
,
)
# No GPU support
self
.
device
=
torch
.
device
(
"
cpu
"
)
self
.
temperature
=
temperature
def
add_ctc_frames
(
self
,
batch_features
:
torch
.
FloatTensor
,
batch_frames
:
torch
.
LongTensor
)
->
tuple
[
torch
.
FloatTensor
,
torch
.
LongTensor
]:
self
.
tokens_to_idx
=
read_txt
(
tokens_path
).
split
(
"
\n
"
)
self
.
ctc_id
=
self
.
tokens_to_idx
.
index
(
blank_token
)
self
.
space_token
=
sil_token
def
add_ctc_frames
(
self
,
batch_features
):
"""
Add CTC frames between each characters to avoid duplicate removal
.
Add CTC frames between each characters to avoid duplicate removal
"""
high_prob
=
batch_features
.
max
()
low_prob
=
batch_features
.
min
()
batch_size
,
n_frames
,
n_tokens
=
batch_features
.
shape
# Reset probabilities for the CTC token
batch_features
[:,
:,
-
1
]
=
(
torch
.
ones
(
(
batch_size
,
n_frames
),
dtype
=
torch
.
float32
,
device
=
batch_features
.
device
,
)
*
low_prob
)
#
Create a frame
with
high
probability CTC token
#
column
with
1
probability
on
CTC token
ctc_probs
=
(
torch
.
ones
(
(
batch_size
,
1
,
n_tokens
),
dtype
=
torch
.
float32
,
device
=
batch_features
.
device
,
)
*
low_prob
torch
.
ones
((
batch_size
,
1
,
n_tokens
),
dtype
=
torch
.
float32
)
*
0.1
/
n_tokens
)
ctc_probs
[:,
:,
self
.
blank_token_id
]
=
high_prob
ctc_probs
=
ctc_probs
ctc_probs
[:,
:,
self
.
ctc_id
]
=
0.9
ctc_probs
=
ctc_probs
.
log
()
# Insert the CTC frame between regular frames
for
fn
in
range
(
batch_frames
.
max
()
-
1
):
for
i
in
range
(
n_frames
-
1
):
batch_features
=
torch
.
cat
(
[
batch_features
[:,
:
2
*
fn
+
1
,
:],
batch_features
[:,
2
*
i
+
1
:
,
:],
ctc_probs
,
batch_features
[:,
2
*
fn
+
1
:
,
:],
batch_features
[:,
:
2
*
i
+
1
,
:],
],
dim
=
1
,
)
return
batch_features
# Update the number of frames
batch_frames
=
2
*
batch_frames
-
1
return
batch_features
,
batch_frames
def
post_process
(
self
,
hypotheses
:
List
[
CTCHypothesis
],
batch_sizes
:
torch
.
LongTensor
)
->
Dict
[
str
,
List
[
Union
[
str
,
float
]]]:
def
post_process
(
self
,
hypotheses
):
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
:param hypotheses: List of hypotheses returned by the decoder.
:param batch_sizes: Prediction length of size batch_size.
:return: A dictionary containing the hypotheses and their confidences.
Post-process hypotheses to output JSON
"""
out
=
{}
#
Replace <space> by an actual space and format string
#
Export only the best hypothesis
out
[
"
text
"
]
=
[
""
.
join
(
[
self
.
mapping
.
display
[
self
.
index_to_token
[
token
]]
if
self
.
index_to_token
[
token
]
in
self
.
mapping
.
display
else
self
.
index_to_token
[
token
]
for
token
in
hypothesis
[
0
].
tokens
.
tolist
()
]
).
strip
()
""
.
join
(
hypothesis
[
0
].
words
).
replace
(
self
.
space_token
,
"
"
)
for
hypothesis
in
hypotheses
]
# Normalize confidence score
out
[
"
confidence
"
]
=
[
np
.
around
(
np
.
exp
(
hypothesis
[
0
].
score
/
((
self
.
language_model_weight
+
1
)
*
length
.
item
())
),
2
,
)
for
hypothesis
,
length
in
zip
(
hypotheses
,
batch_sizes
)
np
.
exp
(
hypothesis
[
0
].
score
/
hypothesis
[
0
].
timesteps
[
-
1
].
item
())
for
hypothesis
in
hypotheses
]
return
out
def
__call__
(
self
,
batch_features
:
torch
.
FloatTensor
,
batch_frames
:
torch
.
LongTensor
)
->
Dict
[
str
,
List
[
Union
[
str
,
float
]]]:
def
__call__
(
self
,
batch_features
,
batch_sizes
):
"""
Decode a feature vector using n-gram language modelling.
:param batch_features: Feature vector of size (batch_size, n_tokens, n_frames).
:param batch_frames: Prediction length of size batch_size.
:return: A dictionary containing the hypotheses and their confidences.
Args:
features (Any): feature vector of size (n_frame, batch_size, n_tokens).
Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
Returns:
out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
There is no character-based probability.
"""
# Reshape from (batch_size, n_tokens
, n_frames
) to (batch_size, n_frame
s
, n_tokens)
# Reshape from (
n_frame,
batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
batch_features
=
batch_features
.
permute
((
0
,
2
,
1
))
# Insert CTC frames to avoid getting rid of duplicates
# Make sure that the CTC token has low probs for other frames
batch_features
,
batch_frames
=
self
.
add_ctc_frames
(
batch_features
,
batch_frames
)
# Apply temperature scaling
batch_features
=
batch_features
/
self
.
temperature
# Apply log softmax
batch_features
=
torch
.
nn
.
functional
.
log_softmax
(
batch_features
/
self
.
temperature
,
dim
=-
1
)
batch_features
=
torch
.
nn
.
functional
.
log_softmax
(
batch_features
,
dim
=-
1
)
#
batch_features
=
self.
add_ctc_frames(batch_features)
# batch_sizes = batch_features.shape[0]
# No GPU support for torchaudio's ctc_decoder
batch_features
=
batch_features
.
to
(
self
.
device
)
batch_frames
=
batch_frames
.
to
(
self
.
device
)
device
=
torch
.
device
(
"
cpu
"
)
batch_features
=
batch_features
.
to
(
device
)
if
isinstance
(
batch_sizes
,
list
):
batch_sizes
=
torch
.
tensor
(
batch_sizes
)
batch_sizes
.
to
(
device
)
# Decode
hypotheses
=
self
.
decoder
(
batch_features
,
batch_
fram
es
)
return
self
.
post_process
(
hypotheses
,
batch_frames
)
hypotheses
=
self
.
decoder
(
batch_features
,
batch_
siz
es
)
return
self
.
post_process
(
hypotheses
)
This diff is collapsed.
Click to expand it.
dan/ocr/predict/__init__.py
+
1
−
1
View file @
2538568c
...
...
@@ -169,7 +169,7 @@ def add_predict_parser(subcommands) -> None:
)
parser
.
add_argument
(
"
--use-language-model
"
,
help
=
"
Whether to use an explicit language model to rescore text hypothes
e
s.
"
,
help
=
"
Whether to use an explicit language model to rescore text hypothes
i
s.
"
,
action
=
"
store_true
"
,
required
=
False
,
)
...
...
This diff is collapsed.
Click to expand it.
dan/ocr/predict/inference.py
+
12
−
3
View file @
2538568c
...
...
@@ -77,6 +77,16 @@ class DAN:
decoder
=
GlobalHTADecoder
(
parameters
[
"
decoder
"
]).
to
(
self
.
device
)
decoder
.
load_state_dict
(
checkpoint
[
"
decoder_state_dict
"
],
strict
=
True
)
self
.
lm_decoder
=
CTCLanguageDecoder
(
language_model_path
=
parameters
[
"
lm_decoder
"
][
"
language_model_path
"
],
lexicon_path
=
parameters
[
"
lm_decoder
"
][
"
lexicon_path
"
],
tokens_path
=
parameters
[
"
lm_decoder
"
][
"
tokens_path
"
],
language_model_weight
=
parameters
[
"
lm_decoder
"
][
"
language_model_weight
"
],
blank_token
=
parameters
[
"
lm_decoder
"
][
"
blank_token
"
],
unk_token
=
parameters
[
"
lm_decoder
"
][
"
unk_token
"
],
sil_token
=
parameters
[
"
lm_decoder
"
][
"
sil_token
"
],
)
logger
.
debug
(
f
"
Loaded model
{
model_path
}
"
)
if
mode
==
"
train
"
:
...
...
@@ -179,7 +189,6 @@ class DAN:
(
batch_size
,),
dtype
=
torch
.
int
,
device
=
self
.
device
)
# end token index will be used for ctc
tot_pred
=
torch
.
zeros
(
(
batch_size
,
len
(
self
.
charset
)
+
1
,
self
.
max_chars
),
dtype
=
torch
.
float
,
...
...
@@ -270,7 +279,7 @@ class DAN:
out
[
"
text
"
]
=
predicted_text
if
use_language_model
:
out
[
"
language_model
"
]
=
self
.
lm_decoder
(
tot_pred
,
predict
ion
_len
)
out
[
"
language_model
"
]
=
self
.
lm_decoder
(
tot_pred
,
predict
ed_tokens
_len
)
if
confidences
:
out
[
"
confidences
"
]
=
confidence_scores
if
attentions
:
...
...
@@ -466,7 +475,7 @@ def run(
:param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used.
:param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
:param use_language_model: Whether to use an explicit language model to rescore text hypothes
e
s.
:param use_language_model: Whether to use an explicit language model to rescore text hypothes
i
s.
"""
# Create output directory if necessary
if
not
output
.
exists
():
...
...
This diff is collapsed.
Click to expand it.
dan/utils.py
+
1
−
3
View file @
2538568c
...
...
@@ -163,9 +163,7 @@ def read_json(json_path: str) -> Dict:
def
read_txt
(
txt_path
:
str
)
->
str
:
"""
Read TXT file.
:param txt_path: Path of the text file to read.
:return: The content of the read file.
Read TXT file
"""
filename
=
Path
(
txt_path
)
assert
filename
.
exists
(),
f
"
{
txt_path
}
does not resolve.
"
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment