Skip to content
Snippets Groups Projects
Verified Commit cbe1fde4 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

fix lint

parent 3277e063
No related branches found
No related tags found
1 merge request!59Robust mlflow requests
......@@ -4,7 +4,6 @@ from contextlib import contextmanager
import mlflow
import requests
from mlflow.exceptions import MlflowException
from dan import logger
......@@ -54,9 +53,9 @@ def logging_metrics(
"""
if mlflow_logging and is_master:
make_mlflow_request(
mlflow_method=mlflow.log_metrics, metrics={
f"{step}_{name}": value for name, value in display_values.items()
}, step=epoch
mlflow_method=mlflow.log_metrics,
metrics={f"{step}_{name}": value for name, value in display_values.items()},
step=epoch,
)
......@@ -75,10 +74,10 @@ def logging_tags_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
make_mlflow_request(mlflow_method=mlflow.set_tags, tags=
{
f"{step}_{name}": value for name, value in display_values.items()
})
make_mlflow_request(
mlflow_method=mlflow.set_tags,
tags={f"{step}_{name}": value for name, value in display_values.items()},
)
@contextmanager
......
......@@ -24,7 +24,7 @@ try:
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run, make_mlflow_request
from dan.mlflow import make_mlflow_request, start_mlflow_run
except ImportError:
MLFLOW = False
......@@ -70,9 +70,9 @@ def get_config():
"""
Retrieve model configuration
"""
dataset_name = "synist"
dataset_level = "manual_text_lines"
dataset_variant = ""
dataset_name = "esposalles"
dataset_level = "record"
dataset_variant = "_debug"
dataset_path = "."
params = {
"mlflow": {
......@@ -288,10 +288,9 @@ def run():
)
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
make_mlflow_request(
mlflow_method=mlflow.set_tags,
tags={"Dataset": dataset_name}
mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
)
# Get the labels json file
......@@ -299,13 +298,16 @@ def run():
labels_artifact = json.load(json_file)
# Log MLflow artifacts
for artifact, filename in [(config_artifact, "config.json"), (labels_artifact, "labels.json")]:
for artifact, filename in [
(config_artifact, "config.json"),
(labels_artifact, "labels.json"),
]:
make_mlflow_request(
mlflow_method=mlflow.log_dict,
dictionary=artifact,
artifact_file=filename,
)
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment