diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index b5e94a0af3186c92fd56242b1ddb99f789c15c7b..a5d3fc6ab878b5aeaaf4091d7401adaef85b8a90 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -24,11 +24,10 @@ from dan.datasets.extract.exceptions import ( ProcessingError, ) from dan.datasets.extract.utils import ( - EntityType, download_image, insert_token, - parse_tokens, ) +from dan.utils import EntityType, parse_tokens from line_image_extractor.extractor import extract, read_img, save_img from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index c39c7dca7b371b26228eab16c3928d888e8c4dcc..5efcf6eaadc20d0ef0839d35b595441098385512 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -1,11 +1,8 @@ # -*- coding: utf-8 -*- import logging from io import BytesIO -from pathlib import Path -from typing import NamedTuple import requests -import yaml from PIL import Image from tenacity import ( retry, @@ -14,6 +11,8 @@ from tenacity import ( wait_exponential, ) +from dan.utils import EntityType + logger = logging.getLogger(__name__) # See http://docs.python-requests.org/en/master/user/advanced/#timeouts @@ -27,15 +26,6 @@ def _retry_log(retry_state, *args, **kwargs): ) -class EntityType(NamedTuple): - start: str - end: str = "" - - @property - def offset(self): - return len(self.start) + len(self.end) - - @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2), @@ -80,10 +70,3 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - # End token + (entity_type.end if entity_type else "") ) - - -def parse_tokens(filename: Path) -> dict: - return { - name: EntityType(**tokens) - for name, tokens in yaml.safe_load(filename.read_text()).items() - } diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index 015c023408e809aedffd9eb781bfe40f206552e0..f0c266752bc940247792b1d7f232daafa41e98fd 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -7,7 +7,7 @@ from typing import Optional import editdistance import numpy as np -from dan.datasets.extract.utils import parse_tokens +from dan.utils import parse_tokens class MetricManager: diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index a37596488463b8956407f314b294a2bd7a454600..c1ed1a037ec863460928ca200bc2e83371066593 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -11,7 +11,6 @@ import torch import yaml from dan import logger -from dan.datasets.extract.utils import parse_tokens from dan.ocr.decoder import GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder from dan.ocr.predict.attention import ( @@ -21,7 +20,13 @@ from dan.ocr.predict.attention import ( split_text_and_confidences, ) from dan.ocr.transforms import get_preprocessing_transforms -from dan.utils import ind_to_token, list_to_batches, pad_images, read_image +from dan.utils import ( + ind_to_token, + list_to_batches, + pad_images, + parse_tokens, + read_image, +) class DAN: diff --git a/dan/utils.py b/dan/utils.py index d58db8fa088d29e08dae619f64f01dd017791b94..b5d53cd2d5acf1ca157bbd465849b0d2de17e15a 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- from itertools import islice +from pathlib import Path +from typing import NamedTuple import torch import torchvision.io as torchvision +import yaml class MLflowNotInstalled(Exception): @@ -11,6 +14,15 @@ class MLflowNotInstalled(Exception): """ +class EntityType(NamedTuple): + start: str + end: str = "" + + @property + def offset(self): + return len(self.start) + len(self.end) + + def pad_sequences_1D(data, padding_value): """ Pad data with padding_value to get same length @@ -92,3 +104,10 @@ def list_to_batches(iterable, n): it = iter(iterable) while batch := tuple(islice(it, n)): yield batch + + +def parse_tokens(filename: Path) -> dict: + return { + name: EntityType(**tokens) + for name, tokens in yaml.safe_load(filename.read_text()).items() + }