diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index 34bb1da1c7c2484e3cab194123613ed8b5ba7444..87ba0c7717e7ac478428b7edac6ea38bb2666c0d 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -341,8 +341,6 @@ def process_batch( image_batch: List[Path], dan_model: DAN, device: str, - font: Path, - maximum_font_size: int, output: Path, confidence_score: bool, confidence_score_levels: List[Level], @@ -357,6 +355,8 @@ def process_batch( max_object_height: int, tokens: Dict[str, EntityType], start_token: str, + font: Path | None = None, + maximum_font_size: int | None = None, ) -> None: input_images, visu_images, input_sizes = [], [], [] logger.info("Loading images...") @@ -436,7 +436,7 @@ def process_batch( ) # Save gif with attention map - if attention_map: + if attention_map and font and maximum_font_size: attentions = prediction["attentions"][idx] gif_filename = ( f"{output}/{image_path.stem}_{attention_map_level.value}.gif" @@ -538,10 +538,12 @@ def run( dynamic_mode=dynamic_mode, ) - try: - load_font(font, maximum_font_size) - except OSError: - raise FileNotFoundError(f"The font file is missing at path {str(font)}") + # Load font if the attention map is drawn + if attention_map: + try: + load_font(font, maximum_font_size) + except OSError: + raise FileNotFoundError(f"The font file is missing at path `{str(font)}`") images = image_dir.rglob(f"*{image_extension}") for image_batch in list_to_batches(images, n=batch_size): @@ -549,8 +551,6 @@ def run( image_batch, dan_model, device, - font, - maximum_font_size, output, confidence_score, confidence_score_levels, @@ -565,4 +565,6 @@ def run( max_object_height, tokens, start_token, + font, + maximum_font_size, )