diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index b70b9364bad373acf6ceb09bd5d4e5f3b453a700..3d8ebdd489abdd26ea53cea5ddb05fc520eb4e40 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -9,8 +9,7 @@ import torch from torch import randint from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing -from dan.ocr.utils import LM_str_to_ind -from dan.utils import pad_image, pad_images, pad_sequences_1D +from dan.utils import pad_image, pad_images, pad_sequences_1D, token_to_ind class OCRDatasetManager(DatasetManager): @@ -174,7 +173,7 @@ class OCRDataset(GenericDataset): full_label = label sample["label"] = full_label - sample["token_label"] = LM_str_to_ind(self.charset, full_label) + sample["token_label"] = token_to_ind(self.charset, full_label) if "add_eot" in self.params["config"]["constraints"]: sample["token_label"].append(self.tokens["end"]) sample["label_len"] = len(sample["token_label"]) diff --git a/dan/manager/training.py b/dan/manager/training.py index fa3fbe12a562c7445dc8a281fb80909f4993e3f7..27bdf224b3095a14d4269ee27b8a41a670226ac5 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -20,8 +20,8 @@ from tqdm import tqdm from dan.manager.metrics import MetricManager from dan.manager.ocr import OCRDatasetManager from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics -from dan.ocr.utils import LM_ind_to_str from dan.schedulers import DropoutScheduler +from dan.utils import ind_to_token if MLFLOW_AVAILABLE: import mlflow @@ -1010,7 +1010,7 @@ class Manager(OCRManager): predicted_tokens = torch.argmax(pred, dim=1).detach().cpu().numpy() predicted_tokens = [predicted_tokens[i, : y_len[i]] for i in range(b)] str_x = [ - LM_ind_to_str(self.dataset.charset, t, oov_symbol="") + ind_to_token(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens ] @@ -1130,7 +1130,7 @@ class Manager(OCRManager): confidence_scores[i, : prediction_len[i]].tolist() for i in range(b) ] str_x = [ - LM_ind_to_str(self.dataset.charset, t, oov_symbol="") + ind_to_token(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens ] diff --git a/dan/ocr/utils.py b/dan/ocr/utils.py deleted file mode 100644 index 8038f6f9cab2bc21b80fda7b36b0fea66b445c96..0000000000000000000000000000000000000000 --- a/dan/ocr/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- coding: utf-8 -*- -# Charset / labels conversion -def LM_str_to_ind(labels, str): - return [labels.index(c) for c in str] - - -def LM_ind_to_str(labels, ind, oov_symbol=None): - if oov_symbol is not None: - res = [] - for i in ind: - if i < len(labels): - res.append(labels[i]) - else: - res.append(oov_symbol) - else: - res = [labels[i] for i in ind] - return "".join(res) diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index b1e8a356edecaf980aa687b4fad15115c87f7462..357bb4df1854a64af8b0e65794960c1fc628e2a8 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -14,14 +14,13 @@ from dan import logger from dan.datasets.extract.utils import save_json from dan.decoder import GlobalHTADecoder from dan.encoder import FCN_Encoder -from dan.ocr.utils import LM_ind_to_str from dan.predict.attention import ( get_predicted_polygons_with_confidence, parse_delimiters, plot_attention, split_text_and_confidences, ) -from dan.utils import read_image +from dan.utils import ind_to_token, read_image class DAN: @@ -221,7 +220,7 @@ class DAN: # Transform tokens to characters predicted_text = [ - LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens + ind_to_token(self.charset, t, oov_symbol="") for t in predicted_tokens ] logger.info("Images processed") diff --git a/dan/utils.py b/dan/utils.py index e4f18b733b26d4d65c10ed7ac4ecc260c3fc7459..8bffd18a71b002651f5875dd5d37f2324d1ce015 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -123,3 +123,28 @@ def read_image(filename, scale=1.0): height = int(image.shape[0] * scale) image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) return image + + +def round_floats(float_list, decimals=2): + """ + Round list of floats with fixed decimals + """ + return [np.around(num, decimals) for num in float_list] + + +# Charset / labels conversion +def token_to_ind(labels, str): + return [labels.index(c) for c in str] + + +def ind_to_token(labels, ind, oov_symbol=None): + if oov_symbol is not None: + res = [] + for i in ind: + if i < len(labels): + res.append(labels[i]) + else: + res.append(oov_symbol) + else: + res = [labels[i] for i in ind] + return "".join(res) diff --git a/docs/ref/ocr/utils.md b/docs/ref/ocr/utils.md deleted file mode 100644 index adeec07327016e1240b904610c8b5aec282a720e..0000000000000000000000000000000000000000 --- a/docs/ref/ocr/utils.md +++ /dev/null @@ -1,3 +0,0 @@ -# Utils - -::: dan.ocr.utils diff --git a/mkdocs.yml b/mkdocs.yml index 2a5b6f42f6b39b719bb0115e69fc241003e6ca1f..8aa69a2efaf9e80ca4a4e470ccc4798a2b56dcf3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -86,7 +86,6 @@ nav: - Training managers: ref/managers/training.md - OCR: - ref/ocr/index.md - - Utils: ref/ocr/utils.md - Document: - ref/ocr/document/index.md - Training: ref/ocr/document/train.md