Skip to content
Snippets Groups Projects
Commit 492d1d68 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Do not add null fields to the payload when creating model version

parent a20994cb
No related branches found
No related tags found
1 merge request!214Do not add null fields to the payload when creating model version
Pipeline #79563 passed
......@@ -81,7 +81,11 @@ class TrainingMixin(object):
"""
def publish_model_version(
self, model_path: DirPath, model_id: str, tag: str = None, description: str = ""
self,
model_path: DirPath,
model_id: str,
tag: str = None,
description: str = None,
):
"""
This method creates a model archive and its associated hash,
......@@ -144,30 +148,29 @@ class TrainingMixin(object):
# Create a new model version with hash and size
try:
payload = {"hash": hash, "size": size, "archive_hash": archive_hash}
if tag:
payload["tag"] = tag
if description:
payload["description"] = description
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={
"hash": hash,
"size": size,
"archive_hash": archive_hash,
"tag": tag,
"description": description,
},
body=payload,
)
logger.info(
f"Model version ({model_version_details['id']}) was created successfully"
)
except ErrorResponse as e:
if e.status_code >= 500:
model_version_details = (
e.content.get("hash") if hasattr(e, "content") else None
)
if e.status_code >= 500 or model_version_details is None:
logger.error(f"Failed to create model version: {e.content}")
model_version_details = e.content.get("hash")
raise e
# If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned.
if model_version_details and isinstance(
model_version_details, (list, tuple)
):
if isinstance(model_version_details, (list, tuple)):
logger.error(model_version_details[0])
return
......
......@@ -53,7 +53,7 @@ def test_create_archive(model_file_dir):
("", "description"),
("tag", ""),
("", ""),
(None, ""),
(None, None),
],
)
def test_create_model_version(mock_training_worker, tag, description):
......@@ -76,17 +76,21 @@ def test_create_model_version(mock_training_worker, tag, description):
"s3_put_url": "http://hehehe.com",
}
expected_payload = {
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
}
if description:
expected_payload["description"] = description
if tag:
expected_payload["tag"] = tag
mock_training_worker.api_client.add_response(
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
body=expected_payload,
)
assert (
mock_training_worker.create_model_version(
......@@ -101,37 +105,19 @@ def test_create_model_version(mock_training_worker, tag, description):
[
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": "tag",
"description": "description",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": "tag",
"description": "description",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
},
400,
),
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": None,
"description": "",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
({"hash": ["A version for this model with this hash already exists."]}, 403),
(["A version for this model with this hash already exists."], 403),
],
)
def test_retrieve_created_model_version(mock_training_worker, content, status_code):
......@@ -141,13 +127,10 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
Else if an existing model version in Available mode,
403 was raised, but None will be returned
"""
model_id = "fake_model_id"
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
tag = "tag"
description = "description"
mock_training_worker.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
......@@ -156,27 +139,57 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
content=content,
content={"hash": content},
)
if status_code == 400:
assert (
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
model_id, model_hash, size, archive_hash, tag=None, description=None
)
== content["hash"]
== content
)
elif status_code == 403:
assert (
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
model_id, model_hash, size, archive_hash, tag=None, description=None
)
is None
)
@pytest.mark.parametrize(
"content, status_code",
(
# error 500
({"id": "fake_id"}, 500),
# model_version details is None
({}, 403),
(None, 403),
),
)
def test_handle_500_create_model_version(mock_training_worker, content, status_code):
model_id = "fake_model_id"
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
mock_training_worker.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
},
content=content,
)
with pytest.raises(Exception):
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag=None, description=None
)
def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir):
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment