From 9d99164bfff53996574737f699fee479a158f113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Wed, 2 Aug 2023 15:51:30 +0200 Subject: [PATCH] Apply d580a769 --- dan/predict/attention.py | 10 +- dan/predict/prediction.py | 4 +- tests/test_prediction.py | 227 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 236 insertions(+), 5 deletions(-) diff --git a/dan/predict/attention.py b/dan/predict/attention.py index f1c4def0..0b06425f 100644 --- a/dan/predict/attention.py +++ b/dan/predict/attention.py @@ -70,14 +70,18 @@ def split_text_and_confidences( texts = list(text) offset = 0 elif level == "word": - texts, probs = compute_prob_by_separator(text, confidences, word_separators) + texts, confidences = compute_prob_by_separator( + text, confidences, word_separators + ) offset = 1 elif level == "line": - texts, probs = compute_prob_by_separator(text, confidences, line_separators) + texts, confidences = compute_prob_by_separator( + text, confidences, line_separators + ) offset = 1 else: logger.error("Level should be either 'char', 'word', or 'line'") - return texts, [np.around(num, 2) for num in probs], offset + return texts, [np.around(num, 2) for num in confidences], offset def get_predicted_polygons_with_confidence( diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 3726679e..0db9b6c9 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -282,9 +282,9 @@ def process_image( logger.debug("Image pre-processed.") # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 - input_tensor = im_p.unsqueeze(0) + input_tensor = torch.from_numpy(im_p).permute(2, 0, 1).unsqueeze(0) input_tensor = input_tensor.to(device) - input_sizes = [im_p.shape[1:]] + input_sizes = [im_p.shape[:2]] # Parse delimiters to regex word_separators = parse_delimiters(word_separators) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 1ad7c510..1a920275 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,9 +1,12 @@ # -*- coding: utf-8 -*- +import json + import pytest import torch from dan.predict.prediction import DAN +from dan.predict.prediction import run as run_prediction from dan.utils import read_image @@ -53,3 +56,227 @@ def test_predict( prediction = dan_model.predict(input_tensor, input_sizes) assert prediction == expected_prediction + + +@pytest.mark.parametrize( + "image_name, confidence_score, temperature, expected_prediction", + ( + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", + None, + 1.0, + {"text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241"}, + ), + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", + ["word"], + 1.0, + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 1.0, + "word": [ + {"text": "ⓈBellisson", "confidence": 1.0}, + {"text": "â’»Georges", "confidence": 1.0}, + {"text": "â’·91", "confidence": 1.0}, + {"text": "â“P", "confidence": 1.0}, + {"text": "â’¸M", "confidence": 1.0}, + {"text": "â“€Ch", "confidence": 1.0}, + {"text": "â“„Plombier", "confidence": 1.0}, + {"text": "â“…Patron?12241", "confidence": 1.0}, + ], + }, + }, + ), + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", + ["word"], + 3.5, + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 0.93, + "word": [ + {"text": "ⓈBellisson", "confidence": 0.93}, + {"text": "â’»Georges", "confidence": 0.94}, + {"text": "â’·91", "confidence": 0.92}, + {"text": "â“P", "confidence": 0.94}, + {"text": "â’¸M", "confidence": 0.93}, + {"text": "â“€Ch", "confidence": 0.96}, + {"text": "â“„Plombier", "confidence": 0.94}, + {"text": "â“…Patron?12241", "confidence": 0.93}, + ], + }, + }, + ), + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", + ["line"], + 1.0, + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 1.0, + "line": [ + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidence": 1.0, + } + ], + }, + }, + ), + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", + ["line"], + 3.5, + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 0.93, + "line": [ + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidence": 0.93, + } + ], + }, + }, + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e", + None, + 1.0, + {"text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376"}, + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e", + ["char", "word", "line"], + 1.0, + { + "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", + "confidences": { + "by ner token": [], + "total": 1.0, + "char": [ + {"text": "Ⓢ", "confidence": 1.0}, + {"text": "T", "confidence": 1.0}, + {"text": "e", "confidence": 1.0}, + {"text": "m", "confidence": 1.0}, + {"text": "p", "confidence": 1.0}, + {"text": "l", "confidence": 1.0}, + {"text": "i", "confidence": 1.0}, + {"text": "é", "confidence": 0.85}, + {"text": " ", "confidence": 1.0}, + {"text": "â’»", "confidence": 1.0}, + {"text": "M", "confidence": 1.0}, + {"text": "a", "confidence": 1.0}, + {"text": "r", "confidence": 1.0}, + {"text": "c", "confidence": 1.0}, + {"text": "e", "confidence": 1.0}, + {"text": "l", "confidence": 1.0}, + {"text": "l", "confidence": 1.0}, + {"text": "e", "confidence": 1.0}, + {"text": " ", "confidence": 1.0}, + {"text": "â’·", "confidence": 1.0}, + {"text": "9", "confidence": 1.0}, + {"text": "3", "confidence": 1.0}, + {"text": " ", "confidence": 1.0}, + {"text": "â“", "confidence": 1.0}, + {"text": "S", "confidence": 1.0}, + {"text": " ", "confidence": 1.0}, + {"text": "â“€", "confidence": 1.0}, + {"text": "c", "confidence": 1.0}, + {"text": "h", "confidence": 1.0}, + {"text": " ", "confidence": 1.0}, + {"text": "â“„", "confidence": 1.0}, + {"text": "E", "confidence": 1.0}, + {"text": " ", "confidence": 1.0}, + {"text": "d", "confidence": 1.0}, + {"text": "a", "confidence": 1.0}, + {"text": "c", "confidence": 1.0}, + {"text": "t", "confidence": 1.0}, + {"text": "y", "confidence": 1.0}, + {"text": "l", "confidence": 1.0}, + {"text": "o", "confidence": 1.0}, + {"text": " ", "confidence": 1.0}, + {"text": "â“…", "confidence": 1.0}, + {"text": "1", "confidence": 1.0}, + {"text": "8", "confidence": 1.0}, + {"text": "3", "confidence": 1.0}, + {"text": "7", "confidence": 1.0}, + {"text": "6", "confidence": 1.0}, + ], + "word": [ + {"text": "ⓈTemplié", "confidence": 0.98}, + {"text": "â’»Marcelle", "confidence": 1.0}, + {"text": "â’·93", "confidence": 1.0}, + {"text": "â“S", "confidence": 1.0}, + {"text": "â“€ch", "confidence": 1.0}, + {"text": "â“„E", "confidence": 1.0}, + {"text": "dactylo", "confidence": 1.0}, + {"text": "â“…18376", "confidence": 1.0}, + ], + "line": [ + { + "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", + "confidence": 1.0, + } + ], + }, + }, + ), + ( + "2c242f5c-e979-43c4-b6f2-a6d4815b651d", + False, + 1.0, + {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31"}, + ), + ( + "ffdec445-7f14-4f5f-be44-68d0844d0df1", + False, + 1.0, + {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, + ), + ), +) +def test_run_prediction( + image_name, + confidence_score, + temperature, + expected_prediction, + prediction_data_path, + tmp_path, +): + run_prediction( + image=(prediction_data_path / "images" / image_name).with_suffix(".png"), + image_dir=None, + model=prediction_data_path / "popp_line_model.pt", + parameters=prediction_data_path / "parameters.yml", + charset=prediction_data_path / "charset.pkl", + output=tmp_path, + scale=1, + confidence_score=True if confidence_score else False, + confidence_score_levels=confidence_score if confidence_score else [], + attention_map=False, + attention_map_level=None, + attention_map_scale=0.5, + word_separators=[" ", "\n"], + line_separators=["\n"], + temperature=temperature, + image_max_width=None, + predict_objects=False, + threshold_method="otsu", + threshold_value=0, + image_extension=None, + gpu_device=None, + ) + + with (tmp_path / image_name).with_suffix(".json").open("r") as json_file: + prediction = json.load(json_file) + + assert prediction == expected_prediction -- GitLab