Newer
Older
# -*- coding: utf-8 -*-
import uuid
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Artifact
from tests.conftest import FIXTURES_DIR
from tests.test_elements_worker import BASE_API_CALLS
TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe")
@pytest.mark.parametrize(
"payload, error",
(
# Task ID
(
{"task_id": None},
"task_id shouldn't be null and should be an UUID",
),
(
{"task_id": "12341234-1234-1234-1234-123412341234"},
"task_id shouldn't be null and should be an UUID",
),
),
)
def test_list_artifacts_wrong_param_task_id(mock_dataset_worker, payload, error):
with pytest.raises(AssertionError, match=error):
mock_dataset_worker.list_artifacts(**payload)
def test_list_artifacts_api_error(responses, mock_dataset_worker):
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{TASK_ID}/artifacts/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_dataset_worker.list_artifacts(task_id=TASK_ID)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
]
def test_list_artifacts(
responses,
mock_dataset_worker,
):
expected_results = [
{
"id": "artifact_1",
"path": "dataset_id.zstd",
"size": 42,
"content_type": "application/zstd",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
},
{
"id": "artifact_2",
"path": "logs.log",
"size": 42,
"content_type": "text/plain",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{TASK_ID}/artifacts/",
status=200,
json=expected_results,
)
for idx, artifact in enumerate(mock_dataset_worker.list_artifacts(task_id=TASK_ID)):
assert isinstance(artifact, Artifact)
assert artifact == expected_results[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
]
@pytest.mark.parametrize(
"payload, error",
(
# Task ID
(
{"task_id": None},
"task_id shouldn't be null and should be an UUID",
),
(
{"task_id": "12341234-1234-1234-1234-123412341234"},
"task_id shouldn't be null and should be an UUID",
),
),
)
def test_download_artifact_wrong_param_task_id(
mock_dataset_worker, default_artifact, payload, error
):
api_payload = {
"task_id": TASK_ID,
"artifact": default_artifact,
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
**payload,
}
with pytest.raises(AssertionError, match=error):
mock_dataset_worker.download_artifact(**api_payload)
@pytest.mark.parametrize(
"payload, error",
(
# Artifact
(
{"artifact": None},
"artifact shouldn't be null and should be an Artifact",
),
(
{"artifact": "not artifact type"},
"artifact shouldn't be null and should be an Artifact",
),
),
)
def test_download_artifact_wrong_param_artifact(
mock_dataset_worker, default_artifact, payload, error
):
api_payload = {
"task_id": TASK_ID,
"artifact": default_artifact,
**payload,
}
with pytest.raises(AssertionError, match=error):
mock_dataset_worker.download_artifact(**api_payload)
def test_download_artifact_api_error(responses, mock_dataset_worker, default_artifact):
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd",
status=500,
)
with pytest.raises(ErrorResponse):
mock_dataset_worker.download_artifact(
task_id=TASK_ID, artifact=default_artifact
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
]
def test_download_artifact(
responses,
mock_dataset_worker,
default_artifact,
):
archive_path = (
FIXTURES_DIR / "extract_parent_archives" / "first_parent" / "arkindex_data.zstd"
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd",
status=200,
body=archive_path.read_bytes(),
content_type="application/zstd",
)
assert (
mock_dataset_worker.download_artifact(
task_id=TASK_ID, artifact=default_artifact
).read()
== archive_path.read_bytes()
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),