From deb015a0ca79efc54285042d952a338ab6d7521b Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Mon, 23 Oct 2023 15:02:12 +0000
Subject: [PATCH] Support the case where DAN does not output anything

---
 dan/ocr/predict/inference.py | 108 +++++++++++++++++------------------
 tests/test_prediction.py     |  53 ++++++++++++++---
 2 files changed, 100 insertions(+), 61 deletions(-)

diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index 0719f492..e33d0ba2 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -356,62 +356,62 @@ def process_batch(
     logger.info("Prediction parsing...")
     for idx, image_path in enumerate(image_batch):
         predicted_text = prediction["text"][idx]
-        result = {"text": predicted_text}
-
-        # Return LM results
-        if use_language_model:
-            result["language_model"] = {
-                "text": prediction["language_model"]["text"][idx],
-                "confidence": prediction["language_model"]["confidence"][idx],
-            }
-
-        # Return extracted objects (coordinates, text, confidence)
-        if predict_objects:
-            result["objects"] = prediction["objects"][idx]
-
-        # Return mean confidence score
-        if confidence_score:
-            result["confidences"] = {}
-            char_confidences = prediction["confidences"][idx]
-            result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
-
-            for level in confidence_score_levels:
-                result["confidences"][level.value] = []
-                texts, confidences, _ = split_text_and_confidences(
-                    predicted_text,
-                    char_confidences,
-                    level,
-                    word_separators,
-                    line_separators,
-                    tokens,
-                )
-
-                for text, conf in zip(texts, confidences):
-                    result["confidences"][level.value].append(
-                        {"text": text, "confidence": conf}
+        result = {"text": predicted_text, "confidences": {}, "language_model": {}}
+
+        if predicted_text:
+            # Return LM results
+            if use_language_model:
+                result["language_model"] = {
+                    "text": prediction["language_model"]["text"][idx],
+                    "confidence": prediction["language_model"]["confidence"][idx],
+                }
+
+            # Return extracted objects (coordinates, text, confidence)
+            if predict_objects:
+                result["objects"] = prediction["objects"][idx]
+
+            # Return mean confidence score
+            if confidence_score:
+                char_confidences = prediction["confidences"][idx]
+                result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
+
+                for level in confidence_score_levels:
+                    result["confidences"][level.value] = []
+                    texts, confidences, _ = split_text_and_confidences(
+                        predicted_text,
+                        char_confidences,
+                        level,
+                        word_separators,
+                        line_separators,
+                        tokens,
                     )
 
-        # Save gif with attention map
-        if attention_map:
-            attentions = prediction["attentions"][idx]
-            gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
-            logger.info(f"Creating attention GIF in {gif_filename}")
-            plot_attention(
-                image=visu_tensor[idx],
-                text=predicted_text,
-                weights=attentions,
-                level=attention_map_level,
-                scale=attention_map_scale,
-                word_separators=word_separators,
-                line_separators=line_separators,
-                tokens=tokens,
-                display_polygons=predict_objects,
-                threshold_method=threshold_method,
-                threshold_value=threshold_value,
-                max_object_height=max_object_height,
-                outname=gif_filename,
-            )
-            result["attention_gif"] = gif_filename
+                    for text, conf in zip(texts, confidences):
+                        result["confidences"][level.value].append(
+                            {"text": text, "confidence": conf}
+                        )
+
+            # Save gif with attention map
+            if attention_map:
+                attentions = prediction["attentions"][idx]
+                gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
+                logger.info(f"Creating attention GIF in {gif_filename}")
+                plot_attention(
+                    image=visu_tensor[idx],
+                    text=predicted_text,
+                    weights=attentions,
+                    level=attention_map_level,
+                    scale=attention_map_scale,
+                    word_separators=word_separators,
+                    line_separators=line_separators,
+                    tokens=tokens,
+                    display_polygons=predict_objects,
+                    threshold_method=threshold_method,
+                    threshold_value=threshold_value,
+                    max_object_height=max_object_height,
+                    outname=gif_filename,
+                )
+                result["attention_gif"] = gif_filename
 
         json_filename = Path(output, image_path.stem).with_suffix(".json")
         logger.info(f"Saving JSON prediction in {json_filename}")
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index 1b6bb607..30576b1b 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -67,7 +67,11 @@ def test_predict(image_name, expected_prediction):
             "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
             None,
             1.0,
-            {"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"},
+            {
+                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "language_model": {},
+                "confidences": {},
+            },
         ),
         (
             "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
@@ -75,6 +79,7 @@ def test_predict(image_name, expected_prediction):
             1.0,
             {
                 "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "language_model": {},
                 "confidences": {
                     "total": 1.0,
                     "word": [
@@ -96,6 +101,7 @@ def test_predict(image_name, expected_prediction):
             3.5,
             {
                 "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "language_model": {},
                 "confidences": {
                     "total": 0.93,
                     "ner": [
@@ -127,6 +133,7 @@ def test_predict(image_name, expected_prediction):
             1.0,
             {
                 "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "language_model": {},
                 "confidences": {
                     "total": 1.0,
                     "line": [
@@ -144,6 +151,7 @@ def test_predict(image_name, expected_prediction):
             3.5,
             {
                 "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                "language_model": {},
                 "confidences": {
                     "total": 0.93,
                     "ner": [
@@ -169,7 +177,11 @@ def test_predict(image_name, expected_prediction):
             "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
             None,
             1.0,
-            {"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"},
+            {
+                "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
+                "language_model": {},
+                "confidences": {},
+            },
         ),
         (
             "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
@@ -177,6 +189,7 @@ def test_predict(image_name, expected_prediction):
             1.0,
             {
                 "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
+                "language_model": {},
                 "confidences": {
                     "total": 1.0,
                     "ner": [
@@ -260,13 +273,21 @@ def test_predict(image_name, expected_prediction):
             "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
             False,
             1.0,
-            {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
+            {
+                "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
+                "language_model": {},
+                "confidences": {},
+            },
         ),
         (
             "ffdec445-7f14-4f5f-be44-68d0844d0df1",
             False,
             1.0,
-            {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
+            {
+                "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
+                "language_model": {},
+                "confidences": {},
+            },
         ),
     ),
 )
@@ -315,7 +336,13 @@ def test_run_prediction(
             ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
             None,
             1.0,
-            [{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"}],
+            [
+                {
+                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "language_model": {},
+                    "confidences": {},
+                }
+            ],
         ),
         (
             ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
@@ -324,6 +351,7 @@ def test_run_prediction(
             [
                 {
                     "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "language_model": {},
                     "confidences": {
                         "total": 1.0,
                         "word": [
@@ -350,6 +378,7 @@ def test_run_prediction(
             [
                 {
                     "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "language_model": {},
                     "confidences": {
                         "total": 1.0,
                         "ner": [
@@ -376,6 +405,7 @@ def test_run_prediction(
                 },
                 {
                     "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "language_model": {},
                     "confidences": {
                         "total": 1.0,
                         "ner": [
@@ -409,6 +439,7 @@ def test_run_prediction(
             [
                 {
                     "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "language_model": {},
                     "confidences": {
                         "total": 1.0,
                         "word": [
@@ -433,8 +464,16 @@ def test_run_prediction(
             False,
             1.0,
             [
-                {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
-                {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
+                {
+                    "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
+                    "language_model": {},
+                    "confidences": {},
+                },
+                {
+                    "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
+                    "language_model": {},
+                    "confidences": {},
+                },
             ],
         ),
     ),
-- 
GitLab