Skip to content
Snippets Groups Projects
Commit bb09ae0c authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Move parse tokens function

parent 08a96bf6
No related branches found
No related tags found
1 merge request!240Move parse tokens function
......@@ -24,11 +24,10 @@ from dan.datasets.extract.exceptions import (
ProcessingError,
)
from dan.datasets.extract.utils import (
EntityType,
download_image,
insert_token,
parse_tokens,
)
from dan.utils import EntityType, parse_tokens
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
......
# -*- coding: utf-8 -*-
import logging
from io import BytesIO
from pathlib import Path
from typing import NamedTuple
import requests
import yaml
from PIL import Image
from tenacity import (
retry,
......@@ -14,6 +11,8 @@ from tenacity import (
wait_exponential,
)
from dan.utils import EntityType
logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
......@@ -27,15 +26,6 @@ def _retry_log(retry_state, *args, **kwargs):
)
class EntityType(NamedTuple):
start: str
end: str = ""
@property
def offset(self):
return len(self.start) + len(self.end)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2),
......@@ -80,10 +70,3 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
# End token
+ (entity_type.end if entity_type else "")
)
def parse_tokens(filename: Path) -> dict:
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
}
......@@ -7,7 +7,7 @@ from typing import Optional
import editdistance
import numpy as np
from dan.datasets.extract.utils import parse_tokens
from dan.utils import parse_tokens
class MetricManager:
......
......@@ -11,7 +11,6 @@ import torch
import yaml
from dan import logger
from dan.datasets.extract.utils import parse_tokens
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.predict.attention import (
......@@ -21,7 +20,13 @@ from dan.ocr.predict.attention import (
split_text_and_confidences,
)
from dan.ocr.transforms import get_preprocessing_transforms
from dan.utils import ind_to_token, list_to_batches, pad_images, read_image
from dan.utils import (
ind_to_token,
list_to_batches,
pad_images,
parse_tokens,
read_image,
)
class DAN:
......
# -*- coding: utf-8 -*-
from itertools import islice
from pathlib import Path
from typing import NamedTuple
import torch
import torchvision.io as torchvision
import yaml
class MLflowNotInstalled(Exception):
......@@ -11,6 +14,15 @@ class MLflowNotInstalled(Exception):
"""
class EntityType(NamedTuple):
start: str
end: str = ""
@property
def offset(self):
return len(self.start) + len(self.end)
def pad_sequences_1D(data, padding_value):
"""
Pad data with padding_value to get same length
......@@ -92,3 +104,10 @@ def list_to_batches(iterable, n):
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
def parse_tokens(filename: Path) -> dict:
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment