diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 5fd9e4555091d0ee5e223996841d0cd1f51c5f46..2501e0eb3eb9cac7681e6ce1ede839830d9e4126 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -80,7 +80,9 @@ class TrainingMixin(object): Mixin for the Training workers to add Model and ModelVersion helpers """ - def publish_model_version(self, model_path: DirPath, model_id: str): + def publish_model_version( + self, model_path: DirPath, model_id: str, tag: str = None, description: str = "" + ): """ This method creates a model archive and its associated hash, to create a unique version that will be stored on a bucket and published on arkindex. @@ -99,6 +101,8 @@ class TrainingMixin(object): hash=hash, size=size, archive_hash=archive_hash, + tag=tag, + description=description, ) if model_version_details is None: return @@ -118,6 +122,8 @@ class TrainingMixin(object): hash: str, size: int, archive_hash: str, + tag: str, + description: str, ) -> dict: """ Create a new version of the specified model with the given information (hashes and size). @@ -131,7 +137,13 @@ class TrainingMixin(object): model_version_details = self.request( "CreateModelVersion", id=model_id, - body={"hash": hash, "size": size, "archive_hash": archive_hash}, + body={ + "hash": hash, + "size": size, + "archive_hash": archive_hash, + "tag": tag, + "description": description, + }, ) except ErrorResponse as e: if e.status_code >= 500: @@ -167,9 +179,7 @@ class TrainingMixin(object): def update_model_version( self, model_version_details: dict, - description: str = "", configuration: dict = {}, - tag: str = None, ) -> None: """ Update the specified model version to the state `Available` and use the given information" @@ -182,9 +192,9 @@ class TrainingMixin(object): id=model_version_details.get("id"), body={ "state": "available", - "description": description, + "description": model_version_details.get("description"), "configuration": configuration, - "tag": tag, + "tag": model_version_details.get("tag"), }, ) except ErrorResponse as e: diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 078d0af763a0633c95aa277c96248417ff56c172..9b6f014a8fc19c1a5bd5ba13425f3cbb5fbb82d6 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -35,11 +35,22 @@ def test_create_archive(model_file_dir): assert not os.path.exists(zst_archive_path), "Auto removal failed" -def test_create_model_version(): +@pytest.mark.parametrize( + "tag, description", + [ + ("tag", "description"), + (None, "description"), + ("", "description"), + ("tag", ""), + ("", ""), + (None, ""), + ], +) +def test_create_model_version(tag, description): """A new model version is returned""" - model_id = "fake_model_id" model_version_id = "fake_model_version_id" + model_id = "fake_model_id" training = TrainingWorker() training.api_client = MockApiClient() model_hash = "hash" @@ -51,6 +62,8 @@ def test_create_model_version(): "hash": model_hash, "archive_hash": archive_hash, "size": size, + "tag": tag, + "description": description, "s3_url": "http://hehehe.com", "s3_put_url": "http://hehehe.com", } @@ -59,10 +72,18 @@ def test_create_model_version(): "CreateModelVersion", id=model_id, response=model_version_details, - body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, + body={ + "hash": model_hash, + "archive_hash": archive_hash, + "size": size, + "tag": tag, + "description": description, + }, ) assert ( - training.create_model_version(model_id, model_hash, size, archive_hash) + training.create_model_version( + model_id, model_hash, size, archive_hash, tag, description + ) == model_version_details ) @@ -78,6 +99,24 @@ def test_create_model_version(): "hash": "hash", "archive_hash": "archive_hash", "size": "size", + "tag": "tag", + "description": "description", + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } + }, + 400, + ), + ( + { + "hash": { + "id": "fake_model_version_id", + "model_id": "fake_model_id", + "hash": "hash", + "archive_hash": "archive_hash", + "size": "size", + "tag": None, + "description": "", "s3_url": "http://hehehe.com", "s3_put_url": "http://hehehe.com", } @@ -101,21 +140,33 @@ def test_retrieve_created_model_version(content, status_code): model_hash = "hash" archive_hash = "archive_hash" size = "30" + tag = "tag" + description = "description" training.api_client.add_error_response( "CreateModelVersion", id=model_id, status_code=status_code, - body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, + body={ + "hash": model_hash, + "archive_hash": archive_hash, + "size": size, + "tag": tag, + "description": description, + }, content=content, ) if status_code == 400: assert ( - training.create_model_version(model_id, model_hash, size, archive_hash) + training.create_model_version( + model_id, model_hash, size, archive_hash, tag, description + ) == content["hash"] ) elif status_code == 403: assert ( - training.create_model_version(model_id, model_hash, size, archive_hash) + training.create_model_version( + model_id, model_hash, size, archive_hash, tag, description + ) is None )