Skip to content
Snippets Groups Projects
Commit d36473ff authored by Mélodie's avatar Mélodie
Browse files

Apply c8bd6c91

parent 2cbbfb7a
No related branches found
No related tags found
No related merge requests found
...@@ -9,8 +9,7 @@ import torch ...@@ -9,8 +9,7 @@ import torch
from torch import randint from torch import randint
from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
from dan.ocr.utils import LM_str_to_ind from dan.utils import pad_image, pad_images, pad_sequences_1D, token_to_ind
from dan.utils import pad_image, pad_images, pad_sequences_1D
class OCRDatasetManager(DatasetManager): class OCRDatasetManager(DatasetManager):
...@@ -174,7 +173,7 @@ class OCRDataset(GenericDataset): ...@@ -174,7 +173,7 @@ class OCRDataset(GenericDataset):
full_label = label full_label = label
sample["label"] = full_label sample["label"] = full_label
sample["token_label"] = LM_str_to_ind(self.charset, full_label) sample["token_label"] = token_to_ind(self.charset, full_label)
if "add_eot" in self.params["config"]["constraints"]: if "add_eot" in self.params["config"]["constraints"]:
sample["token_label"].append(self.tokens["end"]) sample["token_label"].append(self.tokens["end"])
sample["label_len"] = len(sample["token_label"]) sample["label_len"] = len(sample["token_label"])
......
...@@ -20,8 +20,8 @@ from tqdm import tqdm ...@@ -20,8 +20,8 @@ from tqdm import tqdm
from dan.manager.metrics import MetricManager from dan.manager.metrics import MetricManager
from dan.manager.ocr import OCRDatasetManager from dan.manager.ocr import OCRDatasetManager
from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler from dan.schedulers import DropoutScheduler
from dan.utils import ind_to_token
if MLFLOW_AVAILABLE: if MLFLOW_AVAILABLE:
import mlflow import mlflow
...@@ -1010,7 +1010,7 @@ class Manager(OCRManager): ...@@ -1010,7 +1010,7 @@ class Manager(OCRManager):
predicted_tokens = torch.argmax(pred, dim=1).detach().cpu().numpy() predicted_tokens = torch.argmax(pred, dim=1).detach().cpu().numpy()
predicted_tokens = [predicted_tokens[i, : y_len[i]] for i in range(b)] predicted_tokens = [predicted_tokens[i, : y_len[i]] for i in range(b)]
str_x = [ str_x = [
LM_ind_to_str(self.dataset.charset, t, oov_symbol="") ind_to_token(self.dataset.charset, t, oov_symbol="")
for t in predicted_tokens for t in predicted_tokens
] ]
...@@ -1130,7 +1130,7 @@ class Manager(OCRManager): ...@@ -1130,7 +1130,7 @@ class Manager(OCRManager):
confidence_scores[i, : prediction_len[i]].tolist() for i in range(b) confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
] ]
str_x = [ str_x = [
LM_ind_to_str(self.dataset.charset, t, oov_symbol="") ind_to_token(self.dataset.charset, t, oov_symbol="")
for t in predicted_tokens for t in predicted_tokens
] ]
......
# -*- coding: utf-8 -*-
# Charset / labels conversion
def LM_str_to_ind(labels, str):
return [labels.index(c) for c in str]
def LM_ind_to_str(labels, ind, oov_symbol=None):
if oov_symbol is not None:
res = []
for i in ind:
if i < len(labels):
res.append(labels[i])
else:
res.append(oov_symbol)
else:
res = [labels[i] for i in ind]
return "".join(res)
...@@ -14,14 +14,13 @@ from dan import logger ...@@ -14,14 +14,13 @@ from dan import logger
from dan.datasets.extract.utils import save_json from dan.datasets.extract.utils import save_json
from dan.decoder import GlobalHTADecoder from dan.decoder import GlobalHTADecoder
from dan.encoder import FCN_Encoder from dan.encoder import FCN_Encoder
from dan.ocr.utils import LM_ind_to_str
from dan.predict.attention import ( from dan.predict.attention import (
get_predicted_polygons_with_confidence, get_predicted_polygons_with_confidence,
parse_delimiters, parse_delimiters,
plot_attention, plot_attention,
split_text_and_confidences, split_text_and_confidences,
) )
from dan.utils import read_image from dan.utils import ind_to_token, read_image
class DAN: class DAN:
...@@ -221,7 +220,7 @@ class DAN: ...@@ -221,7 +220,7 @@ class DAN:
# Transform tokens to characters # Transform tokens to characters
predicted_text = [ predicted_text = [
LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens ind_to_token(self.charset, t, oov_symbol="") for t in predicted_tokens
] ]
logger.info("Images processed") logger.info("Images processed")
......
...@@ -123,3 +123,28 @@ def read_image(filename, scale=1.0): ...@@ -123,3 +123,28 @@ def read_image(filename, scale=1.0):
height = int(image.shape[0] * scale) height = int(image.shape[0] * scale)
image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
return image return image
def round_floats(float_list, decimals=2):
"""
Round list of floats with fixed decimals
"""
return [np.around(num, decimals) for num in float_list]
# Charset / labels conversion
def token_to_ind(labels, str):
return [labels.index(c) for c in str]
def ind_to_token(labels, ind, oov_symbol=None):
if oov_symbol is not None:
res = []
for i in ind:
if i < len(labels):
res.append(labels[i])
else:
res.append(oov_symbol)
else:
res = [labels[i] for i in ind]
return "".join(res)
# Utils
::: dan.ocr.utils
...@@ -86,7 +86,6 @@ nav: ...@@ -86,7 +86,6 @@ nav:
- Training managers: ref/managers/training.md - Training managers: ref/managers/training.md
- OCR: - OCR:
- ref/ocr/index.md - ref/ocr/index.md
- Utils: ref/ocr/utils.md
- Document: - Document:
- ref/ocr/document/index.md - ref/ocr/document/index.md
- Training: ref/ocr/document/train.md - Training: ref/ocr/document/train.md
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment