Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (6)
......@@ -106,9 +106,10 @@ class ArkindexExtractor:
raise NoTranscriptionError(element.id)
transcription = random.choice(transcriptions)
stripped_text = transcription.text.strip()
if not self.tokens:
return transcription.text.strip()
return stripped_text
entities = get_transcription_entities(
transcription.id,
......@@ -116,6 +117,9 @@ class ArkindexExtractor:
supported_types=list(self.tokens),
)
if not entities.count():
return stripped_text
return self.translate(
entities_to_xml(
transcription.text, entities, entity_separators=self.entity_separators
......
......@@ -32,6 +32,7 @@ class Inference(NamedTuple):
image: str
ground_truth: str
prediction: str
lm_prediction: str
wer: float
......
......@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
......@@ -31,7 +33,10 @@ if MLFLOW_AVAILABLE:
import mlflow
logger = logging.getLogger(__name__)
MODEL_NAMES = ("encoder", "decoder")
MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
class GenericTrainingManager:
......@@ -69,6 +74,14 @@ class GenericTrainingManager:
self.init_paths()
self.load_dataset()
@property
def encoder(self) -> FCN_Encoder | None:
return self.models.get(MODEL_NAME_ENCODER)
@property
def decoder(self) -> GlobalHTADecoder | None:
return self.models.get(MODEL_NAME_DECODER)
def init_paths(self):
"""
Create output folders for results and checkpoints
......@@ -183,6 +196,28 @@ class GenericTrainingManager:
output_device=self.ddp_config["rank"],
)
# Instantiate LM decoder
self.lm_decoder = None
if self.params["model"].get("lm") and self.params["model"]["lm"]["weight"] > 0:
logger.info(
f"Decoding with a language model (weight={self.params['model']['lm']['weight']})."
)
# Check files
model_path = self.params["model"]["lm"]["path"]
assert model_path.is_file(), f"File {model_path} not found"
base_path = model_path.parent
lexicon_path = base_path / "lexicon.txt"
assert lexicon_path.is_file(), f"File {lexicon_path} not found"
tokens_path = base_path / "tokens.txt"
assert tokens_path.is_file(), f"File {tokens_path} not found"
# Load LM decoder
self.lm_decoder = CTCLanguageDecoder(
language_model_path=str(model_path),
lexicon_path=str(lexicon_path),
tokens_path=str(tokens_path),
language_model_weight=self.params["model"]["lm"]["weight"],
)
# Handle curriculum dropout
self.dropout_scheduler = DropoutScheduler(self.models)
......@@ -804,6 +839,7 @@ class GenericTrainingManager:
batch_data["names"],
batch_values["str_y"],
batch_values["str_x"],
batch_values.get("str_lm", repeat("")),
repeat(display_values["wer"]),
)
)
......@@ -985,20 +1021,18 @@ class Manager(GenericTrainingManager):
hidden_predict = None
cache = None
features = self.models["encoder"](batch_data["imgs"].to(self.device))
features = self.encoder(batch_data["imgs"].to(self.device))
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
simulated_y_pred[:, :-1],
[s[:2] for s in batch_data["imgs_reduced_shape"]],
......@@ -1049,6 +1083,13 @@ class Manager(GenericTrainingManager):
)
predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)
# end token index will be used for ctc
tot_pred = torch.zeros(
(b, len(self.dataset.charset) + 1, max_chars),
dtype=torch.float,
device=self.device,
)
whole_output = list()
confidence_scores = list()
cache = None
......@@ -1058,7 +1099,7 @@ class Manager(GenericTrainingManager):
for i in range(b):
pos = batch_data["imgs_position"]
features_list.append(
self.models["encoder"](
self.encoder(
x[
i : i + 1,
:,
......@@ -1079,21 +1120,19 @@ class Manager(GenericTrainingManager):
i, :, : features_list[i].size(2), : features_list[i].size(3)
] = features_list[i]
else:
features = self.models["encoder"](x)
features = self.encoder(x)
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
predicted_tokens,
[s[:2] for s in batch_data["imgs_reduced_shape"]],
......@@ -1104,6 +1143,10 @@ class Manager(GenericTrainingManager):
cache=cache,
num_pred=1,
)
# output total logit prediction
tot_pred[:, :, i : i + 1] = pred
whole_output.append(output)
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
......@@ -1150,4 +1193,7 @@ class Manager(GenericTrainingManager):
"confidence_score": confidence_scores,
"time": process_time,
}
if self.lm_decoder:
values["str_lm"] = self.lm_decoder(tot_pred, prediction_len)["text"]
return values
......@@ -37,6 +37,10 @@ def update_config(config: dict):
# .model.decoder.class = GlobalHTADecoder
config["model"]["decoder"]["class"] = GlobalHTADecoder
# .model.lm.path to Path
if config["model"].get("lm", {}).get("path"):
config["model"]["lm"]["path"] = Path(config["model"]["lm"]["path"])
# Update preprocessing type
for prepro in config["training"]["data"]["preprocessings"]:
prepro["type"] = Preprocessing(prepro["type"])
......
......@@ -41,4 +41,4 @@ To train a DAN model, please refer to the [documentation of the training command
## 3. Predict
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`.
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`.
......@@ -166,14 +166,13 @@ It will create the following JSON file named after the image and a GIF showing a
This example assumes that you have already [trained a language model](../train/language_model.md).
Note that:
- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
!!! note
- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
#### Language model at character level
First, update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......@@ -185,8 +184,6 @@ parameters:
weight: 0.5
```
Note that the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
Then, run this command:
```shell
......@@ -211,7 +208,7 @@ It will create the following JSON file named after the image in the `predict_cha
#### Language model at subword level
Update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......@@ -247,7 +244,7 @@ It will create the following JSON file named after the image in the `predict_sub
#### Language model at word level
Update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......
This diff is collapsed.
......@@ -24,6 +24,6 @@ parameters:
max_width: 1500
language_model:
model: tests/data/prediction/language_model.arpa
lexicon: tests/data/prediction/language_lexicon.txt
tokens: tests/data/prediction/language_tokens.txt
lexicon: tests/data/prediction/lexicon.txt
tokens: tests/data/prediction/tokens.txt
weight: 1.0
......@@ -8,9 +8,12 @@ import yaml
from prettytable import PrettyTable
from dan.ocr import evaluate
from dan.ocr.manager.metrics import Inference
from dan.ocr.utils import add_metrics_table_row, create_metrics_table
from tests import FIXTURES
PREDICTION_DATA_PATH = FIXTURES / "prediction"
def test_create_metrics_table():
metric_names = ["ignored", "wer", "cer", "time", "ner"]
......@@ -115,14 +118,228 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
/ f"predict_training-{split_name}_1685.yaml"
)
with filename.open() as f:
assert {
metric: value
for metric, value in yaml.safe_load(filename.read_bytes()).items()
# Remove the times from the results as they vary
res = {
if "time" not in metric
} == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[10:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
)
@pytest.mark.parametrize(
"language_model_weight, expected_inferences",
(
(
0.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"", # LM prediction
0.1429, # WER
),
],
),
(
1.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # LM prediction
0.1429, # WER
),
],
),
(
2.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # LM prediction
0.1429, # WER
),
],
),
),
)
def test_evaluate_language_model(
capsys, evaluate_config, language_model_weight, expected_inferences, monkeypatch
):
# LM predictions are never used/displayed
# We mock the `Inference` class to temporary check the results
global nb_inferences
nb_inferences = 0
class MockInference(Inference):
def __new__(cls, *args, **kwargs):
global nb_inferences
assert args == expected_inferences[nb_inferences]
nb_inferences += 1
return super().__new__(cls, *args, **kwargs)
monkeypatch.setattr("dan.ocr.manager.training.Inference", MockInference)
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
# Use a LM decoder
evaluate_config["model"]["lm"] = {
"path": PREDICTION_DATA_PATH / "language_model.arpa",
"weight": language_model_weight,
}
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)
# Check that the evaluation results are correct
for split_name, expected_res in [
(
"train",
{
"nb_chars": 90,
"cer": 0.1889,
"nb_chars_no_token": 76,
"cer_no_token": 0.2105,
"nb_words": 15,
"wer": 0.2667,
"nb_words_no_punct": 15,
"wer_no_punct": 0.2667,
"nb_words_no_token": 15,
"wer_no_token": 0.2667,
"nb_tokens": 14,
"ner": 0.0714,
"nb_samples": 2,
},
),
(
"val",
{
"nb_chars": 34,
"cer": 0.0882,
"nb_chars_no_token": 26,
"cer_no_token": 0.1154,
"nb_words": 8,
"wer": 0.5,
"nb_words_no_punct": 8,
"wer_no_punct": 0.5,
"nb_words_no_token": 8,
"wer_no_token": 0.5,
"nb_tokens": 8,
"ner": 0.0,
"nb_samples": 1,
},
),
(
"test",
{
"nb_chars": 36,
"cer": 0.0278,
"nb_chars_no_token": 30,
"cer_no_token": 0.0333,
"nb_words": 7,
"wer": 0.1429,
"nb_words_no_punct": 7,
"wer_no_punct": 0.1429,
"nb_words_no_token": 7,
"wer_no_token": 0.1429,
"nb_tokens": 6,
"ner": 0.0,
"nb_samples": 1,
},
),
]:
filename = (
evaluate_config["training"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_1685.yaml"
)
with filename.open() as f:
assert {
metric: value
for metric, value in yaml.safe_load(f).items()
# Remove the times from the results as they vary
if "time" not in metric
}
assert res == expected_res
} == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
......
......@@ -425,6 +425,32 @@ def test_empty_transcription(allow_empty, mock_database):
extractor.extract_transcription(element_no_transcription)
@pytest.mark.parametrize("tokens", (None, EXTRACTION_DATA_PATH / "tokens.yml"))
def test_extract_transcription_no_translation(mock_database, tokens):
extractor = ArkindexExtractor(
element_type=["text_line"],
entity_separators=None,
tokens=tokens,
)
element = Element.get_by_id("test-page_1-line_1")
# Deleting one of the two transcriptions from the element
Transcription.get(
Transcription.element == element,
Transcription.worker_version_id == "worker_version_id",
).delete_instance(recursive=True)
# Deleting all entities on the element remaining transcription while leaving the transcription intact
if tokens:
TranscriptionEntity.delete().where(
TranscriptionEntity.transcription
== Transcription.select().where(Transcription.element == element).get()
).execute()
# Early return with only the element transcription text instead of a translation
assert extractor.extract_transcription(element) == "Coupez Bouis 7.12.14"
@pytest.mark.parametrize(
"nestation, xml_output, separators",
(
......