Skip to content
Snippets Groups Projects
Commit 387f094b authored by Nolan's avatar Nolan Committed by Yoann Schneider
Browse files

Allow to specify tag and description when publishing a model version

parent d078a1b4
No related branches found
No related tags found
1 merge request!200Allow to specify tag and description when publishing a model version
Pipeline #79489 passed
......@@ -80,7 +80,9 @@ class TrainingMixin(object):
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version(self, model_path: DirPath, model_id: str):
def publish_model_version(
self, model_path: DirPath, model_id: str, tag: str = None, description: str = ""
):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex.
......@@ -99,6 +101,8 @@ class TrainingMixin(object):
hash=hash,
size=size,
archive_hash=archive_hash,
tag=tag,
description=description,
)
if model_version_details is None:
return
......@@ -118,6 +122,8 @@ class TrainingMixin(object):
hash: str,
size: int,
archive_hash: str,
tag: str,
description: str,
) -> dict:
"""
Create a new version of the specified model with the given information (hashes and size).
......@@ -131,7 +137,13 @@ class TrainingMixin(object):
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={"hash": hash, "size": size, "archive_hash": archive_hash},
body={
"hash": hash,
"size": size,
"archive_hash": archive_hash,
"tag": tag,
"description": description,
},
)
except ErrorResponse as e:
if e.status_code >= 500:
......@@ -167,9 +179,7 @@ class TrainingMixin(object):
def update_model_version(
self,
model_version_details: dict,
description: str = "",
configuration: dict = {},
tag: str = None,
) -> None:
"""
Update the specified model version to the state `Available` and use the given information"
......@@ -182,9 +192,9 @@ class TrainingMixin(object):
id=model_version_details.get("id"),
body={
"state": "available",
"description": description,
"description": model_version_details.get("description"),
"configuration": configuration,
"tag": tag,
"tag": model_version_details.get("tag"),
},
)
except ErrorResponse as e:
......
......@@ -35,11 +35,22 @@ def test_create_archive(model_file_dir):
assert not os.path.exists(zst_archive_path), "Auto removal failed"
def test_create_model_version():
@pytest.mark.parametrize(
"tag, description",
[
("tag", "description"),
(None, "description"),
("", "description"),
("tag", ""),
("", ""),
(None, ""),
],
)
def test_create_model_version(tag, description):
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
......@@ -51,6 +62,8 @@ def test_create_model_version():
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
......@@ -59,10 +72,18 @@ def test_create_model_version():
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
)
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== model_version_details
)
......@@ -78,6 +99,24 @@ def test_create_model_version():
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": "tag",
"description": "description",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": None,
"description": "",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
......@@ -101,21 +140,33 @@ def test_retrieve_created_model_version(content, status_code):
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
tag = "tag"
description = "description"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
content=content,
)
if status_code == 400:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== content["hash"]
)
elif status_code == 403:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
is None
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment