# -*- coding: utf-8 -*-
import json
import logging

import pytest
from apistar.exceptions import ErrorResponse

from arkindex_worker.models import Dataset
from arkindex_worker.worker.dataset import DatasetState
from tests.conftest import PROCESS_ID
from tests.test_elements_worker import BASE_API_CALLS


def test_list_process_datasets_readonly_error(mock_dataset_worker):
    # Set worker in read_only mode
    mock_dataset_worker.worker_run_id = None
    assert mock_dataset_worker.is_read_only

    with pytest.raises(AssertionError) as e:
        mock_dataset_worker.list_process_datasets()
    assert str(e.value) == "This helper is not available in read-only mode."


def test_list_process_datasets_api_error(responses, mock_dataset_worker):
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
        status=500,
    )

    with pytest.raises(
        Exception, match="Stopping pagination as data will be incomplete"
    ):
        next(mock_dataset_worker.list_process_datasets())

    assert len(responses.calls) == len(BASE_API_CALLS) + 5
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        # The API call is retried 5 times
        ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
        ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
        ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
        ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
        ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
    ]


def test_list_process_datasets(
    responses,
    mock_dataset_worker,
):
    expected_results = [
        {
            "id": "dataset_1",
            "name": "Dataset 1",
            "description": "My first great dataset",
            "sets": ["train", "val", "test"],
            "state": "open",
            "corpus_id": "corpus_id",
            "creator": "test@teklia.com",
            "task_id": "task_id_1",
        },
        {
            "id": "dataset_2",
            "name": "Dataset 2",
            "description": "My second great dataset",
            "sets": ["train", "val"],
            "state": "complete",
            "corpus_id": "corpus_id",
            "creator": "test@teklia.com",
            "task_id": "task_id_2",
        },
        {
            "id": "dataset_3",
            "name": "Dataset 3 (TRASHME)",
            "description": "My third dataset, in error",
            "sets": ["nonsense", "random set"],
            "state": "error",
            "corpus_id": "corpus_id",
            "creator": "test@teklia.com",
            "task_id": "task_id_3",
        },
    ]
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
        status=200,
        json={
            "count": 3,
            "next": None,
            "results": expected_results,
        },
    )

    for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
        assert isinstance(dataset, Dataset)
        assert dataset == expected_results[idx]

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
    ]


@pytest.mark.parametrize(
    "payload, error",
    (
        # Dataset
        (
            {"dataset": None},
            "dataset shouldn't be null and should be a Dataset",
        ),
        (
            {"dataset": "not Dataset type"},
            "dataset shouldn't be null and should be a Dataset",
        ),
    ),
)
def test_list_dataset_elements_wrong_param_dataset(mock_dataset_worker, payload, error):
    with pytest.raises(AssertionError) as e:
        mock_dataset_worker.list_dataset_elements(**payload)
    assert str(e.value) == error


def test_list_dataset_elements_api_error(
    responses, mock_dataset_worker, default_dataset
):
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
        status=500,
    )

    with pytest.raises(
        Exception, match="Stopping pagination as data will be incomplete"
    ):
        next(mock_dataset_worker.list_dataset_elements(dataset=default_dataset))

    assert len(responses.calls) == len(BASE_API_CALLS) + 5
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        # The API call is retried 5 times
        ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
        ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
        ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
        ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
        ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
    ]


def test_list_dataset_elements(
    responses,
    mock_dataset_worker,
    default_dataset,
):
    expected_results = [
        {
            "set": "set_1",
            "element": {
                "id": "0000",
                "type": "page",
                "name": "Test",
                "corpus": {},
                "thumbnail_url": None,
                "zone": {},
                "best_classes": None,
                "has_children": None,
                "worker_version_id": None,
                "worker_run_id": None,
            },
        },
        {
            "set": "set_1",
            "element": {
                "id": "1111",
                "type": "page",
                "name": "Test 2",
                "corpus": {},
                "thumbnail_url": None,
                "zone": {},
                "best_classes": None,
                "has_children": None,
                "worker_version_id": None,
                "worker_run_id": None,
            },
        },
        {
            "set": "set_2",
            "element": {
                "id": "2222",
                "type": "page",
                "name": "Test 3",
                "corpus": {},
                "thumbnail_url": None,
                "zone": {},
                "best_classes": None,
                "has_children": None,
                "worker_version_id": None,
                "worker_run_id": None,
            },
        },
        {
            "set": "set_3",
            "element": {
                "id": "3333",
                "type": "page",
                "name": "Test 4",
                "corpus": {},
                "thumbnail_url": None,
                "zone": {},
                "best_classes": None,
                "has_children": None,
                "worker_version_id": None,
                "worker_run_id": None,
            },
        },
    ]
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
        status=200,
        json={
            "count": 4,
            "next": None,
            "results": expected_results,
        },
    )

    for idx, element in enumerate(
        mock_dataset_worker.list_dataset_elements(dataset=default_dataset)
    ):
        assert element == (
            expected_results[idx]["set"],
            expected_results[idx]["element"],
        )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
    ]


