diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 5fd9e4555091d0ee5e223996841d0cd1f51c5f46..2501e0eb3eb9cac7681e6ce1ede839830d9e4126 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -80,7 +80,9 @@ class TrainingMixin(object): Mixin for the Training workers to add Model and ModelVersion helpers """ - def publish_model_version(self, model_path: DirPath, model_id: str): + def publish_model_version( + self, model_path: DirPath, model_id: str, tag: str = None, description: str = "" + ): """ This method creates a model archive and its associated hash, to create a unique version that will be stored on a bucket and published on arkindex. @@ -99,6 +101,8 @@ class TrainingMixin(object): hash=hash, size=size, archive_hash=archive_hash, + tag=tag, + description=description, ) if model_version_details is None: return @@ -118,6 +122,8 @@ class TrainingMixin(object): hash: str, size: int, archive_hash: str, + tag: str, + description: str, ) -> dict: """ Create a new version of the specified model with the given information (hashes and size). @@ -131,7 +137,13 @@ class TrainingMixin(object): model_version_details = self.request( "CreateModelVersion", id=model_id, - body={"hash": hash, "size": size, "archive_hash": archive_hash}, + body={ + "hash": hash, + "size": size, + "archive_hash": archive_hash, + "tag": tag, + "description": description, + }, ) except ErrorResponse as e: if e.status_code >= 500: @@ -167,9 +179,7 @@ class TrainingMixin(object): def update_model_version( self, model_version_details: dict, - description: str = "", configuration: dict = {}, - tag: str = None, ) -> None: """ Update the specified model version to the state `Available` and use the given information" @@ -182,9 +192,9 @@ class TrainingMixin(object): id=model_version_details.get("id"), body={ "state": "available", - "description": description, + "description": model_version_details.get("description"), "configuration": configuration, - "tag": tag, + "tag": model_version_details.get("tag"), }, ) except ErrorResponse as e: diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 078d0af763a0633c95aa277c96248417ff56c172..66067ee4901f38153f169e66546dd525560c3bd2 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -45,12 +45,16 @@ def test_create_model_version(): model_hash = "hash" archive_hash = "archive_hash" size = "30" + tag = "tag" + description = "description" model_version_details = { "id": model_version_id, "model_id": model_id, "hash": model_hash, "archive_hash": archive_hash, "size": size, + "tag": tag, + "description": description, "s3_url": "http://hehehe.com", "s3_put_url": "http://hehehe.com", } @@ -59,10 +63,18 @@ def test_create_model_version(): "CreateModelVersion", id=model_id, response=model_version_details, - body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, + body={ + "hash": model_hash, + "archive_hash": archive_hash, + "size": size, + "tag": tag, + "description": description, + }, ) assert ( - training.create_model_version(model_id, model_hash, size, archive_hash) + training.create_model_version( + model_id, model_hash, size, archive_hash, tag, description + ) == model_version_details ) @@ -101,21 +113,33 @@ def test_retrieve_created_model_version(content, status_code): model_hash = "hash" archive_hash = "archive_hash" size = "30" + tag = "tag" + description = "description" training.api_client.add_error_response( "CreateModelVersion", id=model_id, status_code=status_code, - body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, + body={ + "hash": model_hash, + "archive_hash": archive_hash, + "size": size, + "tag": tag, + "description": description, + }, content=content, ) if status_code == 400: assert ( - training.create_model_version(model_id, model_hash, size, archive_hash) + training.create_model_version( + model_id, model_hash, size, archive_hash, tag, description + ) == content["hash"] ) elif status_code == 403: assert ( - training.create_model_version(model_id, model_hash, size, archive_hash) + training.create_model_version( + model_id, model_hash, size, archive_hash, tag, description + ) is None )