Skip to content
Snippets Groups Projects
Commit 6afc3251 authored by NolanB's avatar NolanB
Browse files

Add tests

parent 2b22fe69
No related branches found
No related tags found
No related merge requests found
Pipeline #79378 passed
# -*- coding: utf-8 -*-
# import json
import os
import pytest
import responses
from apistar.exceptions import ErrorResponse
from arkindex.mock import MockApiClient
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.training import TrainingMixin, create_archive
# from responses import matchers
class TrainingWorker(BaseWorker, TrainingMixin):
pass
......@@ -65,27 +61,24 @@ def test_create_model_version():
@pytest.mark.parametrize(
"process_exception, status_code",
"model_version_details, status_code",
[
(
ErrorResponse(
title="Mock error response",
status_code=400,
content="Mock error response",
),
{
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
},
400,
),
(
ErrorResponse(
title="Mock error response",
status_code=403,
content="Mock error response",
),
403,
),
({"hash": ["A version for this model with this hash already exists."]}, 403),
],
)
def test_retrieve_created_model_version(process_exception, status_code):
def test_retrieve_created_model_version(model_version_details, status_code):
"""There is an existing model version in Created mode, A 400 was raised.
But the model is still returned in error content
"""
......@@ -95,70 +88,24 @@ def test_retrieve_created_model_version(process_exception, status_code):
hash = "hash"
archive_hash = "archive_hash"
size = "30"
model_version_id = "fake_model_version_id"
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": hash, "archive_hash": archive_hash, "size": size},
response={"hash": model_version_details},
content={"hash": model_version_details},
)
# training.create_model_version(model_id, hash, size, archive_hash)
# assert training.api_client.responses[0][1] == {"hash": model_version_details}
with pytest.raises(Exception) as e:
training.create_model_version(model_id, hash, size, archive_hash)
assert e.value
# if process_exception.status_code == status_code:
# assert (
# training.create_model_version(model_id, hash, size, archive_hash)
# == process_exception
# )
# try:
# training.create_model_version(model_id, hash, size, archive_hash)
# except Exception as e:
# assert e == process_exception
# def test_retrieve_available_model_version():
# """Raise error when there is an existing model version in Available mode"""
# model_id = "fake_model_id"
# model_version_id = "fake_model_version_id"
# training = TrainingWorker()
# training.api_client = MockApiClient()
# hash = "hash"
# archive_hash = "archive_hash"
# size = "30"
# model_version_details = {
# "id": model_version_id,
# "model_id": model_id,
# "hash": hash,
# "archive_hash": archive_hash,
# "size": size,
# "s3_url": "http://hehehe.com",
# "s3_put_url": "http://hehehe.com",
# }
# training.api_client.add_error_response(
# "CreateModelVersion",
# id=model_id,
# status_code=403,
# body={"hash": hash, "archive_hash": archive_hash, "size": size}
# )
# with pytest.raises(Exception):
# training.create_model_version(model_id, hash, size, archive_hash)
if status_code == 400:
assert (
training.create_model_version(model_id, hash, size, archive_hash)
== model_version_details
)
elif status_code == 403:
assert (
training.create_model_version(model_id, hash, size, archive_hash)
== model_version_details
)
def test_handle_s3_uploading_errors(model_file_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment