Skip to content
Snippets Groups Projects
Commit d36473ff authored by Mélodie's avatar Mélodie
Browse files

Apply c8bd6c91

parent 2cbbfb7a
No related branches found
No related tags found
No related merge requests found
......@@ -9,8 +9,7 @@ import torch
from torch import randint
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.ocr.utils import LM_str_to_ind
from dan.utils import pad_image, pad_images, pad_sequences_1D
from dan.utils import pad_image, pad_images, pad_sequences_1D, token_to_ind
class OCRDatasetManager(DatasetManager):
......@@ -174,7 +173,7 @@ class OCRDataset(GenericDataset):
full_label = label
sample["label"] = full_label
sample["token_label"] = LM_str_to_ind(self.charset, full_label)
sample["token_label"] = token_to_ind(self.charset, full_label)
if "add_eot" in self.params["config"]["constraints"]:
sample["token_label"].append(self.tokens["end"])
sample["label_len"] = len(sample["token_label"])
......
......@@ -20,8 +20,8 @@ from tqdm import tqdm
from dan.manager.metrics import MetricManager
from dan.manager.ocr import OCRDatasetManager
from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler
from dan.utils import ind_to_token
if MLFLOW_AVAILABLE:
import mlflow
......@@ -1010,7 +1010,7 @@ class Manager(OCRManager):
predicted_tokens = torch.argmax(pred, dim=1).detach().cpu().numpy()
predicted_tokens = [predicted_tokens[i, : y_len[i]] for i in range(b)]
str_x = [
LM_ind_to_str(self.dataset.charset, t, oov_symbol="")
ind_to_token(self.dataset.charset, t, oov_symbol="")
for t in predicted_tokens
]
......@@ -1130,7 +1130,7 @@ class Manager(OCRManager):
confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
]
str_x = [
LM_ind_to_str(self.dataset.charset, t, oov_symbol="")
ind_to_token(self.dataset.charset, t, oov_symbol="")
for t in predicted_tokens
]
......
# -*- coding: utf-8 -*-
# Charset / labels conversion
def LM_str_to_ind(labels, str):
return [labels.index(c) for c in str]
def LM_ind_to_str(labels, ind, oov_symbol=None):
if oov_symbol is not None:
res = []
for i in ind:
if i < len(labels):
res.append(labels[i])
else:
res.append(oov_symbol)
else:
res = [labels[i] for i in ind]
return "".join(res)
......@@ -14,14 +14,13 @@ from dan import logger
from dan.datasets.extract.utils import save_json
from dan.decoder import GlobalHTADecoder
from dan.encoder import FCN_Encoder
from dan.ocr.utils import LM_ind_to_str
from dan.predict.attention import (
get_predicted_polygons_with_confidence,
parse_delimiters,
plot_attention,
split_text_and_confidences,
)
from dan.utils import read_image
from dan.utils import ind_to_token, read_image
class DAN:
......@@ -221,7 +220,7 @@ class DAN:
# Transform tokens to characters
predicted_text = [
LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
ind_to_token(self.charset, t, oov_symbol="") for t in predicted_tokens
]
logger.info("Images processed")
......
......@@ -123,3 +123,28 @@ def read_image(filename, scale=1.0):
height = int(image.shape[0] * scale)
image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
return image
def round_floats(float_list, decimals=2):
"""
Round list of floats with fixed decimals
"""
return [np.around(num, decimals) for num in float_list]
# Charset / labels conversion
def token_to_ind(labels, str):
return [labels.index(c) for c in str]
def ind_to_token(labels, ind, oov_symbol=None):
if oov_symbol is not None:
res = []
for i in ind:
if i < len(labels):
res.append(labels[i])
else:
res.append(oov_symbol)
else:
res = [labels[i] for i in ind]
return "".join(res)
# Utils
::: dan.ocr.utils
......@@ -86,7 +86,6 @@ nav:
- Training managers: ref/managers/training.md
- OCR:
- ref/ocr/index.md
- Utils: ref/ocr/utils.md
- Document:
- ref/ocr/document/index.md
- Training: ref/ocr/document/train.md
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment