From bbbeea3a3ed9bee12cbb00c8f9acccc19f9c933a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Wed, 24 May 2023 13:30:18 +0200 Subject: [PATCH] Rename keep_only_tokens to keep_only_ner_tokens --- dan/manager/metrics.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index fc202a21..565f3072 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -245,7 +245,7 @@ class MetricManager: pp_pred.append(pp_module.post_process(pred)) metrics["nb_pp_op_layout"].append(pp_module.num_op) metrics["nb_gt_layout_token"] = [ - len(keep_only_tokens(str_x, self.layout_tokens)) + len(keep_only_ner_tokens(str_x, self.layout_tokens)) for str_x in values["str_x"] ] edit_and_num_items = [ @@ -261,16 +261,16 @@ class MetricManager: return self.epoch_metrics[name] -def keep_only_tokens(str, tokens): +def keep_only_ner_tokens(str, tokens): """ - Remove all but layout tokens from string + Remove all but ner tokens from string """ return re.sub("([^" + tokens + "])", "", str) -def keep_all_but_tokens(str, tokens): +def keep_all_but_ner_tokens(str, tokens): """ - Remove all layout tokens from string + Remove all ner tokens from string """ return re.sub("([" + tokens + "])", "", str) @@ -309,7 +309,7 @@ def format_string_for_wer(str, layout_tokens, remove_punct=False): r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str ) # remove punctuation if layout_tokens is not None: - str = keep_all_but_tokens( + str = keep_all_but_ner_tokens( str, layout_tokens ) # remove layout tokens from metric str = re.sub("([ \n])+", " ", str).strip() # keep only one space character @@ -321,7 +321,7 @@ def format_string_for_cer(str, layout_tokens): Format string for CER computation: remove layout tokens and extra spaces """ if layout_tokens is not None: - str = keep_all_but_tokens( + str = keep_all_but_ner_tokens( str, layout_tokens ) # remove layout tokens from metric str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks @@ -377,8 +377,8 @@ def compute_layout_precision_per_threshold( pred, begin_token, end_token, associated_score=score, order_by_score=True ) gt_list = extract_by_tokens(gt, begin_token, end_token) - pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list] - gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list] + pred_list = [keep_all_but_ner_tokens(p, layout_tokens) for p in pred_list] + gt_list = [keep_all_but_ner_tokens(gt, layout_tokens) for gt in gt_list] precision_per_threshold = [ compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold / 100) for threshold in range(5, 51, 5) @@ -513,7 +513,7 @@ def str_to_graph_simara(str): Compute graph from string of layout tokens for the SIMARA dataset at page level """ begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys())) - layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) + layout_token_sequence = keep_only_ner_tokens(str, begin_layout_tokens) g = nx.DiGraph() g.add_node("D", type="document", level=2, page=0) token_name_dict = {"ⓘ": "I", "â““": "D", "â“¢": "S", "â“’": "C", "â“Ÿ": "P", "â“": "A"} -- GitLab