Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (11)
......@@ -27,9 +27,10 @@ from dan.datasets.extract.exceptions import (
from dan.datasets.extract.utils import (
download_image,
insert_token,
remove_spaces,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import EntityType, parse_tokens
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
......@@ -74,6 +75,7 @@ class ArkindexExtractor:
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
self.mapping = LMTokenMapping()
self.cache_dir = cache_dir
# Create cache dir if non existent
......@@ -83,6 +85,9 @@ class ArkindexExtractor:
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = []
self.language_tokens = []
self.language_lexicon = []
def find_image_in_cache(self, image_id: str) -> Path:
"""Images are cached to avoid downloading them twice. They are stored under a specific name,
......@@ -223,10 +228,25 @@ class ArkindexExtractor:
save_img(path=destination, img=image)
def format_text(self, text: str):
"""
Strip text and remove duplicate spaces and linebreaks if needed.
"""
if not self.keep_spaces:
text = remove_spaces(text)
text = normalize_spaces(text)
text = normalize_linebreaks(text)
return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(
[
self.mapping.encode[token] if token in self.mapping else token
for token in list(text.strip())
]
)
def process_element(
self,
element: Element,
......@@ -243,8 +263,11 @@ class ArkindexExtractor:
).with_suffix(self.image_extension)
self.get_image(element, image_path)
self.data[split][str(image_path)] = self.format_text(text)
self.charset = self.charset.union(set(text))
clean_text = self.format_text(text)
self.data[split][str(image_path)] = clean_text
self.charset = self.charset.union(set(clean_text))
if split == "train":
self.language_corpus.append(self.format_text_language_model(clean_text))
def process_parent(
self,
......@@ -280,6 +303,27 @@ class ArkindexExtractor:
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
def format_lm_files(self) -> None:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
for token in sorted(list(self.charset)):
assert (
token not in self.mapping.encode.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
def export(self):
(self.output / "labels.json").write_text(
json.dumps(
......@@ -288,6 +332,15 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_corpus.txt").write_text(
"\n".join(self.language_corpus)
)
(self.output / "language_tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
......@@ -312,6 +365,7 @@ class ArkindexExtractor:
# Progress bar updates
pbar.update()
pbar.refresh()
self.format_lm_files()
self.export()
......
......@@ -20,7 +20,8 @@ logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces
TRIM_REGEX = re.compile(r"\t?(?: +)")
TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
def _retry_log(retry_state, *args, **kwargs):
......@@ -76,7 +77,17 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
)
def remove_spaces(text: str) -> str:
# remove begin/ending spaces
# replace \t with regular space and consecutive spaces
return TRIM_REGEX.sub(" ", text.strip())
def normalize_linebreaks(text: str) -> str:
"""
Remove begin/ending linebreaks
Replace \r with regular linebreak and consecutive linebreaks
"""
return TRIM_RETURN_REGEX.sub("\n", text.strip())
def normalize_spaces(text: str) -> str:
"""
Remove begin/ending spaces
Replace \t with regular space and consecutive spaces
"""
return TRIM_SPACE_REGEX.sub(" ", text.strip())
# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch import relu, softmax
from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
from torch.nn.init import xavier_uniform_
from torchaudio.models.decoder import ctc_decoder
from dan.utils import LMTokenMapping, read_txt
class PositionalEncoding1D(Module):
......@@ -459,3 +463,132 @@ class GlobalHTADecoder(Module):
),
)
)
class CTCLanguageDecoder:
"""
Initialize a CTC decoder with n-gram language modeling.
Args:
language_model_path (str): path to a KenLM or ARPA language model
lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
decoding.
tokens_path (str): path to a file containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
language_model_weight (float): weight of the language model.
blank_token (str): token representing the blank/ctc symbol
unk_token (str): token representing unknown characters
sil_token (str): token representing the space character
"""
def __init__(
self,
language_model_path: str,
lexicon_path: str,
tokens_path: str,
language_model_weight: float = 1.0,
temperature: float = 1.0,
):
self.mapping = LMTokenMapping()
self.language_model_weight = language_model_weight
self.temperature = temperature
self.tokens_to_index = {
token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
}
self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
self.decoder = ctc_decoder(
lm=language_model_path,
lexicon=lexicon_path,
tokens=tokens_path,
lm_weight=self.language_model_weight,
blank_token=self.mapping.ctc.encoded,
unk_word=self.mapping.unknown.encoded,
sil_token=self.mapping.space.encoded,
nbest=1,
)
# No GPU support
self.device = torch.device("cpu")
def add_ctc_frames(self, batch_features, batch_frames):
"""
Add CTC frames between each characters to avoid duplicate removal
"""
batch_size, _, n_tokens = batch_features.shape
# Create tensor with high probability CTC token
high_prob = 0.99
low_prob = 1 - high_prob
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
* low_prob
/ (n_tokens - 1)
)
ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs.log()
# Insert CTC tensor between frames
for fn in range(batch_frames[0] - 1):
batch_features = torch.cat(
[
batch_features[:, : 2 * fn + 1, :],
ctc_probs,
batch_features[:, 2 * fn + 1 :, :],
],
dim=1,
)
# Update the number of frames
batch_frames = 2 * batch_frames - 1
return batch_features, batch_frames
def post_process(self, hypotheses, batch_sizes):
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
"""
out = {}
# Replace <space> by an actual space and format string
out["text"] = [
"".join(
[
self.mapping.display[token]
if token in self.mapping.display
else token
for token in hypothesis[0].words
]
)
for hypothesis in hypotheses
]
# Normalize confidence score
out["confidence"] = [
np.exp(
hypothesis[0].score / ((self.language_model_weight + 1) * length.item())
)
for hypothesis, length in zip(hypotheses, batch_sizes)
]
return out
def __call__(self, batch_features, batch_frames):
"""
Decode a feature vector using n-gram language modelling.
Args:
features (torch.tensor): feature vector of size (batch_size, n_tokens, n_frame).
batch_sizes (Union[List, torch.tensor]): actual length of predictions
Returns:
out (Dict[List]): a dictionary containing the hypotheses.
"""
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(
batch_features / self.temperature, dim=-1
)
batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
# No GPU support for torchaudio's ctc_decoder
batch_features = batch_features.to(self.device)
batch_frames = batch_frames.to(self.device)
# Decode
hypotheses = self.decoder(batch_features, batch_frames)
return self.post_process(hypotheses, batch_frames)
......@@ -167,4 +167,10 @@ def add_predict_parser(subcommands) -> None:
type=str,
required=False,
)
parser.add_argument(
"--use-language-model",
help="Whether to use an explicit language model to rescore text hypotheses.",
action="store_true",
required=False,
)
parser.set_defaults(func=run)
......@@ -12,7 +12,7 @@ import numpy as np
import torch
import yaml
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.predict.attention import (
Level,
......@@ -54,6 +54,7 @@ class DAN:
params_path: Path,
charset_path: Path,
mode: str = "eval",
use_language_model: bool = False,
) -> None:
"""
Load a trained model.
......@@ -61,6 +62,7 @@ class DAN:
:param params_path: Path to the parameters.
:param charset_path: Path to the charset.
:param mode: The mode to load the model (train or eval).
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
"""
parameters = yaml.safe_load(params_path.read_text())["parameters"]
parameters["decoder"]["device"] = self.device
......@@ -75,6 +77,7 @@ class DAN:
decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
logger.debug(f"Loaded model {model_path}")
if mode == "train":
......@@ -88,6 +91,16 @@ class DAN:
self.encoder = encoder
self.decoder = decoder
self.lm_decoder = None
if use_language_model:
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
)
self.mean, self.std = (
torch.tensor(parameters["mean"]) / 255,
torch.tensor(parameters["std"]) / 255,
......@@ -125,6 +138,7 @@ class DAN:
threshold_method: str = "otsu",
threshold_value: int = 0,
max_object_height: int = 50,
use_language_model: bool = False,
) -> dict:
"""
Run prediction on an input image.
......@@ -162,6 +176,13 @@ class DAN:
(batch_size,), dtype=torch.int, device=self.device
)
# end token index will be used for ctc
tot_pred = torch.zeros(
(batch_size, len(self.charset) + 1, self.max_chars),
dtype=torch.float,
device=self.device,
)
whole_output = list()
confidence_scores = list()
attention_maps = list()
......@@ -192,6 +213,9 @@ class DAN:
num_pred=1,
)
# output total logit prediction
tot_pred[:, :, i : i + 1] = pred
pred = pred / self.temperature
whole_output.append(output)
attention_maps.append(weights)
......@@ -242,6 +266,8 @@ class DAN:
out = {}
out["text"] = predicted_text
if use_language_model:
out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
if confidences:
out["confidences"] = confidence_scores
if attentions:
......@@ -296,6 +322,7 @@ def process_batch(
max_object_height: int,
tokens: Dict[str, EntityType],
start_token: str,
use_language_model: bool,
) -> None:
input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...")
......@@ -330,6 +357,7 @@ def process_batch(
threshold_value=threshold_value,
max_object_height=max_object_height,
start_token=start_token,
use_language_model=use_language_model,
)
logger.info("Prediction parsing...")
......@@ -337,6 +365,14 @@ def process_batch(
predicted_text = prediction["text"][idx]
result = {"text": predicted_text}
# Return LM results
if use_language_model:
result["language_model"] = {}
result["language_model"]["text"] = prediction["language_model"]["text"][idx]
result["language_model"]["confidence"] = prediction["language_model"][
"confidence"
][idx]
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][idx]
......@@ -435,6 +471,7 @@ def run(
batch_size: int,
tokens: Dict[str, EntityType],
start_token: str,
use_language_model: bool,
) -> None:
"""
Predict a single image save the output
......@@ -458,6 +495,7 @@ def run(
:param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used.
:param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
"""
# Create output directory if necessary
if not output.exists():
......@@ -467,7 +505,10 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval")
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
batch_size = 1 if use_language_model else batch_size
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
......@@ -489,4 +530,5 @@ def run(
max_object_height,
tokens,
start_token,
use_language_model,
)
......@@ -16,6 +16,26 @@ class MLflowNotInstalled(Exception):
"""
class Token(NamedTuple):
encoded: str
display: str
class LMTokenMapping(NamedTuple):
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
unknown: Token = Token("", "<unk>")
@property
def display(self):
return {a.encoded: a.display for a in self}
@property
def encode(self):
return {a.display: a.encoded for a in self}
class EntityType(NamedTuple):
start: str
end: str = ""
......@@ -137,3 +157,12 @@ def read_json(json_path: str) -> Dict:
return json.loads(filename.read_text())
except json.JSONDecodeError as e:
raise ArgumentTypeError(e)
def read_txt(txt_path: str) -> str:
"""
Read TXT file
"""
filename = Path(txt_path)
assert filename.exists(), f"{txt_path} does not resolve."
return filename.read_text()
......@@ -13,5 +13,6 @@ teklia-line-image-extractor==0.2.8rc4
tenacity==8.2.3
tensorboard==2.12.2
torch==2.0.0
torchaudio==2.0.1
torchvision==0.15.1
tqdm==4.65.0
......@@ -10,13 +10,20 @@ import pytest
from dan.datasets.extract.exceptions import NoEndTokenError
from dan.datasets.extract.extract import ArkindexExtractor
from dan.datasets.extract.utils import EntityType, insert_token, remove_spaces
from dan.datasets.extract.utils import (
EntityType,
insert_token,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import parse_tokens
from tests import FIXTURES
EXTRACTION_DATA_PATH = FIXTURES / "extraction"
TWO_SPACES_REGEX = re.compile(r" {2}")
ENTITY_TOKEN_SPACE = re.compile(r"[ⓢ|ⓕ|ⓑ] ")
TWO_SPACES_LM_REGEX = re.compile(r"⎵ ⎵")
# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
......@@ -135,8 +142,25 @@ def test_reconstruct_text(entity_separators, tokens, expected, text_before, text
(" \ttab", "tab"),
),
)
def test_remove_spaces(text, trimmed):
assert remove_spaces(text) == trimmed
def test_normalize_spaces(text, trimmed):
assert normalize_spaces(text) == trimmed
@pytest.mark.parametrize(
"text,trimmed",
(
("no_linebreaks", "no_linebreaks"),
("\nbeginning", "beginning"),
("ending\n", "ending"),
("\nboth\n", "both"),
("\n\n\nconsecutive", "consecutive"),
("\rcarriage_return", "carriage_return"),
("\r\ncarriage_return+linebreak", "carriage_return+linebreak"),
("\n\r\r\n\ncarriage_return+linebreak", "carriage_return+linebreak"),
),
)
def test_normalize_linebreaks(text, trimmed):
assert normalize_linebreaks(text) == trimmed
@pytest.mark.parametrize(
......@@ -306,6 +330,9 @@ def test_extract(
VAL_DIR / "text_line_val-page_1-line_2.jpg",
VAL_DIR / "text_line_val-page_1-line_3.jpg",
output / "labels.json",
output / "language_corpus.txt",
output / "language_lexicon.txt",
output / "language_tokens.txt",
]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
......@@ -402,6 +429,47 @@ def test_extract(
expected_charset.update(tokens)
assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset
# Check "language_corpus.txt"
expected_language_corpus = """ⓢ C a i l l e t ⎵ ⎵ ⓕ M a u r i c e ⎵ ⎵ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ⎵ ⎵ ⓕ M a r c e l ⎵ ⎵ ⓑ 1 0 . 8 . 0 6
ⓢ R o q u e s ⎵ ⎵ ⓕ E l o i ⎵ ⎵ ⓑ 1 1 . 1 0 . 0 4
ⓢ G i r o s ⎵ ⎵ ⓕ P a u l ⎵ ⎵ ⓑ 3 0 . 1 0 . 1 0"""
# Transcriptions with worker version are in lowercase
if transcription_entities_worker_version:
expected_language_corpus = expected_language_corpus.lower()
# If we do not load entities, remove tokens
if not load_entities:
token_translations = {f"{token} ": "" for token in tokens}
expected_language_corpus = ENTITY_TOKEN_SPACE.sub("", expected_language_corpus)
# Replace double spaces with regular space
if not keep_spaces:
expected_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_language_corpus
)
assert (output / "language_corpus.txt").read_text() == expected_language_corpus
# Check "language_tokens.txt"
expected_language_tokens = [
t if t != " " else "" for t in sorted(list(expected_charset))
]
expected_language_tokens.append("")
assert (output / "language_tokens.txt").read_text() == "\n".join(
expected_language_tokens
)
# Check "language_lexicon.txt"
expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (output / "language_lexicon.txt").read_text() == "\n".join(
expected_language_lexicon
)
# Check cropped images
for expected_path in expected_paths:
if expected_path.suffix != ".jpg":
......
......@@ -319,6 +319,7 @@ def test_run_prediction(
batch_size=1,
tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
start_token=None,
use_language_model=False,
)
prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
......@@ -518,6 +519,7 @@ def test_run_prediction_batch(
batch_size=batch_size,
tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
start_token=None,
use_language_model=False,
)
for image_name, expected_prediction in zip(image_names, expected_predictions):
......