Skip to content
Snippets Groups Projects
Commit bb09ae0c authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Move parse tokens function

parent 08a96bf6
No related branches found
No related tags found
1 merge request!240Move parse tokens function
...@@ -24,11 +24,10 @@ from dan.datasets.extract.exceptions import ( ...@@ -24,11 +24,10 @@ from dan.datasets.extract.exceptions import (
ProcessingError, ProcessingError,
) )
from dan.datasets.extract.utils import ( from dan.datasets.extract.utils import (
EntityType,
download_image, download_image,
insert_token, insert_token,
parse_tokens,
) )
from dan.utils import EntityType, parse_tokens
from line_image_extractor.extractor import extract, read_img, save_img from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
from io import BytesIO from io import BytesIO
from pathlib import Path
from typing import NamedTuple
import requests import requests
import yaml
from PIL import Image from PIL import Image
from tenacity import ( from tenacity import (
retry, retry,
...@@ -14,6 +11,8 @@ from tenacity import ( ...@@ -14,6 +11,8 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from dan.utils import EntityType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts # See http://docs.python-requests.org/en/master/user/advanced/#timeouts
...@@ -27,15 +26,6 @@ def _retry_log(retry_state, *args, **kwargs): ...@@ -27,15 +26,6 @@ def _retry_log(retry_state, *args, **kwargs):
) )
class EntityType(NamedTuple):
start: str
end: str = ""
@property
def offset(self):
return len(self.start) + len(self.end)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2), wait=wait_exponential(multiplier=2),
...@@ -80,10 +70,3 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - ...@@ -80,10 +70,3 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
# End token # End token
+ (entity_type.end if entity_type else "") + (entity_type.end if entity_type else "")
) )
def parse_tokens(filename: Path) -> dict:
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
}
...@@ -7,7 +7,7 @@ from typing import Optional ...@@ -7,7 +7,7 @@ from typing import Optional
import editdistance import editdistance
import numpy as np import numpy as np
from dan.datasets.extract.utils import parse_tokens from dan.utils import parse_tokens
class MetricManager: class MetricManager:
......
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
import yaml import yaml
from dan import logger from dan import logger
from dan.datasets.extract.utils import parse_tokens
from dan.ocr.decoder import GlobalHTADecoder from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder from dan.ocr.encoder import FCN_Encoder
from dan.ocr.predict.attention import ( from dan.ocr.predict.attention import (
...@@ -21,7 +20,13 @@ from dan.ocr.predict.attention import ( ...@@ -21,7 +20,13 @@ from dan.ocr.predict.attention import (
split_text_and_confidences, split_text_and_confidences,
) )
from dan.ocr.transforms import get_preprocessing_transforms from dan.ocr.transforms import get_preprocessing_transforms
from dan.utils import ind_to_token, list_to_batches, pad_images, read_image from dan.utils import (
ind_to_token,
list_to_batches,
pad_images,
parse_tokens,
read_image,
)
class DAN: class DAN:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from itertools import islice from itertools import islice
from pathlib import Path
from typing import NamedTuple
import torch import torch
import torchvision.io as torchvision import torchvision.io as torchvision
import yaml
class MLflowNotInstalled(Exception): class MLflowNotInstalled(Exception):
...@@ -11,6 +14,15 @@ class MLflowNotInstalled(Exception): ...@@ -11,6 +14,15 @@ class MLflowNotInstalled(Exception):
""" """
class EntityType(NamedTuple):
start: str
end: str = ""
@property
def offset(self):
return len(self.start) + len(self.end)
def pad_sequences_1D(data, padding_value): def pad_sequences_1D(data, padding_value):
""" """
Pad data with padding_value to get same length Pad data with padding_value to get same length
...@@ -92,3 +104,10 @@ def list_to_batches(iterable, n): ...@@ -92,3 +104,10 @@ def list_to_batches(iterable, n):
it = iter(iterable) it = iter(iterable)
while batch := tuple(islice(it, n)): while batch := tuple(islice(it, n)):
yield batch yield batch
def parse_tokens(filename: Path) -> dict:
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment