diff --git a/dan/mlflow.py b/dan/mlflow.py index 9fdd4a6512ef198946cfb5fa438397a37f6795f2..366a119af425f233dc58c42c286e81f6cd5fd856 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -3,6 +3,8 @@ import os from contextlib import contextmanager import mlflow +from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES + import requests from dan import logger @@ -16,7 +18,7 @@ def make_mlflow_request(mlflow_method, *args, **kwargs): mlflow_method(*args, **kwargs) except requests.exceptions.ConnectionError as e: logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}") - + raise e def setup_environment(config: dict): """ @@ -34,6 +36,11 @@ def setup_environment(config: dict): if config_key in config: os.environ[variable_name] = config[config_key] + # Check max retry setting + max_retries = MLFLOW_HTTP_REQUEST_MAX_RETRIES.get() + if max_retries and int(max_retries) <= 1: + logger.warning(f"The maximum number of retries for MLflow HTTP requests is set to {max_retries}, which is low. Consider using a higher value.") + def logging_metrics( display_values: dict, diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 72abb8b8426cdd2f65ef7cefbfc21fe7afc83e92..4b005ec4bdbdd5512c7c2d37d825df573cb3f6e1 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -103,6 +103,11 @@ def get_config(): (dataset_name, "val"), ], }, + "test": { + "{}-test".format(dataset_name): [ + (dataset_name, "test"), + ], + }, "config": { "load_in_memory": True, # Load all images in CPU memory "worker_per_gpu": 4, # Num of parallel processes per gpu for data loading