Skip to content
Snippets Groups Projects
Commit bbbeea3a authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Yoann Schneider
Browse files

Rename keep_only_tokens to keep_only_ner_tokens

parent df6bda94
No related branches found
No related tags found
1 merge request!135Rename keep_only_tokens to keep_only_ner_tokens
......@@ -245,7 +245,7 @@ class MetricManager:
pp_pred.append(pp_module.post_process(pred))
metrics["nb_pp_op_layout"].append(pp_module.num_op)
metrics["nb_gt_layout_token"] = [
len(keep_only_tokens(str_x, self.layout_tokens))
len(keep_only_ner_tokens(str_x, self.layout_tokens))
for str_x in values["str_x"]
]
edit_and_num_items = [
......@@ -261,16 +261,16 @@ class MetricManager:
return self.epoch_metrics[name]
def keep_only_tokens(str, tokens):
def keep_only_ner_tokens(str, tokens):
"""
Remove all but layout tokens from string
Remove all but ner tokens from string
"""
return re.sub("([^" + tokens + "])", "", str)
def keep_all_but_tokens(str, tokens):
def keep_all_but_ner_tokens(str, tokens):
"""
Remove all layout tokens from string
Remove all ner tokens from string
"""
return re.sub("([" + tokens + "])", "", str)
......@@ -309,7 +309,7 @@ def format_string_for_wer(str, layout_tokens, remove_punct=False):
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation
if layout_tokens is not None:
str = keep_all_but_tokens(
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([ \n])+", " ", str).strip() # keep only one space character
......@@ -321,7 +321,7 @@ def format_string_for_cer(str, layout_tokens):
Format string for CER computation: remove layout tokens and extra spaces
"""
if layout_tokens is not None:
str = keep_all_but_tokens(
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks
......@@ -377,8 +377,8 @@ def compute_layout_precision_per_threshold(
pred, begin_token, end_token, associated_score=score, order_by_score=True
)
gt_list = extract_by_tokens(gt, begin_token, end_token)
pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list]
gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list]
pred_list = [keep_all_but_ner_tokens(p, layout_tokens) for p in pred_list]
gt_list = [keep_all_but_ner_tokens(gt, layout_tokens) for gt in gt_list]
precision_per_threshold = [
compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold / 100)
for threshold in range(5, 51, 5)
......@@ -513,7 +513,7 @@ def str_to_graph_simara(str):
Compute graph from string of layout tokens for the SIMARA dataset at page level
"""
begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys()))
layout_token_sequence = keep_only_tokens(str, begin_layout_tokens)
layout_token_sequence = keep_only_ner_tokens(str, begin_layout_tokens)
g = nx.DiGraph()
g.add_node("D", type="document", level=2, page=0)
token_name_dict = {"": "I", "": "D", "": "S", "": "C", "": "P", "": "A"}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment