From 504de3b61a68414b3697f2427cdff8d786e66f6d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Wed, 27 Sep 2023 11:12:30 +0200
Subject: [PATCH] Improve code

---
 dan/ocr/predict/prediction.py | 9 ++++-----
 dan/utils.py                  | 2 +-
 2 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py
index b2b4c8ca..4ce2898f 100644
--- a/dan/ocr/predict/prediction.py
+++ b/dan/ocr/predict/prediction.py
@@ -360,11 +360,10 @@ def process_batch(
 
         # Return LM results
         if use_language_model:
-            result["language_model"] = {}
-            result["language_model"]["text"] = prediction["language_model"]["text"][idx]
-            result["language_model"]["confidence"] = prediction["language_model"][
-                "confidence"
-            ][idx]
+            result["language_model"] = {
+                "text": prediction["language_model"]["text"][idx],
+                "confidence": prediction["language_model"]["confidence"][idx],
+            }
 
         # Return extracted objects (coordinates, text, confidence)
         if predict_objects:
diff --git a/dan/utils.py b/dan/utils.py
index fcd7e7af..c65df263 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -36,7 +36,7 @@ class LMTokenMapping(NamedTuple):
         return {a.display: a.encoded for a in self}
 
     def encode_token(self, token: str) -> str:
-        return self.encode[token] if token in self.encode else token
+        return self.encode.get(token, token)
 
 
 class EntityType(NamedTuple):
-- 
GitLab