Skip to content
Snippets Groups Projects
Verified Commit c7afdc56 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

do not add null field to the payload

parent a20994cb
No related branches found
No related tags found
1 merge request!214Do not add null fields to the payload when creating model version
Pipeline #79561 passed
...@@ -81,7 +81,11 @@ class TrainingMixin(object): ...@@ -81,7 +81,11 @@ class TrainingMixin(object):
""" """
def publish_model_version( def publish_model_version(
self, model_path: DirPath, model_id: str, tag: str = None, description: str = "" self,
model_path: DirPath,
model_id: str,
tag: str = None,
description: str = None,
): ):
""" """
This method creates a model archive and its associated hash, This method creates a model archive and its associated hash,
...@@ -144,30 +148,29 @@ class TrainingMixin(object): ...@@ -144,30 +148,29 @@ class TrainingMixin(object):
# Create a new model version with hash and size # Create a new model version with hash and size
try: try:
payload = {"hash": hash, "size": size, "archive_hash": archive_hash}
if tag:
payload["tag"] = tag
if description:
payload["description"] = description
model_version_details = self.request( model_version_details = self.request(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
body={ body=payload,
"hash": hash,
"size": size,
"archive_hash": archive_hash,
"tag": tag,
"description": description,
},
) )
logger.info( logger.info(
f"Model version ({model_version_details['id']}) was created successfully" f"Model version ({model_version_details['id']}) was created successfully"
) )
except ErrorResponse as e: except ErrorResponse as e:
if e.status_code >= 500: model_version_details = (
e.content.get("hash") if hasattr(e, "content") else None
)
if e.status_code >= 500 or model_version_details is None:
logger.error(f"Failed to create model version: {e.content}") logger.error(f"Failed to create model version: {e.content}")
raise e
model_version_details = e.content.get("hash")
# If the existing model is in Created state, this model is returned as a dict. # If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned. # Else an error in a list is returned.
if model_version_details and isinstance( if isinstance(model_version_details, (list, tuple)):
model_version_details, (list, tuple)
):
logger.error(model_version_details[0]) logger.error(model_version_details[0])
return return
......
...@@ -53,7 +53,7 @@ def test_create_archive(model_file_dir): ...@@ -53,7 +53,7 @@ def test_create_archive(model_file_dir):
("", "description"), ("", "description"),
("tag", ""), ("tag", ""),
("", ""), ("", ""),
(None, ""), (None, None),
], ],
) )
def test_create_model_version(mock_training_worker, tag, description): def test_create_model_version(mock_training_worker, tag, description):
...@@ -76,17 +76,21 @@ def test_create_model_version(mock_training_worker, tag, description): ...@@ -76,17 +76,21 @@ def test_create_model_version(mock_training_worker, tag, description):
"s3_put_url": "http://hehehe.com", "s3_put_url": "http://hehehe.com",
} }
expected_payload = {
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
}
if description:
expected_payload["description"] = description
if tag:
expected_payload["tag"] = tag
mock_training_worker.api_client.add_response( mock_training_worker.api_client.add_response(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
response=model_version_details, response=model_version_details,
body={ body=expected_payload,
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
) )
assert ( assert (
mock_training_worker.create_model_version( mock_training_worker.create_model_version(
...@@ -101,37 +105,19 @@ def test_create_model_version(mock_training_worker, tag, description): ...@@ -101,37 +105,19 @@ def test_create_model_version(mock_training_worker, tag, description):
[ [
( (
{ {
"hash": { "id": "fake_model_version_id",
"id": "fake_model_version_id", "model_id": "fake_model_id",
"model_id": "fake_model_id", "hash": "hash",
"hash": "hash", "archive_hash": "archive_hash",
"archive_hash": "archive_hash", "size": "size",
"size": "size", "tag": "tag",
"tag": "tag", "description": "description",
"description": "description", "s3_url": "http://hehehe.com",
"s3_url": "http://hehehe.com", "s3_put_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
}, },
400, 400,
), ),
( (["A version for this model with this hash already exists."], 403),
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": None,
"description": "",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
({"hash": ["A version for this model with this hash already exists."]}, 403),
], ],
) )
def test_retrieve_created_model_version(mock_training_worker, content, status_code): def test_retrieve_created_model_version(mock_training_worker, content, status_code):
...@@ -141,13 +127,10 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co ...@@ -141,13 +127,10 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
Else if an existing model version in Available mode, Else if an existing model version in Available mode,
403 was raised, but None will be returned 403 was raised, but None will be returned
""" """
model_id = "fake_model_id" model_id = "fake_model_id"
model_hash = "hash" model_hash = "hash"
archive_hash = "archive_hash" archive_hash = "archive_hash"
size = "30" size = "30"
tag = "tag"
description = "description"
mock_training_worker.api_client.add_error_response( mock_training_worker.api_client.add_error_response(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
...@@ -156,27 +139,57 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co ...@@ -156,27 +139,57 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
"hash": model_hash, "hash": model_hash,
"archive_hash": archive_hash, "archive_hash": archive_hash,
"size": size, "size": size,
"tag": tag,
"description": description,
}, },
content=content, content={"hash": content},
) )
if status_code == 400: if status_code == 400:
assert ( assert (
mock_training_worker.create_model_version( mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description model_id, model_hash, size, archive_hash, tag=None, description=None
) )
== content["hash"] == content
) )
elif status_code == 403: elif status_code == 403:
assert ( assert (
mock_training_worker.create_model_version( mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description model_id, model_hash, size, archive_hash, tag=None, description=None
) )
is None is None
) )
@pytest.mark.parametrize(
"content, status_code",
(
# error 500
({"id": "fake_id"}, 500),
# model_version details is None
({}, 403),
(None, 403),
),
)
def test_handle_500_create_model_version(mock_training_worker, content, status_code):
model_id = "fake_model_id"
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
mock_training_worker.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
},
content=content,
)
with pytest.raises(Exception):
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag=None, description=None
)
def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir): def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir):
s3_endpoint_url = "http://s3.localhost.com" s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url) responses.add_passthru(s3_endpoint_url)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment