Skip to content
Snippets Groups Projects
test_task.py 6.12 KiB
Newer Older
# -*- coding: utf-8 -*-
import uuid

import pytest
from apistar.exceptions import ErrorResponse

from arkindex_worker.models import Artifact
from tests.conftest import FIXTURES_DIR
from tests.test_elements_worker import BASE_API_CALLS

TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe")


@pytest.mark.parametrize(
    "payload, error",
    (
        # Task ID
        (
            {"task_id": None},
            "task_id shouldn't be null and should be an UUID",
        ),
        (
            {"task_id": "12341234-1234-1234-1234-123412341234"},
            "task_id shouldn't be null and should be an UUID",
        ),
    ),
)
def test_list_artifacts_wrong_param_task_id(mock_dataset_worker, payload, error):
    with pytest.raises(AssertionError, match=error):
        mock_dataset_worker.list_artifacts(**payload)


def test_list_artifacts_api_error(responses, mock_dataset_worker):
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/task/{TASK_ID}/artifacts/",
        status=500,
    )

    with pytest.raises(ErrorResponse):
        mock_dataset_worker.list_artifacts(task_id=TASK_ID)

    assert len(responses.calls) == len(BASE_API_CALLS) + 5
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        # The API call is retried 5 times
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
    ]


def test_list_artifacts(
    responses,
    mock_dataset_worker,
):
    expected_results = [
        {
            "id": "artifact_1",
            "path": "dataset_id.zstd",
            "size": 42,
            "content_type": "application/zstd",
            "s3_put_url": None,
            "created": "2000-01-01T00:00:00Z",
            "updated": "2000-01-01T00:00:00Z",
        },
        {
            "id": "artifact_2",
            "path": "logs.log",
            "size": 42,
            "content_type": "text/plain",
            "s3_put_url": None,
            "created": "2000-01-01T00:00:00Z",
            "updated": "2000-01-01T00:00:00Z",
        },
    ]
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/task/{TASK_ID}/artifacts/",
        status=200,
        json=expected_results,
    )

    for idx, artifact in enumerate(mock_dataset_worker.list_artifacts(task_id=TASK_ID)):
        assert isinstance(artifact, Artifact)
        assert artifact == expected_results[idx]

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
    ]


@pytest.mark.parametrize(
    "payload, error",
    (
        # Task ID
        (
            {"task_id": None},
            "task_id shouldn't be null and should be an UUID",
        ),
        (
            {"task_id": "12341234-1234-1234-1234-123412341234"},
            "task_id shouldn't be null and should be an UUID",
        ),
    ),
)
def test_download_artifact_wrong_param_task_id(
    mock_dataset_worker, default_artifact, payload, error
):
    api_payload = {
        "task_id": TASK_ID,
        "artifact": default_artifact,
        **payload,
    }

    with pytest.raises(AssertionError, match=error):
        mock_dataset_worker.download_artifact(**api_payload)


@pytest.mark.parametrize(
    "payload, error",
    (
        # Artifact
        (
            {"artifact": None},
            "artifact shouldn't be null and should be an Artifact",
        ),
        (
            {"artifact": "not artifact type"},
            "artifact shouldn't be null and should be an Artifact",
        ),
    ),
)
def test_download_artifact_wrong_param_artifact(
    mock_dataset_worker, default_artifact, payload, error
):
    api_payload = {
        "task_id": TASK_ID,
        "artifact": default_artifact,
        **payload,
    }

    with pytest.raises(AssertionError, match=error):
        mock_dataset_worker.download_artifact(**api_payload)


def test_download_artifact_api_error(responses, mock_dataset_worker, default_artifact):
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd",
        status=500,
    )

    with pytest.raises(ErrorResponse):
        mock_dataset_worker.download_artifact(
            task_id=TASK_ID, artifact=default_artifact
        )

    assert len(responses.calls) == len(BASE_API_CALLS) + 5
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        # The API call is retried 5 times
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
    ]


def test_download_artifact(
    responses,
    mock_dataset_worker,
    default_artifact,
):
    archive_path = (
        FIXTURES_DIR / "extract_parent_archives" / "first_parent" / "arkindex_data.zstd"
    )
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd",
        status=200,
        body=archive_path.read_bytes(),
        content_type="application/zstd",
    )

    assert (
        mock_dataset_worker.download_artifact(
            task_id=TASK_ID, artifact=default_artifact
        ).read()
        == archive_path.read_bytes()
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),