Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (11)
...@@ -27,9 +27,10 @@ from dan.datasets.extract.exceptions import ( ...@@ -27,9 +27,10 @@ from dan.datasets.extract.exceptions import (
from dan.datasets.extract.utils import ( from dan.datasets.extract.utils import (
download_image, download_image,
insert_token, insert_token,
remove_spaces, normalize_linebreaks,
normalize_spaces,
) )
from dan.utils import EntityType, parse_tokens from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract, read_img, save_img from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
...@@ -74,6 +75,7 @@ class ArkindexExtractor: ...@@ -74,6 +75,7 @@ class ArkindexExtractor:
self.max_width = max_width self.max_width = max_width
self.max_height = max_height self.max_height = max_height
self.image_extension = image_extension self.image_extension = image_extension
self.mapping = LMTokenMapping()
self.cache_dir = cache_dir self.cache_dir = cache_dir
# Create cache dir if non existent # Create cache dir if non existent
...@@ -83,6 +85,9 @@ class ArkindexExtractor: ...@@ -83,6 +85,9 @@ class ArkindexExtractor:
self.data: Dict = defaultdict(dict) self.data: Dict = defaultdict(dict)
self.charset = set() self.charset = set()
self.language_corpus = []
self.language_tokens = []
self.language_lexicon = []
def find_image_in_cache(self, image_id: str) -> Path: def find_image_in_cache(self, image_id: str) -> Path:
"""Images are cached to avoid downloading them twice. They are stored under a specific name, """Images are cached to avoid downloading them twice. They are stored under a specific name,
...@@ -223,10 +228,25 @@ class ArkindexExtractor: ...@@ -223,10 +228,25 @@ class ArkindexExtractor:
save_img(path=destination, img=image) save_img(path=destination, img=image)
def format_text(self, text: str): def format_text(self, text: str):
"""
Strip text and remove duplicate spaces and linebreaks if needed.
"""
if not self.keep_spaces: if not self.keep_spaces:
text = remove_spaces(text) text = normalize_spaces(text)
text = normalize_linebreaks(text)
return text.strip() return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(
[
self.mapping.encode[token] if token in self.mapping else token
for token in list(text.strip())
]
)
def process_element( def process_element(
self, self,
element: Element, element: Element,
...@@ -243,8 +263,11 @@ class ArkindexExtractor: ...@@ -243,8 +263,11 @@ class ArkindexExtractor:
).with_suffix(self.image_extension) ).with_suffix(self.image_extension)
self.get_image(element, image_path) self.get_image(element, image_path)
self.data[split][str(image_path)] = self.format_text(text) clean_text = self.format_text(text)
self.charset = self.charset.union(set(text)) self.data[split][str(image_path)] = clean_text
self.charset = self.charset.union(set(clean_text))
if split == "train":
self.language_corpus.append(self.format_text_language_model(clean_text))
def process_parent( def process_parent(
self, self,
...@@ -280,6 +303,27 @@ class ArkindexExtractor: ...@@ -280,6 +303,27 @@ class ArkindexExtractor:
except ProcessingError as e: except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}") logger.warning(f"Skipping {element.id}: {str(e)}")
def format_lm_files(self) -> None:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
for token in sorted(list(self.charset)):
assert (
token not in self.mapping.encode.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
def export(self): def export(self):
(self.output / "labels.json").write_text( (self.output / "labels.json").write_text(
json.dumps( json.dumps(
...@@ -288,6 +332,15 @@ class ArkindexExtractor: ...@@ -288,6 +332,15 @@ class ArkindexExtractor:
indent=4, indent=4,
) )
) )
(self.output / "language_corpus.txt").write_text(
"\n".join(self.language_corpus)
)
(self.output / "language_tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes( (self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset))) pickle.dumps(sorted(list(self.charset)))
) )
...@@ -312,6 +365,7 @@ class ArkindexExtractor: ...@@ -312,6 +365,7 @@ class ArkindexExtractor:
# Progress bar updates # Progress bar updates
pbar.update() pbar.update()
pbar.refresh() pbar.refresh()
self.format_lm_files()
self.export() self.export()
......
...@@ -20,7 +20,8 @@ logger = logging.getLogger(__name__) ...@@ -20,7 +20,8 @@ logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT = (30, 60) DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces # replace \t with regular space and consecutive spaces
TRIM_REGEX = re.compile(r"\t?(?: +)") TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
def _retry_log(retry_state, *args, **kwargs): def _retry_log(retry_state, *args, **kwargs):
...@@ -76,7 +77,17 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - ...@@ -76,7 +77,17 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
) )
def remove_spaces(text: str) -> str: def normalize_linebreaks(text: str) -> str:
# remove begin/ending spaces """
# replace \t with regular space and consecutive spaces Remove begin/ending linebreaks
return TRIM_REGEX.sub(" ", text.strip()) Replace \r with regular linebreak and consecutive linebreaks
"""
return TRIM_RETURN_REGEX.sub("\n", text.strip())
def normalize_spaces(text: str) -> str:
"""
Remove begin/ending spaces
Replace \t with regular space and consecutive spaces
"""
return TRIM_SPACE_REGEX.sub(" ", text.strip())
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
import torch import torch
from torch import relu, softmax from torch import relu, softmax
from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
from torch.nn.init import xavier_uniform_ from torch.nn.init import xavier_uniform_
from torchaudio.models.decoder import ctc_decoder
from dan.utils import LMTokenMapping, read_txt
class PositionalEncoding1D(Module): class PositionalEncoding1D(Module):
...@@ -459,3 +463,132 @@ class GlobalHTADecoder(Module): ...@@ -459,3 +463,132 @@ class GlobalHTADecoder(Module):
), ),
) )
) )
class CTCLanguageDecoder:
"""
Initialize a CTC decoder with n-gram language modeling.
Args:
language_model_path (str): path to a KenLM or ARPA language model
lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
decoding.
tokens_path (str): path to a file containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
language_model_weight (float): weight of the language model.
blank_token (str): token representing the blank/ctc symbol
unk_token (str): token representing unknown characters
sil_token (str): token representing the space character
"""
def __init__(
self,
language_model_path: str,
lexicon_path: str,
tokens_path: str,
language_model_weight: float = 1.0,
temperature: float = 1.0,
):
self.mapping = LMTokenMapping()
self.language_model_weight = language_model_weight
self.temperature = temperature
self.tokens_to_index = {
token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
}
self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
self.decoder = ctc_decoder(
lm=language_model_path,
lexicon=lexicon_path,
tokens=tokens_path,
lm_weight=self.language_model_weight,
blank_token=self.mapping.ctc.encoded,
unk_word=self.mapping.unknown.encoded,
sil_token=self.mapping.space.encoded,
nbest=1,
)
# No GPU support
self.device = torch.device("cpu")
def add_ctc_frames(self, batch_features, batch_frames):
"""
Add CTC frames between each characters to avoid duplicate removal
"""
batch_size, _, n_tokens = batch_features.shape
# Create tensor with high probability CTC token
high_prob = 0.99
low_prob = 1 - high_prob
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
* low_prob
/ (n_tokens - 1)
)
ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs.log()
# Insert CTC tensor between frames
for fn in range(batch_frames[0] - 1):
batch_features = torch.cat(
[
batch_features[:, : 2 * fn + 1, :],
ctc_probs,
batch_features[:, 2 * fn + 1 :, :],
],
dim=1,
)
# Update the number of frames
batch_frames = 2 * batch_frames - 1
return batch_features, batch_frames
def post_process(self, hypotheses, batch_sizes):
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
"""
out = {}
# Replace <space> by an actual space and format string
out["text"] = [
"".join(
[
self.mapping.display[token]
if token in self.mapping.display
else token
for token in hypothesis[0].words
]
)
for hypothesis in hypotheses
]
# Normalize confidence score
out["confidence"] = [
np.exp(
hypothesis[0].score / ((self.language_model_weight + 1) * length.item())
)
for hypothesis, length in zip(hypotheses, batch_sizes)
]
return out
def __call__(self, batch_features, batch_frames):
"""
Decode a feature vector using n-gram language modelling.
Args:
features (torch.tensor): feature vector of size (batch_size, n_tokens, n_frame).
batch_sizes (Union[List, torch.tensor]): actual length of predictions
Returns:
out (Dict[List]): a dictionary containing the hypotheses.
"""
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(
batch_features / self.temperature, dim=-1
)
batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
# No GPU support for torchaudio's ctc_decoder
batch_features = batch_features.to(self.device)
batch_frames = batch_frames.to(self.device)
# Decode
hypotheses = self.decoder(batch_features, batch_frames)
return self.post_process(hypotheses, batch_frames)
...@@ -167,4 +167,10 @@ def add_predict_parser(subcommands) -> None: ...@@ -167,4 +167,10 @@ def add_predict_parser(subcommands) -> None:
type=str, type=str,
required=False, required=False,
) )
parser.add_argument(
"--use-language-model",
help="Whether to use an explicit language model to rescore text hypotheses.",
action="store_true",
required=False,
)
parser.set_defaults(func=run) parser.set_defaults(func=run)
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
import torch import torch
import yaml import yaml
from dan.ocr.decoder import GlobalHTADecoder from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder from dan.ocr.encoder import FCN_Encoder
from dan.ocr.predict.attention import ( from dan.ocr.predict.attention import (
Level, Level,
...@@ -54,6 +54,7 @@ class DAN: ...@@ -54,6 +54,7 @@ class DAN:
params_path: Path, params_path: Path,
charset_path: Path, charset_path: Path,
mode: str = "eval", mode: str = "eval",
use_language_model: bool = False,
) -> None: ) -> None:
""" """
Load a trained model. Load a trained model.
...@@ -61,6 +62,7 @@ class DAN: ...@@ -61,6 +62,7 @@ class DAN:
:param params_path: Path to the parameters. :param params_path: Path to the parameters.
:param charset_path: Path to the charset. :param charset_path: Path to the charset.
:param mode: The mode to load the model (train or eval). :param mode: The mode to load the model (train or eval).
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
""" """
parameters = yaml.safe_load(params_path.read_text())["parameters"] parameters = yaml.safe_load(params_path.read_text())["parameters"]
parameters["decoder"]["device"] = self.device parameters["decoder"]["device"] = self.device
...@@ -75,6 +77,7 @@ class DAN: ...@@ -75,6 +77,7 @@ class DAN:
decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device) decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True) decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
logger.debug(f"Loaded model {model_path}") logger.debug(f"Loaded model {model_path}")
if mode == "train": if mode == "train":
...@@ -88,6 +91,16 @@ class DAN: ...@@ -88,6 +91,16 @@ class DAN:
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.lm_decoder = None
if use_language_model:
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
)
self.mean, self.std = ( self.mean, self.std = (
torch.tensor(parameters["mean"]) / 255, torch.tensor(parameters["mean"]) / 255,
torch.tensor(parameters["std"]) / 255, torch.tensor(parameters["std"]) / 255,
...@@ -125,6 +138,7 @@ class DAN: ...@@ -125,6 +138,7 @@ class DAN:
threshold_method: str = "otsu", threshold_method: str = "otsu",
threshold_value: int = 0, threshold_value: int = 0,
max_object_height: int = 50, max_object_height: int = 50,
use_language_model: bool = False,
) -> dict: ) -> dict:
""" """
Run prediction on an input image. Run prediction on an input image.
...@@ -162,6 +176,13 @@ class DAN: ...@@ -162,6 +176,13 @@ class DAN:
(batch_size,), dtype=torch.int, device=self.device (batch_size,), dtype=torch.int, device=self.device
) )
# end token index will be used for ctc
tot_pred = torch.zeros(
(batch_size, len(self.charset) + 1, self.max_chars),
dtype=torch.float,
device=self.device,
)
whole_output = list() whole_output = list()
confidence_scores = list() confidence_scores = list()
attention_maps = list() attention_maps = list()
...@@ -192,6 +213,9 @@ class DAN: ...@@ -192,6 +213,9 @@ class DAN:
num_pred=1, num_pred=1,
) )
# output total logit prediction
tot_pred[:, :, i : i + 1] = pred
pred = pred / self.temperature pred = pred / self.temperature
whole_output.append(output) whole_output.append(output)
attention_maps.append(weights) attention_maps.append(weights)
...@@ -242,6 +266,8 @@ class DAN: ...@@ -242,6 +266,8 @@ class DAN:
out = {} out = {}
out["text"] = predicted_text out["text"] = predicted_text
if use_language_model:
out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
if confidences: if confidences:
out["confidences"] = confidence_scores out["confidences"] = confidence_scores
if attentions: if attentions:
...@@ -296,6 +322,7 @@ def process_batch( ...@@ -296,6 +322,7 @@ def process_batch(
max_object_height: int, max_object_height: int,
tokens: Dict[str, EntityType], tokens: Dict[str, EntityType],
start_token: str, start_token: str,
use_language_model: bool,
) -> None: ) -> None:
input_images, visu_images, input_sizes = [], [], [] input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...") logger.info("Loading images...")
...@@ -330,6 +357,7 @@ def process_batch( ...@@ -330,6 +357,7 @@ def process_batch(
threshold_value=threshold_value, threshold_value=threshold_value,
max_object_height=max_object_height, max_object_height=max_object_height,
start_token=start_token, start_token=start_token,
use_language_model=use_language_model,
) )
logger.info("Prediction parsing...") logger.info("Prediction parsing...")
...@@ -337,6 +365,14 @@ def process_batch( ...@@ -337,6 +365,14 @@ def process_batch(
predicted_text = prediction["text"][idx] predicted_text = prediction["text"][idx]
result = {"text": predicted_text} result = {"text": predicted_text}
# Return LM results
if use_language_model:
result["language_model"] = {}
result["language_model"]["text"] = prediction["language_model"]["text"][idx]
result["language_model"]["confidence"] = prediction["language_model"][
"confidence"
][idx]
# Return extracted objects (coordinates, text, confidence) # Return extracted objects (coordinates, text, confidence)
if predict_objects: if predict_objects:
result["objects"] = prediction["objects"][idx] result["objects"] = prediction["objects"][idx]
...@@ -435,6 +471,7 @@ def run( ...@@ -435,6 +471,7 @@ def run(
batch_size: int, batch_size: int,
tokens: Dict[str, EntityType], tokens: Dict[str, EntityType],
start_token: str, start_token: str,
use_language_model: bool,
) -> None: ) -> None:
""" """
Predict a single image save the output Predict a single image save the output
...@@ -458,6 +495,7 @@ def run( ...@@ -458,6 +495,7 @@ def run(
:param batch_size: Size of the batches for prediction. :param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used. :param tokens: NER tokens used.
:param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
""" """
# Create output directory if necessary # Create output directory if necessary
if not output.exists(): if not output.exists():
...@@ -467,7 +505,10 @@ def run( ...@@ -467,7 +505,10 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else "" cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature) dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval") dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
batch_size = 1 if use_language_model else batch_size
images = image_dir.rglob(f"*{image_extension}") if not image else [image] images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size): for image_batch in list_to_batches(images, n=batch_size):
...@@ -489,4 +530,5 @@ def run( ...@@ -489,4 +530,5 @@ def run(
max_object_height, max_object_height,
tokens, tokens,
start_token, start_token,
use_language_model,
) )
...@@ -16,6 +16,26 @@ class MLflowNotInstalled(Exception): ...@@ -16,6 +16,26 @@ class MLflowNotInstalled(Exception):
""" """
class Token(NamedTuple):
encoded: str
display: str
class LMTokenMapping(NamedTuple):
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
unknown: Token = Token("", "<unk>")
@property
def display(self):
return {a.encoded: a.display for a in self}
@property
def encode(self):
return {a.display: a.encoded for a in self}
class EntityType(NamedTuple): class EntityType(NamedTuple):
start: str start: str
end: str = "" end: str = ""
...@@ -137,3 +157,12 @@ def read_json(json_path: str) -> Dict: ...@@ -137,3 +157,12 @@ def read_json(json_path: str) -> Dict:
return json.loads(filename.read_text()) return json.loads(filename.read_text())
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ArgumentTypeError(e) raise ArgumentTypeError(e)
def read_txt(txt_path: str) -> str:
"""
Read TXT file
"""
filename = Path(txt_path)
assert filename.exists(), f"{txt_path} does not resolve."
return filename.read_text()
...@@ -13,5 +13,6 @@ teklia-line-image-extractor==0.2.8rc4 ...@@ -13,5 +13,6 @@ teklia-line-image-extractor==0.2.8rc4
tenacity==8.2.3 tenacity==8.2.3
tensorboard==2.12.2 tensorboard==2.12.2
torch==2.0.0 torch==2.0.0
torchaudio==2.0.1
torchvision==0.15.1 torchvision==0.15.1
tqdm==4.65.0 tqdm==4.65.0
...@@ -10,13 +10,20 @@ import pytest ...@@ -10,13 +10,20 @@ import pytest
from dan.datasets.extract.exceptions import NoEndTokenError from dan.datasets.extract.exceptions import NoEndTokenError
from dan.datasets.extract.extract import ArkindexExtractor from dan.datasets.extract.extract import ArkindexExtractor
from dan.datasets.extract.utils import EntityType, insert_token, remove_spaces from dan.datasets.extract.utils import (
EntityType,
insert_token,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import parse_tokens from dan.utils import parse_tokens
from tests import FIXTURES from tests import FIXTURES
EXTRACTION_DATA_PATH = FIXTURES / "extraction" EXTRACTION_DATA_PATH = FIXTURES / "extraction"
TWO_SPACES_REGEX = re.compile(r" {2}") TWO_SPACES_REGEX = re.compile(r" {2}")
ENTITY_TOKEN_SPACE = re.compile(r"[ⓢ|ⓕ|ⓑ] ")
TWO_SPACES_LM_REGEX = re.compile(r"⎵ ⎵")
# NamedTuple to mock actual database result # NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str) Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
...@@ -135,8 +142,25 @@ def test_reconstruct_text(entity_separators, tokens, expected, text_before, text ...@@ -135,8 +142,25 @@ def test_reconstruct_text(entity_separators, tokens, expected, text_before, text
(" \ttab", "tab"), (" \ttab", "tab"),
), ),
) )
def test_remove_spaces(text, trimmed): def test_normalize_spaces(text, trimmed):
assert remove_spaces(text) == trimmed assert normalize_spaces(text) == trimmed
@pytest.mark.parametrize(
"text,trimmed",
(
("no_linebreaks", "no_linebreaks"),
("\nbeginning", "beginning"),
("ending\n", "ending"),
("\nboth\n", "both"),
("\n\n\nconsecutive", "consecutive"),
("\rcarriage_return", "carriage_return"),
("\r\ncarriage_return+linebreak", "carriage_return+linebreak"),
("\n\r\r\n\ncarriage_return+linebreak", "carriage_return+linebreak"),
),
)
def test_normalize_linebreaks(text, trimmed):
assert normalize_linebreaks(text) == trimmed
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -306,6 +330,9 @@ def test_extract( ...@@ -306,6 +330,9 @@ def test_extract(
VAL_DIR / "text_line_val-page_1-line_2.jpg", VAL_DIR / "text_line_val-page_1-line_2.jpg",
VAL_DIR / "text_line_val-page_1-line_3.jpg", VAL_DIR / "text_line_val-page_1-line_3.jpg",
output / "labels.json", output / "labels.json",
output / "language_corpus.txt",
output / "language_lexicon.txt",
output / "language_tokens.txt",
] ]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
...@@ -402,6 +429,47 @@ def test_extract( ...@@ -402,6 +429,47 @@ def test_extract(
expected_charset.update(tokens) expected_charset.update(tokens)
assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset
# Check "language_corpus.txt"
expected_language_corpus = """ⓢ C a i l l e t ⎵ ⎵ ⓕ M a u r i c e ⎵ ⎵ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ⎵ ⎵ ⓕ M a r c e l ⎵ ⎵ ⓑ 1 0 . 8 . 0 6
ⓢ R o q u e s ⎵ ⎵ ⓕ E l o i ⎵ ⎵ ⓑ 1 1 . 1 0 . 0 4
ⓢ G i r o s ⎵ ⎵ ⓕ P a u l ⎵ ⎵ ⓑ 3 0 . 1 0 . 1 0"""
# Transcriptions with worker version are in lowercase
if transcription_entities_worker_version:
expected_language_corpus = expected_language_corpus.lower()
# If we do not load entities, remove tokens
if not load_entities:
token_translations = {f"{token} ": "" for token in tokens}
expected_language_corpus = ENTITY_TOKEN_SPACE.sub("", expected_language_corpus)
# Replace double spaces with regular space
if not keep_spaces:
expected_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_language_corpus
)
assert (output / "language_corpus.txt").read_text() == expected_language_corpus
# Check "language_tokens.txt"
expected_language_tokens = [
t if t != " " else "" for t in sorted(list(expected_charset))
]
expected_language_tokens.append("")
assert (output / "language_tokens.txt").read_text() == "\n".join(
expected_language_tokens
)
# Check "language_lexicon.txt"
expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (output / "language_lexicon.txt").read_text() == "\n".join(
expected_language_lexicon
)
# Check cropped images # Check cropped images
for expected_path in expected_paths: for expected_path in expected_paths:
if expected_path.suffix != ".jpg": if expected_path.suffix != ".jpg":
......
...@@ -319,6 +319,7 @@ def test_run_prediction( ...@@ -319,6 +319,7 @@ def test_run_prediction(
batch_size=1, batch_size=1,
tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"), tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
start_token=None, start_token=None,
use_language_model=False,
) )
prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text()) prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
...@@ -518,6 +519,7 @@ def test_run_prediction_batch( ...@@ -518,6 +519,7 @@ def test_run_prediction_batch(
batch_size=batch_size, batch_size=batch_size,
tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"), tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
start_token=None, start_token=None,
use_language_model=False,
) )
for image_name, expected_prediction in zip(image_names, expected_predictions): for image_name, expected_prediction in zip(image_names, expected_predictions):
......