Skip to content
Snippets Groups Projects
Commit b64c4327 authored by NolanB's avatar NolanB
Browse files

Add tag and description to training methods

parent d078a1b4
No related branches found
No related tags found
1 merge request!200Allow to specify tag and description when publishing a model version
......@@ -80,7 +80,9 @@ class TrainingMixin(object):
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version(self, model_path: DirPath, model_id: str):
def publish_model_version(
self, model_path: DirPath, model_id: str, tag: str = None, description: str = ""
):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex.
......@@ -99,6 +101,8 @@ class TrainingMixin(object):
hash=hash,
size=size,
archive_hash=archive_hash,
tag=tag,
description=description,
)
if model_version_details is None:
return
......@@ -118,6 +122,8 @@ class TrainingMixin(object):
hash: str,
size: int,
archive_hash: str,
tag: str,
description: str,
) -> dict:
"""
Create a new version of the specified model with the given information (hashes and size).
......@@ -131,7 +137,13 @@ class TrainingMixin(object):
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={"hash": hash, "size": size, "archive_hash": archive_hash},
body={
"hash": hash,
"size": size,
"archive_hash": archive_hash,
"tag": tag,
"description": description,
},
)
except ErrorResponse as e:
if e.status_code >= 500:
......@@ -167,9 +179,7 @@ class TrainingMixin(object):
def update_model_version(
self,
model_version_details: dict,
description: str = "",
configuration: dict = {},
tag: str = None,
) -> None:
"""
Update the specified model version to the state `Available` and use the given information"
......@@ -182,9 +192,9 @@ class TrainingMixin(object):
id=model_version_details.get("id"),
body={
"state": "available",
"description": description,
"description": model_version_details.get("description"),
"configuration": configuration,
"tag": tag,
"tag": model_version_details.get("tag"),
},
)
except ErrorResponse as e:
......
......@@ -45,12 +45,16 @@ def test_create_model_version():
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
tag = "tag"
description = "description"
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
......@@ -59,10 +63,18 @@ def test_create_model_version():
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
)
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== model_version_details
)
......@@ -101,21 +113,33 @@ def test_retrieve_created_model_version(content, status_code):
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
tag = "tag"
description = "description"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
content=content,
)
if status_code == 400:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== content["hash"]
)
elif status_code == 403:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
is None
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment