Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (4)
......@@ -16,7 +16,6 @@ class MetricManager:
if "simara" in dataset_name and "page" in dataset_name:
self.post_processing_module = PostProcessingModuleSIMARA
self.matching_tokens = SIMARA_MATCHING_TOKENS
self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_simara
else:
self.matching_tokens = dict()
......@@ -549,16 +548,3 @@ def graph_edit_distance(g1, g2):
):
new_edit = v
return new_edit
def edit_and_num_items_for_ged_from_str_simara(str_gt, str_pred):
"""
Compute graph edit distance and num nodes/edges for normalized graph edit distance
For the SIMARA dataset
"""
g_gt = str_to_graph_simara(str_gt)
g_pred = str_to_graph_simara(str_pred)
return (
graph_edit_distance(g_gt, g_pred),
g_gt.number_of_nodes() + g_gt.number_of_edges(),
)
......@@ -4,8 +4,6 @@ import json
import os
import pickle
import random
import sys
from datetime import date
from time import time
import numpy as np
......@@ -523,7 +521,6 @@ class GenericTrainingManager:
return
params = copy.deepcopy(self.params)
params = class_to_str_dict(params)
params["date"] = date.today().strftime("%d/%m/%Y")
total_params = 0
for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name])
......@@ -533,21 +530,6 @@ class GenericTrainingManager:
]
total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params)
params["hardware"] = dict()
if self.device != "cpu":
for i in range(self.params["training_params"]["nb_gpu"]):
params["hardware"][str(i)] = "{} {}".format(
torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)
)
else:
params["hardware"]["0"] = "CPU"
params["software"] = {
"python_version": sys.version,
"pytorch_version": torch.__version__,
"cuda_version": torch.version.cuda,
"cudnn_version": torch.backends.cudnn.version(),
}
with open(path, "w") as f:
json.dump(params, f, indent=4)
......@@ -871,8 +853,9 @@ class GenericTrainingManager:
with open(path, "w") as f:
yaml.dump(metrics, stream=f)
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
if mlflow_logging:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
def output_pred(self, name):
path = os.path.join(
......@@ -1104,14 +1087,7 @@ class Manager(OCRManager):
reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
y_len = batch_data["labels_len"]
# add errors in teacher forcing
if (
"teacher_forcing_error_rate" in self.params["training_params"]
and self.params["training_params"]["teacher_forcing_error_rate"] is not None
):
error_rate = self.params["training_params"]["teacher_forcing_error_rate"]
simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate)
elif "label_noise_scheduler" in self.params["training_params"]:
if "label_noise_scheduler" in self.params["training_params"]:
error_rate = (
self.params["training_params"]["label_noise_scheduler"][
"min_error_rate"
......