From d580a769396d8fd1ed073a511e781c78215f1b6b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Mon, 10 Jul 2023 15:54:16 +0200
Subject: [PATCH] Add tests for prediction code

---
 dan/predict/attention.py |  10 +-
 tests/test_prediction.py | 227 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 234 insertions(+), 3 deletions(-)

diff --git a/dan/predict/attention.py b/dan/predict/attention.py
index f1c4def0..0b06425f 100644
--- a/dan/predict/attention.py
+++ b/dan/predict/attention.py
@@ -70,14 +70,18 @@ def split_text_and_confidences(
         texts = list(text)
         offset = 0
     elif level == "word":
-        texts, probs = compute_prob_by_separator(text, confidences, word_separators)
+        texts, confidences = compute_prob_by_separator(
+            text, confidences, word_separators
+        )
         offset = 1
     elif level == "line":
-        texts, probs = compute_prob_by_separator(text, confidences, line_separators)
+        texts, confidences = compute_prob_by_separator(
+            text, confidences, line_separators
+        )
         offset = 1
     else:
         logger.error("Level should be either 'char', 'word', or 'line'")
-    return texts, [np.around(num, 2) for num in probs], offset
+    return texts, [np.around(num, 2) for num in confidences], offset
 
 
 def get_predicted_polygons_with_confidence(
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index 28df938b..d677d63f 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -1,8 +1,11 @@
 # -*- coding: utf-8 -*-
 
+import json
+
 import pytest
 
 from dan.predict.prediction import DAN
+from dan.predict.prediction import run as run_prediction
 from dan.utils import read_image
 
 
@@ -52,3 +55,227 @@ def test_predict(
     prediction = dan_model.predict(input_tensor, input_sizes)
 
     assert prediction == expected_prediction
+
+
+@pytest.mark.parametrize(
+    "image_name, confidence_score, temperature, expected_prediction",
+    (
+        (
+            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
+            None,
+            1.0,
+            {"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"},
+        ),
+        (
+            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
+            ["word"],
+            1.0,
+            {
+                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "confidences": {
+                    "by ner token": [],
+                    "total": 1.0,
+                    "word": [
+                        {"text": "ⓈBellisson", "confidence": 1.0},
+                        {"text": "â’»Georges", "confidence": 1.0},
+                        {"text": "â’·91", "confidence": 1.0},
+                        {"text": "ⓁP", "confidence": 1.0},
+                        {"text": "â’¸M", "confidence": 1.0},
+                        {"text": "â“€Ch", "confidence": 1.0},
+                        {"text": "â“„Plombier", "confidence": 1.0},
+                        {"text": "â“…Patron?12241", "confidence": 1.0},
+                    ],
+                },
+            },
+        ),
+        (
+            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
+            ["word"],
+            3.5,
+            {
+                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "confidences": {
+                    "by ner token": [],
+                    "total": 0.93,
+                    "word": [
+                        {"text": "ⓈBellisson", "confidence": 0.92},
+                        {"text": "â’»Georges", "confidence": 0.94},
+                        {"text": "â’·91", "confidence": 0.92},
+                        {"text": "ⓁP", "confidence": 0.94},
+                        {"text": "â’¸M", "confidence": 0.93},
+                        {"text": "â“€Ch", "confidence": 0.96},
+                        {"text": "â“„Plombier", "confidence": 0.94},
+                        {"text": "â“…Patron?12241", "confidence": 0.93},
+                    ],
+                },
+            },
+        ),
+        (
+            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
+            ["line"],
+            1.0,
+            {
+                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "confidences": {
+                    "by ner token": [],
+                    "total": 1.0,
+                    "line": [
+                        {
+                            "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                            "confidence": 1.0,
+                        }
+                    ],
+                },
+            },
+        ),
+        (
+            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
+            ["line"],
+            3.5,
+            {
+                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "confidences": {
+                    "by ner token": [],
+                    "total": 0.93,
+                    "line": [
+                        {
+                            "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                            "confidence": 0.93,
+                        }
+                    ],
+                },
+            },
+        ),
+        (
+            "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
+            None,
+            1.0,
+            {"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"},
+        ),
+        (
+            "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
+            ["char", "word", "line"],
+            1.0,
+            {
+                "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
+                "confidences": {
+                    "by ner token": [],
+                    "total": 1.0,
+                    "char": [
+                        {"text": "Ⓢ", "confidence": 1.0},
+                        {"text": "T", "confidence": 1.0},
+                        {"text": "e", "confidence": 1.0},
+                        {"text": "m", "confidence": 1.0},
+                        {"text": "p", "confidence": 1.0},
+                        {"text": "l", "confidence": 1.0},
+                        {"text": "i", "confidence": 1.0},
+                        {"text": "é", "confidence": 0.86},
+                        {"text": " ", "confidence": 1.0},
+                        {"text": "â’»", "confidence": 1.0},
+                        {"text": "M", "confidence": 1.0},
+                        {"text": "a", "confidence": 1.0},
+                        {"text": "r", "confidence": 1.0},
+                        {"text": "c", "confidence": 1.0},
+                        {"text": "e", "confidence": 1.0},
+                        {"text": "l", "confidence": 1.0},
+                        {"text": "l", "confidence": 1.0},
+                        {"text": "e", "confidence": 1.0},
+                        {"text": " ", "confidence": 1.0},
+                        {"text": "â’·", "confidence": 1.0},
+                        {"text": "9", "confidence": 1.0},
+                        {"text": "3", "confidence": 1.0},
+                        {"text": " ", "confidence": 1.0},
+                        {"text": "Ⓛ", "confidence": 1.0},
+                        {"text": "S", "confidence": 1.0},
+                        {"text": " ", "confidence": 1.0},
+                        {"text": "â“€", "confidence": 1.0},
+                        {"text": "c", "confidence": 1.0},
+                        {"text": "h", "confidence": 1.0},
+                        {"text": " ", "confidence": 1.0},
+                        {"text": "â“„", "confidence": 1.0},
+                        {"text": "E", "confidence": 1.0},
+                        {"text": " ", "confidence": 1.0},
+                        {"text": "d", "confidence": 1.0},
+                        {"text": "a", "confidence": 1.0},
+                        {"text": "c", "confidence": 1.0},
+                        {"text": "t", "confidence": 1.0},
+                        {"text": "y", "confidence": 1.0},
+                        {"text": "l", "confidence": 1.0},
+                        {"text": "o", "confidence": 1.0},
+                        {"text": " ", "confidence": 1.0},
+                        {"text": "â“…", "confidence": 1.0},
+                        {"text": "1", "confidence": 1.0},
+                        {"text": "8", "confidence": 1.0},
+                        {"text": "3", "confidence": 1.0},
+                        {"text": "7", "confidence": 1.0},
+                        {"text": "6", "confidence": 1.0},
+                    ],
+                    "word": [
+                        {"text": "ⓈTemplié", "confidence": 0.98},
+                        {"text": "â’»Marcelle", "confidence": 1.0},
+                        {"text": "â’·93", "confidence": 1.0},
+                        {"text": "ⓁS", "confidence": 1.0},
+                        {"text": "â“€ch", "confidence": 1.0},
+                        {"text": "â“„E", "confidence": 1.0},
+                        {"text": "dactylo", "confidence": 1.0},
+                        {"text": "â“…18376", "confidence": 1.0},
+                    ],
+                    "line": [
+                        {
+                            "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
+                            "confidence": 1.0,
+                        }
+                    ],
+                },
+            },
+        ),
+        (
+            "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
+            False,
+            1.0,
+            {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
+        ),
+        (
+            "ffdec445-7f14-4f5f-be44-68d0844d0df1",
+            False,
+            1.0,
+            {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
+        ),
+    ),
+)
+def test_run_prediction(
+    image_name,
+    confidence_score,
+    temperature,
+    expected_prediction,
+    prediction_data_path,
+    tmp_path,
+):
+    run_prediction(
+        image=(prediction_data_path / "images" / image_name).with_suffix(".png"),
+        image_dir=None,
+        model=prediction_data_path / "popp_line_model.pt",
+        parameters=prediction_data_path / "parameters.yml",
+        charset=prediction_data_path / "charset.pkl",
+        output=tmp_path,
+        scale=1,
+        confidence_score=True if confidence_score else False,
+        confidence_score_levels=confidence_score if confidence_score else [],
+        attention_map=False,
+        attention_map_level=None,
+        attention_map_scale=0.5,
+        word_separators=[" ", "\n"],
+        line_separators=["\n"],
+        temperature=temperature,
+        image_max_width=None,
+        predict_objects=False,
+        threshold_method="otsu",
+        threshold_value=0,
+        image_extension=None,
+        gpu_device=None,
+    )
+
+    with (tmp_path / image_name).with_suffix(".json").open("r") as json_file:
+        prediction = json.load(json_file)
+
+    assert prediction == expected_prediction
-- 
GitLab