From deb015a0ca79efc54285042d952a338ab6d7521b Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Mon, 23 Oct 2023 15:02:12 +0000 Subject: [PATCH] Support the case where DAN does not output anything --- dan/ocr/predict/inference.py | 108 +++++++++++++++++------------------ tests/test_prediction.py | 53 ++++++++++++++--- 2 files changed, 100 insertions(+), 61 deletions(-) diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index 0719f492..e33d0ba2 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -356,62 +356,62 @@ def process_batch( logger.info("Prediction parsing...") for idx, image_path in enumerate(image_batch): predicted_text = prediction["text"][idx] - result = {"text": predicted_text} - - # Return LM results - if use_language_model: - result["language_model"] = { - "text": prediction["language_model"]["text"][idx], - "confidence": prediction["language_model"]["confidence"][idx], - } - - # Return extracted objects (coordinates, text, confidence) - if predict_objects: - result["objects"] = prediction["objects"][idx] - - # Return mean confidence score - if confidence_score: - result["confidences"] = {} - char_confidences = prediction["confidences"][idx] - result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) - - for level in confidence_score_levels: - result["confidences"][level.value] = [] - texts, confidences, _ = split_text_and_confidences( - predicted_text, - char_confidences, - level, - word_separators, - line_separators, - tokens, - ) - - for text, conf in zip(texts, confidences): - result["confidences"][level.value].append( - {"text": text, "confidence": conf} + result = {"text": predicted_text, "confidences": {}, "language_model": {}} + + if predicted_text: + # Return LM results + if use_language_model: + result["language_model"] = { + "text": prediction["language_model"]["text"][idx], + "confidence": prediction["language_model"]["confidence"][idx], + } + + # Return extracted objects (coordinates, text, confidence) + if predict_objects: + result["objects"] = prediction["objects"][idx] + + # Return mean confidence score + if confidence_score: + char_confidences = prediction["confidences"][idx] + result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) + + for level in confidence_score_levels: + result["confidences"][level.value] = [] + texts, confidences, _ = split_text_and_confidences( + predicted_text, + char_confidences, + level, + word_separators, + line_separators, + tokens, ) - # Save gif with attention map - if attention_map: - attentions = prediction["attentions"][idx] - gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif" - logger.info(f"Creating attention GIF in {gif_filename}") - plot_attention( - image=visu_tensor[idx], - text=predicted_text, - weights=attentions, - level=attention_map_level, - scale=attention_map_scale, - word_separators=word_separators, - line_separators=line_separators, - tokens=tokens, - display_polygons=predict_objects, - threshold_method=threshold_method, - threshold_value=threshold_value, - max_object_height=max_object_height, - outname=gif_filename, - ) - result["attention_gif"] = gif_filename + for text, conf in zip(texts, confidences): + result["confidences"][level.value].append( + {"text": text, "confidence": conf} + ) + + # Save gif with attention map + if attention_map: + attentions = prediction["attentions"][idx] + gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif" + logger.info(f"Creating attention GIF in {gif_filename}") + plot_attention( + image=visu_tensor[idx], + text=predicted_text, + weights=attentions, + level=attention_map_level, + scale=attention_map_scale, + word_separators=word_separators, + line_separators=line_separators, + tokens=tokens, + display_polygons=predict_objects, + threshold_method=threshold_method, + threshold_value=threshold_value, + max_object_height=max_object_height, + outname=gif_filename, + ) + result["attention_gif"] = gif_filename json_filename = Path(output, image_path.stem).with_suffix(".json") logger.info(f"Saving JSON prediction in {json_filename}") diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 1b6bb607..30576b1b 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -67,7 +67,11 @@ def test_predict(image_name, expected_prediction): "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", None, 1.0, - {"text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241"}, + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, + "confidences": {}, + }, ), ( "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", @@ -75,6 +79,7 @@ def test_predict(image_name, expected_prediction): 1.0, { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 1.0, "word": [ @@ -96,6 +101,7 @@ def test_predict(image_name, expected_prediction): 3.5, { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 0.93, "ner": [ @@ -127,6 +133,7 @@ def test_predict(image_name, expected_prediction): 1.0, { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 1.0, "line": [ @@ -144,6 +151,7 @@ def test_predict(image_name, expected_prediction): 3.5, { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 0.93, "ner": [ @@ -169,7 +177,11 @@ def test_predict(image_name, expected_prediction): "0dfe8bcd-ed0b-453e-bf19-cc697012296e", None, 1.0, - {"text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376"}, + { + "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", + "language_model": {}, + "confidences": {}, + }, ), ( "0dfe8bcd-ed0b-453e-bf19-cc697012296e", @@ -177,6 +189,7 @@ def test_predict(image_name, expected_prediction): 1.0, { "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", + "language_model": {}, "confidences": { "total": 1.0, "ner": [ @@ -260,13 +273,21 @@ def test_predict(image_name, expected_prediction): "2c242f5c-e979-43c4-b6f2-a6d4815b651d", False, 1.0, - {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31"}, + { + "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", + "language_model": {}, + "confidences": {}, + }, ), ( "ffdec445-7f14-4f5f-be44-68d0844d0df1", False, 1.0, - {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, + { + "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", + "language_model": {}, + "confidences": {}, + }, ), ), ) @@ -315,7 +336,13 @@ def test_run_prediction( ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"], None, 1.0, - [{"text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241"}], + [ + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, + "confidences": {}, + } + ], ), ( ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"], @@ -324,6 +351,7 @@ def test_run_prediction( [ { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 1.0, "word": [ @@ -350,6 +378,7 @@ def test_run_prediction( [ { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 1.0, "ner": [ @@ -376,6 +405,7 @@ def test_run_prediction( }, { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 1.0, "ner": [ @@ -409,6 +439,7 @@ def test_run_prediction( [ { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "language_model": {}, "confidences": { "total": 1.0, "word": [ @@ -433,8 +464,16 @@ def test_run_prediction( False, 1.0, [ - {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31"}, - {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, + { + "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", + "language_model": {}, + "confidences": {}, + }, + { + "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", + "language_model": {}, + "confidences": {}, + }, ], ), ), -- GitLab