Skip to content
Snippets Groups Projects
Commit f43cea7f authored by NolanB's avatar NolanB
Browse files

Add tag and description to training methods

parent 35202cb4
No related branches found
No related tags found
1 merge request!200Allow to specify tag and description when publishing a model version
Pipeline #79480 passed
...@@ -80,7 +80,9 @@ class TrainingMixin(object): ...@@ -80,7 +80,9 @@ class TrainingMixin(object):
Mixin for the Training workers to add Model and ModelVersion helpers Mixin for the Training workers to add Model and ModelVersion helpers
""" """
def publish_model_version(self, model_path: DirPath, model_id: str): def publish_model_version(
self, model_path: DirPath, model_id: str, tag: str = None, description: str = ""
):
""" """
This method creates a model archive and its associated hash, This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex. to create a unique version that will be stored on a bucket and published on arkindex.
...@@ -99,6 +101,8 @@ class TrainingMixin(object): ...@@ -99,6 +101,8 @@ class TrainingMixin(object):
hash=hash, hash=hash,
size=size, size=size,
archive_hash=archive_hash, archive_hash=archive_hash,
tag=tag,
description=description,
) )
if model_version_details is None: if model_version_details is None:
return return
...@@ -118,6 +122,8 @@ class TrainingMixin(object): ...@@ -118,6 +122,8 @@ class TrainingMixin(object):
hash: str, hash: str,
size: int, size: int,
archive_hash: str, archive_hash: str,
tag: str,
description: str,
) -> dict: ) -> dict:
""" """
Create a new version of the specified model with the given information (hashes and size). Create a new version of the specified model with the given information (hashes and size).
...@@ -131,7 +137,13 @@ class TrainingMixin(object): ...@@ -131,7 +137,13 @@ class TrainingMixin(object):
model_version_details = self.request( model_version_details = self.request(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
body={"hash": hash, "size": size, "archive_hash": archive_hash}, body={
"hash": hash,
"size": size,
"archive_hash": archive_hash,
"tag": tag,
"description": description,
},
) )
except ErrorResponse as e: except ErrorResponse as e:
if e.status_code >= 500: if e.status_code >= 500:
...@@ -167,9 +179,7 @@ class TrainingMixin(object): ...@@ -167,9 +179,7 @@ class TrainingMixin(object):
def update_model_version( def update_model_version(
self, self,
model_version_details: dict, model_version_details: dict,
description: str = "",
configuration: dict = {}, configuration: dict = {},
tag: str = None,
) -> None: ) -> None:
""" """
Update the specified model version to the state `Available` and use the given information" Update the specified model version to the state `Available` and use the given information"
...@@ -182,9 +192,9 @@ class TrainingMixin(object): ...@@ -182,9 +192,9 @@ class TrainingMixin(object):
id=model_version_details.get("id"), id=model_version_details.get("id"),
body={ body={
"state": "available", "state": "available",
"description": description, "description": model_version_details.get("description"),
"configuration": configuration, "configuration": configuration,
"tag": tag, "tag": model_version_details.get("tag"),
}, },
) )
except ErrorResponse as e: except ErrorResponse as e:
......
...@@ -45,12 +45,16 @@ def test_create_model_version(): ...@@ -45,12 +45,16 @@ def test_create_model_version():
model_hash = "hash" model_hash = "hash"
archive_hash = "archive_hash" archive_hash = "archive_hash"
size = "30" size = "30"
tag = "tag"
description = "description"
model_version_details = { model_version_details = {
"id": model_version_id, "id": model_version_id,
"model_id": model_id, "model_id": model_id,
"hash": model_hash, "hash": model_hash,
"archive_hash": archive_hash, "archive_hash": archive_hash,
"size": size, "size": size,
"tag": tag,
"description": description,
"s3_url": "http://hehehe.com", "s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com", "s3_put_url": "http://hehehe.com",
} }
...@@ -59,10 +63,18 @@ def test_create_model_version(): ...@@ -59,10 +63,18 @@ def test_create_model_version():
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
response=model_version_details, response=model_version_details,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
) )
assert ( assert (
training.create_model_version(model_id, model_hash, size, archive_hash) training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== model_version_details == model_version_details
) )
...@@ -101,21 +113,33 @@ def test_retrieve_created_model_version(content, status_code): ...@@ -101,21 +113,33 @@ def test_retrieve_created_model_version(content, status_code):
model_hash = "hash" model_hash = "hash"
archive_hash = "archive_hash" archive_hash = "archive_hash"
size = "30" size = "30"
tag = "tag"
description = "description"
training.api_client.add_error_response( training.api_client.add_error_response(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
status_code=status_code, status_code=status_code,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
content=content, content=content,
) )
if status_code == 400: if status_code == 400:
assert ( assert (
training.create_model_version(model_id, model_hash, size, archive_hash) training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== content["hash"] == content["hash"]
) )
elif status_code == 403: elif status_code == 403:
assert ( assert (
training.create_model_version(model_id, model_hash, size, archive_hash) training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
is None is None
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment