From bb09ae0c91ea7155a9da8913aee18d7d6c0757a2 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 16 Aug 2023 06:50:47 +0000
Subject: [PATCH] Move parse tokens function

---
 dan/datasets/extract/extract.py |  3 +--
 dan/datasets/extract/utils.py   | 21 ++-------------------
 dan/ocr/manager/metrics.py      |  2 +-
 dan/ocr/predict/prediction.py   |  9 +++++++--
 dan/utils.py                    | 19 +++++++++++++++++++
 5 files changed, 30 insertions(+), 24 deletions(-)

diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index b5e94a0a..a5d3fc6a 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -24,11 +24,10 @@ from dan.datasets.extract.exceptions import (
     ProcessingError,
 )
 from dan.datasets.extract.utils import (
-    EntityType,
     download_image,
     insert_token,
-    parse_tokens,
 )
+from dan.utils import EntityType, parse_tokens
 from line_image_extractor.extractor import extract, read_img, save_img
 from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
 
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index c39c7dca..5efcf6ea 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -1,11 +1,8 @@
 # -*- coding: utf-8 -*-
 import logging
 from io import BytesIO
-from pathlib import Path
-from typing import NamedTuple
 
 import requests
-import yaml
 from PIL import Image
 from tenacity import (
     retry,
@@ -14,6 +11,8 @@ from tenacity import (
     wait_exponential,
 )
 
+from dan.utils import EntityType
+
 logger = logging.getLogger(__name__)
 
 # See http://docs.python-requests.org/en/master/user/advanced/#timeouts
@@ -27,15 +26,6 @@ def _retry_log(retry_state, *args, **kwargs):
     )
 
 
-class EntityType(NamedTuple):
-    start: str
-    end: str = ""
-
-    @property
-    def offset(self):
-        return len(self.start) + len(self.end)
-
-
 @retry(
     stop=stop_after_attempt(3),
     wait=wait_exponential(multiplier=2),
@@ -80,10 +70,3 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
         # End token
         + (entity_type.end if entity_type else "")
     )
-
-
-def parse_tokens(filename: Path) -> dict:
-    return {
-        name: EntityType(**tokens)
-        for name, tokens in yaml.safe_load(filename.read_text()).items()
-    }
diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py
index 015c0234..f0c26675 100644
--- a/dan/ocr/manager/metrics.py
+++ b/dan/ocr/manager/metrics.py
@@ -7,7 +7,7 @@ from typing import Optional
 import editdistance
 import numpy as np
 
-from dan.datasets.extract.utils import parse_tokens
+from dan.utils import parse_tokens
 
 
 class MetricManager:
diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py
index a3759648..c1ed1a03 100644
--- a/dan/ocr/predict/prediction.py
+++ b/dan/ocr/predict/prediction.py
@@ -11,7 +11,6 @@ import torch
 import yaml
 
 from dan import logger
-from dan.datasets.extract.utils import parse_tokens
 from dan.ocr.decoder import GlobalHTADecoder
 from dan.ocr.encoder import FCN_Encoder
 from dan.ocr.predict.attention import (
@@ -21,7 +20,13 @@ from dan.ocr.predict.attention import (
     split_text_and_confidences,
 )
 from dan.ocr.transforms import get_preprocessing_transforms
-from dan.utils import ind_to_token, list_to_batches, pad_images, read_image
+from dan.utils import (
+    ind_to_token,
+    list_to_batches,
+    pad_images,
+    parse_tokens,
+    read_image,
+)
 
 
 class DAN:
diff --git a/dan/utils.py b/dan/utils.py
index d58db8fa..b5d53cd2 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -1,8 +1,11 @@
 # -*- coding: utf-8 -*-
 from itertools import islice
+from pathlib import Path
+from typing import NamedTuple
 
 import torch
 import torchvision.io as torchvision
+import yaml
 
 
 class MLflowNotInstalled(Exception):
@@ -11,6 +14,15 @@ class MLflowNotInstalled(Exception):
     """
 
 
+class EntityType(NamedTuple):
+    start: str
+    end: str = ""
+
+    @property
+    def offset(self):
+        return len(self.start) + len(self.end)
+
+
 def pad_sequences_1D(data, padding_value):
     """
     Pad data with padding_value to get same length
@@ -92,3 +104,10 @@ def list_to_batches(iterable, n):
     it = iter(iterable)
     while batch := tuple(islice(it, n)):
         yield batch
+
+
+def parse_tokens(filename: Path) -> dict:
+    return {
+        name: EntityType(**tokens)
+        for name, tokens in yaml.safe_load(filename.read_text()).items()
+    }
-- 
GitLab