From 504de3b61a68414b3697f2427cdff8d786e66f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Wed, 27 Sep 2023 11:12:30 +0200 Subject: [PATCH] Improve code --- dan/ocr/predict/prediction.py | 9 ++++----- dan/utils.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index b2b4c8ca..4ce2898f 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -360,11 +360,10 @@ def process_batch( # Return LM results if use_language_model: - result["language_model"] = {} - result["language_model"]["text"] = prediction["language_model"]["text"][idx] - result["language_model"]["confidence"] = prediction["language_model"][ - "confidence" - ][idx] + result["language_model"] = { + "text": prediction["language_model"]["text"][idx], + "confidence": prediction["language_model"]["confidence"][idx], + } # Return extracted objects (coordinates, text, confidence) if predict_objects: diff --git a/dan/utils.py b/dan/utils.py index fcd7e7af..c65df263 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -36,7 +36,7 @@ class LMTokenMapping(NamedTuple): return {a.display: a.encoded for a in self} def encode_token(self, token: str) -> str: - return self.encode[token] if token in self.encode else token + return self.encode.get(token, token) class EntityType(NamedTuple): -- GitLab