Skip to content
Snippets Groups Projects

Allow to specify tag and description when publishing a model version

All threads resolved!
1 file
+ 56
16
Compare changes
  • Side-by-side
  • Inline
@@ -35,11 +35,22 @@ def test_create_archive(model_file_dir):
assert not os.path.exists(zst_archive_path), "Auto removal failed"
def test_create_model_version():
@pytest.mark.parametrize(
"tag, description",
[
("tag", "description"),
(None, "description"),
("", "description"),
("tag", ""),
("", ""),
(None, ""),
],
)
def test_create_model_version(tag, description):
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
@@ -51,6 +62,8 @@ def test_create_model_version():
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
@@ -59,10 +72,18 @@ def test_create_model_version():
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
)
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== model_version_details
)
@@ -78,6 +99,24 @@ def test_create_model_version():
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": "tag",
"description": "description",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": None,
"description": "",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
@@ -101,21 +140,33 @@ def test_retrieve_created_model_version(content, status_code):
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
tag = "tag"
description = "description"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
content=content,
)
if status_code == 400:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
== content["hash"]
)
elif status_code == 403:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
training.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
)
is None
)
Loading