Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (6)
...@@ -16,7 +16,6 @@ class MetricManager: ...@@ -16,7 +16,6 @@ class MetricManager:
if "simara" in dataset_name and "page" in dataset_name: if "simara" in dataset_name and "page" in dataset_name:
self.post_processing_module = PostProcessingModuleSIMARA self.post_processing_module = PostProcessingModuleSIMARA
self.matching_tokens = SIMARA_MATCHING_TOKENS self.matching_tokens = SIMARA_MATCHING_TOKENS
self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_simara
else: else:
self.matching_tokens = dict() self.matching_tokens = dict()
...@@ -150,7 +149,6 @@ class MetricManager: ...@@ -150,7 +149,6 @@ class MetricManager:
) )
elif metric_name in [ elif metric_name in [
"loss", "loss",
"loss_ctc",
"loss_ce", "loss_ce",
]: ]:
value = float( value = float(
...@@ -220,7 +218,6 @@ class MetricManager: ...@@ -220,7 +218,6 @@ class MetricManager:
] ]
metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt] metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
elif metric_name in [ elif metric_name in [
"loss_ctc",
"loss_ce", "loss_ce",
"loss", "loss",
]: ]:
...@@ -246,7 +243,7 @@ class MetricManager: ...@@ -246,7 +243,7 @@ class MetricManager:
pp_pred.append(pp_module.post_process(pred)) pp_pred.append(pp_module.post_process(pred))
metrics["nb_pp_op_layout"].append(pp_module.num_op) metrics["nb_pp_op_layout"].append(pp_module.num_op)
metrics["nb_gt_layout_token"] = [ metrics["nb_gt_layout_token"] = [
len(keep_only_tokens(str_x, self.layout_tokens)) len(keep_only_ner_tokens(str_x, self.layout_tokens))
for str_x in values["str_x"] for str_x in values["str_x"]
] ]
edit_and_num_items = [ edit_and_num_items = [
...@@ -262,16 +259,16 @@ class MetricManager: ...@@ -262,16 +259,16 @@ class MetricManager:
return self.epoch_metrics[name] return self.epoch_metrics[name]
def keep_only_tokens(str, tokens): def keep_only_ner_tokens(str, tokens):
""" """
Remove all but layout tokens from string Remove all but ner tokens from string
""" """
return re.sub("([^" + tokens + "])", "", str) return re.sub("([^" + tokens + "])", "", str)
def keep_all_but_tokens(str, tokens): def keep_all_but_ner_tokens(str, tokens):
""" """
Remove all layout tokens from string Remove all ner tokens from string
""" """
return re.sub("([" + tokens + "])", "", str) return re.sub("([" + tokens + "])", "", str)
...@@ -310,7 +307,7 @@ def format_string_for_wer(str, layout_tokens, remove_punct=False): ...@@ -310,7 +307,7 @@ def format_string_for_wer(str, layout_tokens, remove_punct=False):
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation ) # remove punctuation
if layout_tokens is not None: if layout_tokens is not None:
str = keep_all_but_tokens( str = keep_all_but_ner_tokens(
str, layout_tokens str, layout_tokens
) # remove layout tokens from metric ) # remove layout tokens from metric
str = re.sub("([ \n])+", " ", str).strip() # keep only one space character str = re.sub("([ \n])+", " ", str).strip() # keep only one space character
...@@ -322,7 +319,7 @@ def format_string_for_cer(str, layout_tokens): ...@@ -322,7 +319,7 @@ def format_string_for_cer(str, layout_tokens):
Format string for CER computation: remove layout tokens and extra spaces Format string for CER computation: remove layout tokens and extra spaces
""" """
if layout_tokens is not None: if layout_tokens is not None:
str = keep_all_but_tokens( str = keep_all_but_ner_tokens(
str, layout_tokens str, layout_tokens
) # remove layout tokens from metric ) # remove layout tokens from metric
str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks
...@@ -378,8 +375,8 @@ def compute_layout_precision_per_threshold( ...@@ -378,8 +375,8 @@ def compute_layout_precision_per_threshold(
pred, begin_token, end_token, associated_score=score, order_by_score=True pred, begin_token, end_token, associated_score=score, order_by_score=True
) )
gt_list = extract_by_tokens(gt, begin_token, end_token) gt_list = extract_by_tokens(gt, begin_token, end_token)
pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list] pred_list = [keep_all_but_ner_tokens(p, layout_tokens) for p in pred_list]
gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list] gt_list = [keep_all_but_ner_tokens(gt, layout_tokens) for gt in gt_list]
precision_per_threshold = [ precision_per_threshold = [
compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold / 100) compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold / 100)
for threshold in range(5, 51, 5) for threshold in range(5, 51, 5)
...@@ -514,7 +511,7 @@ def str_to_graph_simara(str): ...@@ -514,7 +511,7 @@ def str_to_graph_simara(str):
Compute graph from string of layout tokens for the SIMARA dataset at page level Compute graph from string of layout tokens for the SIMARA dataset at page level
""" """
begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys())) begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys()))
layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) layout_token_sequence = keep_only_ner_tokens(str, begin_layout_tokens)
g = nx.DiGraph() g = nx.DiGraph()
g.add_node("D", type="document", level=2, page=0) g.add_node("D", type="document", level=2, page=0)
token_name_dict = {"": "I", "": "D", "": "S", "": "C", "": "P", "": "A"} token_name_dict = {"": "I", "": "D", "": "S", "": "C", "": "P", "": "A"}
...@@ -549,16 +546,3 @@ def graph_edit_distance(g1, g2): ...@@ -549,16 +546,3 @@ def graph_edit_distance(g1, g2):
): ):
new_edit = v new_edit = v
return new_edit return new_edit
def edit_and_num_items_for_ged_from_str_simara(str_gt, str_pred):
"""
Compute graph edit distance and num nodes/edges for normalized graph edit distance
For the SIMARA dataset
"""
g_gt = str_to_graph_simara(str_gt)
g_pred = str_to_graph_simara(str_pred)
return (
graph_edit_distance(g_gt, g_pred),
g_gt.number_of_nodes() + g_gt.number_of_edges(),
)
...@@ -14,7 +14,6 @@ from dan.ocr.utils import LM_str_to_ind ...@@ -14,7 +14,6 @@ from dan.ocr.utils import LM_str_to_ind
from dan.utils import ( from dan.utils import (
pad_image, pad_image,
pad_image_width_random, pad_image_width_random,
pad_image_width_right,
pad_images, pad_images,
pad_sequences_1D, pad_sequences_1D,
rand, rand,
...@@ -45,19 +44,12 @@ class OCRDatasetManager(DatasetManager): ...@@ -45,19 +44,12 @@ class OCRDatasetManager(DatasetManager):
self.tokens = { self.tokens = {
"pad": params["config"]["padding_token"], "pad": params["config"]["padding_token"],
} }
if self.params["config"]["charset_mode"].lower() == "ctc": self.tokens["end"] = len(self.charset)
self.tokens["blank"] = len(self.charset) self.tokens["start"] = len(self.charset) + 1
self.tokens["pad"] = ( self.tokens["pad"] = (
self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 1 self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2
) )
self.params["config"]["padding_token"] = self.tokens["pad"] self.params["config"]["padding_token"] = self.tokens["pad"]
elif self.params["config"]["charset_mode"] == "seq2seq":
self.tokens["end"] = len(self.charset)
self.tokens["start"] = len(self.charset) + 1
self.tokens["pad"] = (
self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2
)
self.params["config"]["padding_token"] = self.tokens["pad"]
def get_merged_charsets(self): def get_merged_charsets(self):
""" """
...@@ -161,37 +153,7 @@ class OCRDataset(GenericDataset): ...@@ -161,37 +153,7 @@ class OCRDataset(GenericDataset):
sample["img_shape"] / self.reduce_dims_factor sample["img_shape"] / self.reduce_dims_factor
).astype(int) ).astype(int)
# Padding to handle CTC requirements
if self.set_name == "train": if self.set_name == "train":
max_label_len = 0
height = 1
ctc_padding = False
if "CTC_line" in self.params["config"]["constraints"]:
max_label_len = sample["label_len"]
ctc_padding = True
if "CTC_va" in self.params["config"]["constraints"]:
max_label_len = max(sample["line_label_len"])
ctc_padding = True
if "CTC_pg" in self.params["config"]["constraints"]:
max_label_len = sample["label_len"]
height = max(sample["img_reduced_shape"][0], 1)
ctc_padding = True
if (
ctc_padding
and 2 * max_label_len + 1 > sample["img_reduced_shape"][1] * height
):
sample["img"] = pad_image_width_right(
sample["img"],
int(
np.ceil((2 * max_label_len + 1) / height)
* self.reduce_dims_factor[1]
),
self.padding_value,
)
sample["img_shape"] = sample["img"].shape
sample["img_reduced_shape"] = np.ceil(
sample["img_shape"] / self.reduce_dims_factor
).astype(int)
sample["img_reduced_shape"] = [ sample["img_reduced_shape"] = [
max(1, t) for t in sample["img_reduced_shape"] max(1, t) for t in sample["img_reduced_shape"]
] ]
......
...@@ -4,8 +4,6 @@ import json ...@@ -4,8 +4,6 @@ import json
import os import os
import pickle import pickle
import random import random
import sys
from datetime import date
from time import time from time import time
import numpy as np import numpy as np
...@@ -343,12 +341,6 @@ class GenericTrainingManager: ...@@ -343,12 +341,6 @@ class GenericTrainingManager:
if c in old_charset: if c in old_charset:
new_weights[i] = weights[old_charset.index(c)] new_weights[i] = weights[old_charset.index(c)]
pretrained_chars.append(c) pretrained_chars.append(c)
if (
"transfered_charset_last_is_ctc_blank" in self.params["model_params"]
and self.params["model_params"]["transfered_charset_last_is_ctc_blank"]
):
new_weights[-1] = weights[-1]
pretrained_chars.append("<blank>")
checkpoint["{}_state_dict".format(state_dict_name)][key] = new_weights checkpoint["{}_state_dict".format(state_dict_name)][key] = new_weights
self.models[model_name].load_state_dict( self.models[model_name].load_state_dict(
{key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, {key: checkpoint["{}_state_dict".format(state_dict_name)][key]},
...@@ -523,7 +515,6 @@ class GenericTrainingManager: ...@@ -523,7 +515,6 @@ class GenericTrainingManager:
return return
params = copy.deepcopy(self.params) params = copy.deepcopy(self.params)
params = class_to_str_dict(params) params = class_to_str_dict(params)
params["date"] = date.today().strftime("%d/%m/%Y")
total_params = 0 total_params = 0
for model_name in self.models.keys(): for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name]) current_params = compute_nb_params(self.models[model_name])
...@@ -533,21 +524,6 @@ class GenericTrainingManager: ...@@ -533,21 +524,6 @@ class GenericTrainingManager:
] ]
total_params += current_params total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params) params["model_params"]["total_params"] = "{:,}".format(total_params)
params["hardware"] = dict()
if self.device != "cpu":
for i in range(self.params["training_params"]["nb_gpu"]):
params["hardware"][str(i)] = "{} {}".format(
torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)
)
else:
params["hardware"]["0"] = "CPU"
params["software"] = {
"python_version": sys.version,
"pytorch_version": torch.__version__,
"cuda_version": torch.version.cuda,
"cudnn_version": torch.backends.cudnn.version(),
}
with open(path, "w") as f: with open(path, "w") as f:
json.dump(params, f, indent=4) json.dump(params, f, indent=4)
...@@ -871,8 +847,9 @@ class GenericTrainingManager: ...@@ -871,8 +847,9 @@ class GenericTrainingManager:
with open(path, "w") as f: with open(path, "w") as f:
yaml.dump(metrics, stream=f) yaml.dump(metrics, stream=f)
# Log mlflow artifacts if mlflow_logging:
mlflow.log_artifact(path, "predictions") # Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
def output_pred(self, name): def output_pred(self, name):
path = os.path.join( path = os.path.join(
...@@ -920,7 +897,6 @@ class GenericTrainingManager: ...@@ -920,7 +897,6 @@ class GenericTrainingManager:
"nb_samples", "nb_samples",
"loss", "loss",
"loss_ce", "loss_ce",
"loss_ctc",
"loss_ce_end", "loss_ce_end",
]: ]:
metrics[metric_name] = self.sum_ddp_metric( metrics[metric_name] = self.sum_ddp_metric(
...@@ -1104,14 +1080,7 @@ class Manager(OCRManager): ...@@ -1104,14 +1080,7 @@ class Manager(OCRManager):
reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]] reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
y_len = batch_data["labels_len"] y_len = batch_data["labels_len"]
# add errors in teacher forcing if "label_noise_scheduler" in self.params["training_params"]:
if (
"teacher_forcing_error_rate" in self.params["training_params"]
and self.params["training_params"]["teacher_forcing_error_rate"] is not None
):
error_rate = self.params["training_params"]["teacher_forcing_error_rate"]
simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate)
elif "label_noise_scheduler" in self.params["training_params"]:
error_rate = ( error_rate = (
self.params["training_params"]["label_noise_scheduler"][ self.params["training_params"]["label_noise_scheduler"][
"min_error_rate" "min_error_rate"
......
...@@ -114,7 +114,6 @@ def get_config(): ...@@ -114,7 +114,6 @@ def get_config():
"height_divisor": 32, # Image height will be divided by 32 "height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value "padding_value": 0, # Image padding value
"padding_token": None, # Label padding value "padding_token": None, # Label padding value
"charset_mode": "seq2seq", # add end-of-transcription and start-of-transcription tokens to charset
"constraints": [ "constraints": [
"add_eot", "add_eot",
"add_sot", "add_sot",
......
...@@ -18,7 +18,6 @@ All hyperparameters are specified and editable in the training scripts (meaning ...@@ -18,7 +18,6 @@ All hyperparameters are specified and editable in the training scripts (meaning
| `dataset_params.config.width_divisor` | Factor to reduce the height of the feature vector before feeding the decoder. | `int` | `32` | | `dataset_params.config.width_divisor` | Factor to reduce the height of the feature vector before feeding the decoder. | `int` | `32` |
| `dataset_params.config.padding_value` | Image padding value. | `int` | `0` | | `dataset_params.config.padding_value` | Image padding value. | `int` | `0` |
| `dataset_params.config.padding_token` | Transcription padding value. | `int` | `None` | | `dataset_params.config.padding_token` | Transcription padding value. | `int` | `None` |
| `dataset_params.config.charset_mode` | Whether to add end-of-transcription and start-of-transcription tokens to charset. | `str` | `seq2seq` |
| `dataset_params.config.constraints` | Whether to add end-of-transcription and start-of-transcription tokens in labels. | `list` | `["add_eot", "add_sot"]` | | `dataset_params.config.constraints` | Whether to add end-of-transcription and start-of-transcription tokens in labels. | `list` | `["add_eot", "add_sot"]` |
| `dataset_params.config.normalize` | Normalize with mean and variance of training dataset. | `bool` | `True` | | `dataset_params.config.normalize` | Normalize with mean and variance of training dataset. | `bool` | `True` |
| `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | | `dataset_params.config.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
......
...@@ -72,7 +72,6 @@ def training_config(): ...@@ -72,7 +72,6 @@ def training_config():
"height_divisor": 32, # Image height will be divided by 32 "height_divisor": 32, # Image height will be divided by 32
"padding_value": 0, # Image padding value "padding_value": 0, # Image padding value
"padding_token": None, # Label padding value "padding_token": None, # Label padding value
"charset_mode": "seq2seq", # add end-of-transcription and start-of-transcription tokens to charset
"constraints": [ "constraints": [
"add_eot", "add_eot",
"add_sot", "add_sot",
......