From 6fb9efdff94de69b9d516f6d8d08e390108695f8 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 3 Sep 2024 08:32:04 +0000
Subject: [PATCH] Font parameters should only be required if we draw the
 attention map

---
 dan/ocr/predict/inference.py | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index 34bb1da1..87ba0c77 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -341,8 +341,6 @@ def process_batch(
     image_batch: List[Path],
     dan_model: DAN,
     device: str,
-    font: Path,
-    maximum_font_size: int,
     output: Path,
     confidence_score: bool,
     confidence_score_levels: List[Level],
@@ -357,6 +355,8 @@ def process_batch(
     max_object_height: int,
     tokens: Dict[str, EntityType],
     start_token: str,
+    font: Path | None = None,
+    maximum_font_size: int | None = None,
 ) -> None:
     input_images, visu_images, input_sizes = [], [], []
     logger.info("Loading images...")
@@ -436,7 +436,7 @@ def process_batch(
                         )
 
             # Save gif with attention map
-            if attention_map:
+            if attention_map and font and maximum_font_size:
                 attentions = prediction["attentions"][idx]
                 gif_filename = (
                     f"{output}/{image_path.stem}_{attention_map_level.value}.gif"
@@ -538,10 +538,12 @@ def run(
         dynamic_mode=dynamic_mode,
     )
 
-    try:
-        load_font(font, maximum_font_size)
-    except OSError:
-        raise FileNotFoundError(f"The font file is missing at path {str(font)}")
+    # Load font if the attention map is drawn
+    if attention_map:
+        try:
+            load_font(font, maximum_font_size)
+        except OSError:
+            raise FileNotFoundError(f"The font file is missing at path `{str(font)}`")
 
     images = image_dir.rglob(f"*{image_extension}")
     for image_batch in list_to_batches(images, n=batch_size):
@@ -549,8 +551,6 @@ def run(
             image_batch,
             dan_model,
             device,
-            font,
-            maximum_font_size,
             output,
             confidence_score,
             confidence_score_levels,
@@ -565,4 +565,6 @@ def run(
             max_object_height,
             tokens,
             start_token,
+            font,
+            maximum_font_size,
         )
-- 
GitLab