@pytest.mark.parametrize(
    "payload, error",
    (
        # Dataset
        (
            {"dataset": None},
            "dataset shouldn't be null and should be a Dataset",
        ),
        (
            {"dataset": "not dataset type"},
            "dataset shouldn't be null and should be a Dataset",
        ),
    ),
)
def test_update_dataset_state_wrong_param_dataset(
    mock_dataset_worker, default_dataset, payload, error
):
    api_payload = {
        "dataset": Dataset(**default_dataset),
        "state": DatasetState.Building,
        **payload,
    }

    with pytest.raises(AssertionError) as e:
        mock_dataset_worker.update_dataset_state(**api_payload)
    assert str(e.value) == error


@pytest.mark.parametrize(
    "payload, error",
    (
        # DatasetState
        (
            {"state": None},
            "state shouldn't be null and should be a str from DatasetState",
        ),
        (
            {"state": "not dataset type"},
            "state shouldn't be null and should be a str from DatasetState",
        ),
    ),
)
def test_update_dataset_state_wrong_param_state(
    mock_dataset_worker, default_dataset, payload, error
):
    api_payload = {
        "dataset": Dataset(**default_dataset),
        "state": DatasetState.Building,
        **payload,
    }

    with pytest.raises(AssertionError) as e:
        mock_dataset_worker.update_dataset_state(**api_payload)
    assert str(e.value) == error


def test_update_dataset_state_readonly_error(
    caplog, mock_dev_dataset_worker, default_dataset
):
    api_payload = {
        "dataset": Dataset(**default_dataset),
        "state": DatasetState.Building,
    }

    assert not mock_dev_dataset_worker.update_dataset_state(**api_payload)
    assert [(level, message) for _, level, message in caplog.record_tuples] == [
        (
            logging.WARNING,
            "Cannot update dataset as this worker is in read-only mode",
        ),
    ]


def test_update_dataset_state_api_error(
    responses, mock_dataset_worker, default_dataset
):
    responses.add(
        responses.PATCH,
        f"http://testserver/api/v1/datasets/{default_dataset.id}/",
        status=500,
    )

    with pytest.raises(ErrorResponse):
        mock_dataset_worker.update_dataset_state(
            dataset=default_dataset,
            state=DatasetState.Building,
        )

    assert len(responses.calls) == len(BASE_API_CALLS) + 5
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        # We retry 5 times the API call
        ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
        ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
        ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
        ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
        ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
    ]


def test_update_dataset_state(
    responses,
    mock_dataset_worker,
    default_dataset,
):
    dataset_response = {
        "name": "My dataset",
        "description": "A super dataset built by me",
        "sets": ["set_1", "set_2", "set_3"],
        "state": DatasetState.Building.value,
    }
    responses.add(
        responses.PATCH,
        f"http://testserver/api/v1/datasets/{default_dataset.id}/",
        status=200,
        json=dataset_response,
    )

    updated_dataset = mock_dataset_worker.update_dataset_state(
        dataset=default_dataset,
        state=DatasetState.Building,
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "PATCH",
            f"http://testserver/api/v1/datasets/{default_dataset.id}/",
        ),
    ]
    assert json.loads(responses.calls[-1].request.body) == {
        "state": DatasetState.Building.value
    }
    assert updated_dataset == Dataset(**{**default_dataset, **dataset_response})