Skip to content
Snippets Groups Projects
Commit 7fd01367 authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Implement DatasetWorker

parent 50ca57b0
No related branches found
No related tags found
1 merge request!411Implement DatasetWorker
Pipeline #138095 passed
......@@ -8,15 +8,18 @@ import os
import sys
import uuid
from enum import Enum
from typing import Iterable, List, Union
from itertools import groupby
from operator import itemgetter
from typing import Iterable, Iterator, List, Tuple, Union
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
from arkindex_worker.models import Dataset, Element
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.classification import ClassificationMixin
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
......@@ -297,3 +300,150 @@ class ElementsWorker(
logger.debug(f"Updated activity of element {element_id} to {state}")
return True
class DatasetWorker(BaseWorker, DatasetMixin):
def __init__(
self,
description: str = "Arkindex Dataset Worker",
support_cache: bool = False,
generator: bool = False,
):
super().__init__(description, support_cache)
self.parser.add_argument(
"--dataset",
type=uuid.UUID,
nargs="+",
help="One or more Arkindex dataset ID",
)
self.generator = generator
def configure(self):
"""
Setup the worker using CLI arguments and environment variables.
"""
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
super().configure_cache()
def list_dataset_elements_per_split(
self, dataset: Dataset
) -> Iterator[Tuple[str, List[Element]]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def format_split(
split: Tuple[str, Iterator[Tuple[str, Element]]]
) -> Tuple[str, List[Element]]:
return (split[0], list(map(itemgetter(1), list(split[1]))))
return map(
format_split,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
def process_dataset(self, dataset: Dataset):
"""
Override this method to implement your worker and process a single Arkindex dataset at once.
:param dataset: The dataset to process.
"""
def list_datasets(self) -> Iterator[Dataset] | Iterator[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if self.is_read_only:
return map(str, self.args.dataset)
return self.list_process_datasets()
def run(self):
self.configure()
datasets: List[Dataset] | List[str] = list(self.list_datasets())
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(datasets)
failed = 0
for i, item in enumerate(datasets, start=1):
dataset = None
try:
if not self.is_read_only:
# Just use the result of list_datasets as the dataset
dataset = item
else:
# Load dataset using the Arkindex API
dataset = Dataset(**self.request("RetrieveDataset", id=item))
if self.generator:
assert (
dataset.state == DatasetState.Open.value
), "When generating a new dataset, its state should be Open."
else:
assert (
dataset.state == DatasetState.Complete.value
), "When processing an existing dataset, its state should be Complete."
logger.info(f"Processing {dataset} ({i}/{count})")
if self.generator:
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
# Process the dataset
self.process_dataset(dataset)
if self.generator:
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while retrieving, processing or patching the state for this dataset.
failed += 1
# Handle the case where we failed retrieving the dataset
dataset_id = dataset.id if dataset else item
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
else:
message = (
f"Failed running worker on dataset {dataset_id}: {repr(e)}"
)
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
if dataset and self.generator:
# Try to update the state to Error regardless of the response
try:
self.update_dataset_state(dataset, DatasetState.Error)
except Exception:
pass
if failed:
logger.error(
"Ran on {} datasets: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
......@@ -23,12 +23,16 @@ from arkindex_worker.cache import (
init_cache_db,
)
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.models import Dataset
from arkindex_worker.worker import BaseWorker, DatasetWorker, ElementsWorker
from arkindex_worker.worker.dataset import DatasetState
from arkindex_worker.worker.transcription import TextOrientation
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
SAMPLES_DIR = Path(__file__).resolve().parent / "samples"
PROCESS_ID = "cafecafe-cafe-cafe-cafe-cafecafecafe"
__yaml_cache = {}
......@@ -553,3 +557,58 @@ def mock_databases(tmp_path):
)
return out
@pytest.fixture
def default_dataset():
return Dataset(
**{
"id": "dataset_id",
"name": "My dataset",
"description": "A super dataset built by me",
"sets": ["set_1", "set_2", "set_3"],
"state": DatasetState.Open.value,
"corpus_id": "corpus_id",
"creator": "creator@teklia.com",
"task_id": "task_id",
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
}
)
@pytest.fixture
def mock_dataset_worker(mocker, mock_worker_run_api):
mocker.patch.object(sys, "argv", ["worker"])
dataset_worker = DatasetWorker()
dataset_worker.configure()
dataset_worker.process_information = {"id": PROCESS_ID}
assert not dataset_worker.is_read_only
return dataset_worker
@pytest.fixture
def mock_dev_dataset_worker(mocker):
mocker.patch.object(
sys,
"argv",
[
"worker",
"--dev",
"--dataset",
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
],
)
dataset_worker = DatasetWorker()
dataset_worker.configure()
assert dataset_worker.args.dev is True
assert dataset_worker.process_information is None
assert dataset_worker.is_read_only is True
return dataset_worker
import logging
import pytest
from arkindex_worker.worker.dataset import DatasetState
from tests.conftest import PROCESS_ID
from tests.test_elements_worker import BASE_API_CALLS
def test_list_dataset_elements_per_split_api_error(
responses, mock_dataset_worker, default_dataset
):
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
status=500,
)
with pytest.raises(
Exception, match="Stopping pagination as data will be incomplete"
):
mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
]
def test_list_dataset_elements_per_split(
responses, mock_dataset_worker, default_dataset
):
expected_results = [
{
"set": "set_1",
"element": {
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_1",
"element": {
"id": "1111",
"type": "page",
"name": "Test 2",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_2",
"element": {
"id": "2222",
"type": "page",
"name": "Test 3",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
{
"set": "set_3",
"element": {
"id": "3333",
"type": "page",
"name": "Test 4",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
status=200,
json={
"count": 4,
"next": None,
"results": expected_results,
},
)
assert list(
mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
) == [
("set_1", [expected_results[0]["element"], expected_results[1]["element"]]),
("set_2", [expected_results[2]["element"]]),
("set_3", [expected_results[3]["element"]]),
]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
]
def test_list_datasets_read_only(mock_dev_dataset_worker):
assert list(mock_dev_dataset_worker.list_datasets()) == [
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
]
def test_list_datasets_api_error(responses, mock_dataset_worker):
responses.add(
responses.GET,
f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
status=500,
)
with pytest.raises(
Exception, match="Stopping pagination as data will be incomplete"
):
mock_dataset_worker.list_datasets()
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
]
def test_list_datasets(responses, mock_dataset_worker):
expected_results = [
{
"id": "dataset_1",
"name": "Dataset 1",
"description": "My first great dataset",
"sets": ["train", "val", "test"],
"state": "open",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_1",
},
{
"id": "dataset_2",
"name": "Dataset 2",
"description": "My second great dataset",
"sets": ["train", "val"],
"state": "complete",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_2",
},
{
"id": "dataset_3",
"name": "Dataset 3 (TRASHME)",
"description": "My third dataset, in error",
"sets": ["nonsense", "random set"],
"state": "error",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_3",
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
status=200,
json={
"count": 3,
"next": None,
"results": expected_results,
},
)
assert list(mock_dataset_worker.list_datasets()) == expected_results
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
]
@pytest.mark.parametrize("generator", (True, False))
def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator):
mocker.patch("arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[])
mock_dataset_worker.generator = generator
with pytest.raises(SystemExit):
mock_dataset_worker.run()
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.INFO, "Loaded worker Fake worker revision deadbee from API"),
(logging.WARNING, "No datasets to process, stopping."),
]
@pytest.mark.parametrize(
"generator, error",
[
(True, "When generating a new dataset, its state should be Open."),
(False, "When processing an existing dataset, its state should be Complete."),
],
)
def test_run_initial_dataset_state_error(
mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, error
):
default_dataset.state = DatasetState.Building.value
mocker.patch(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset],
)
mock_dataset_worker.generator = generator
extra_call = []
if generator:
responses.add(
responses.PATCH,
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
status=200,
json={},
)
extra_call = [
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
]
with pytest.raises(SystemExit):
mock_dataset_worker.run()
assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_call)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS * 2 + extra_call
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.INFO, "Loaded worker Fake worker revision deadbee from API"),
(
logging.WARNING,
f"Failed running worker on dataset dataset_id: AssertionError('{error}')",
),
(
logging.ERROR,
"Ran on 1 datasets: 0 completed, 1 failed",
),
]
def test_run_update_dataset_state_api_error(
mocker, responses, caplog, mock_dataset_worker, default_dataset
):
mocker.patch(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset],
)
mock_dataset_worker.generator = True
responses.add(
responses.PATCH,
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
status=500,
)
with pytest.raises(SystemExit):
mock_dataset_worker.run()
assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 10
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS * 2 + [
# We retry 5 times the API call to update the Dataset as Building
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
# We retry 5 times the API call to update the Dataset as in Error
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
]
retries = [3.0, 4.0, 8.0, 16.0]
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.INFO, "Loaded worker Fake worker revision deadbee from API"),
(logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
(logging.INFO, "Building Dataset (dataset_id) (1/1)"),
*[
(
logging.INFO,
f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
)
for retry in retries
],
(
logging.WARNING,
"An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
),
*[
(
logging.INFO,
f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
)
for retry in retries
],
(
logging.ERROR,
"Ran on 1 datasets: 0 completed, 1 failed",
),
]
@pytest.mark.parametrize(
"generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)]
)
def test_run(
mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, state
):
mock_dataset_worker.generator = generator
default_dataset.state = state.value
mocker.patch(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset],
)
mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
extra_calls = []
extra_logs = []
if generator:
responses.add(
responses.PATCH,
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
status=200,
json={},
)
extra_calls += [
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
] * 2
extra_logs = [
(logging.INFO, "Building Dataset (dataset_id) (1/1)"),
(logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
]
mock_dataset_worker.run()
assert mock_process.call_count == 1
assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_calls)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS * 2 + extra_calls
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.INFO, "Loaded worker Fake worker revision deadbee from API"),
(logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
] + extra_logs
@pytest.mark.parametrize(
"generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)]
)
def test_run_read_only(
mocker,
responses,
caplog,
mock_dev_dataset_worker,
default_dataset,
generator,
state,
):
mock_dev_dataset_worker.generator = generator
default_dataset.state = state.value
mocker.patch(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset.id],
)
mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
status=200,
json=default_dataset,
)
extra_logs = []
if generator:
extra_logs = [
(logging.INFO, "Building Dataset (dataset_id) (1/1)"),
(
logging.WARNING,
"Cannot update dataset as this worker is in read-only mode",
),
(logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
(
logging.WARNING,
"Cannot update dataset as this worker is in read-only mode",
),
]
mock_dev_dataset_worker.run()
assert mock_process.call_count == 1
assert len(responses.calls) == 1
assert [(call.request.method, call.request.url) for call in responses.calls] == [
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
]
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.WARNING, "Running without any extra configuration"),
(logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
] + extra_logs
# -*- coding: utf-8 -*-
import json
import logging
import sys
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Dataset
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
PROCESS_ID = "cafecafe-cafe-cafe-cafe-cafecafecafe"
@pytest.fixture
def default_dataset():
return {
"id": "dataset_id",
"name": "My dataset",
"description": "A super dataset built by me",
"sets": ["set_1", "set_2", "set_3"],
"state": DatasetState.Open,
"corpus_id": "corpus_id",
"creator": "creator@teklia.com",
"task_id": "task_id",
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
}
@pytest.fixture
def mock_dataset_worker(monkeypatch, default_dataset):
class DatasetWorker(BaseWorker, DatasetMixin):
"""
This class is needed to run tests in the context of a dataset worker
"""
monkeypatch.setattr(sys, "argv", ["worker"])
dataset_worker = DatasetWorker()
dataset_worker.args = dataset_worker.parser.parse_args()
dataset_worker.process_information = {"id": PROCESS_ID}
dataset_worker.dataset = Dataset(default_dataset)
return dataset_worker
from arkindex_worker.worker.dataset import DatasetState
from tests.conftest import PROCESS_ID
from tests.test_elements_worker import BASE_API_CALLS
def test_list_process_datasets_readonly_error(mock_dataset_worker):
......@@ -67,8 +33,10 @@ def test_list_process_datasets_api_error(responses, mock_dataset_worker):
):
next(mock_dataset_worker.list_process_datasets())
assert len(responses.calls) == 5
assert [(call.request.method, call.request.url) for call in responses.calls] == [
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
......@@ -129,8 +97,10 @@ def test_list_process_datasets(
assert isinstance(dataset, Dataset)
assert dataset == expected_results[idx]
assert len(responses.calls) == 1
assert [(call.request.method, call.request.url) for call in responses.calls] == [
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
]
......@@ -155,36 +125,38 @@ def test_list_dataset_elements_wrong_param_dataset(mock_dataset_worker, payload,
assert str(e.value) == error
def test_list_dataset_elements_api_error(responses, mock_dataset_worker):
dataset = mock_dataset_worker.dataset
def test_list_dataset_elements_api_error(
responses, mock_dataset_worker, default_dataset
):
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{dataset.id}/elements/",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
status=500,
)
with pytest.raises(
Exception, match="Stopping pagination as data will be incomplete"
):
next(mock_dataset_worker.list_dataset_elements(dataset=dataset))
next(mock_dataset_worker.list_dataset_elements(dataset=default_dataset))
assert len(responses.calls) == 5
assert [(call.request.method, call.request.url) for call in responses.calls] == [
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/datasets/{dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
]
def test_list_dataset_elements(
responses,
mock_dataset_worker,
default_dataset,
):
dataset = mock_dataset_worker.dataset
expected_results = [
{
"set": "set_1",
......@@ -249,7 +221,7 @@ def test_list_dataset_elements(
]
responses.add(
responses.GET,
f"http://testserver/api/v1/datasets/{dataset.id}/elements/",
f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
status=200,
json={
"count": 4,
......@@ -259,16 +231,18 @@ def test_list_dataset_elements(
)
for idx, element in enumerate(
mock_dataset_worker.list_dataset_elements(dataset=dataset)
mock_dataset_worker.list_dataset_elements(dataset=default_dataset)
):
assert element == (
expected_results[idx]["set"],
expected_results[idx]["element"],
)
assert len(responses.calls) == 1
assert [(call.request.method, call.request.url) for call in responses.calls] == [
("GET", f"http://testserver/api/v1/datasets/{dataset.id}/elements/"),
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/"),
]
......@@ -286,9 +260,11 @@ def test_list_dataset_elements(
),
),
)
def test_update_dataset_state_wrong_param_dataset(mock_dataset_worker, payload, error):
def test_update_dataset_state_wrong_param_dataset(
mock_dataset_worker, default_dataset, payload, error
):
api_payload = {
"dataset": mock_dataset_worker.dataset,
"dataset": Dataset(**default_dataset),
"state": DatasetState.Building,
**payload,
}
......@@ -312,9 +288,11 @@ def test_update_dataset_state_wrong_param_dataset(mock_dataset_worker, payload,
),
),
)
def test_update_dataset_state_wrong_param_state(mock_dataset_worker, payload, error):
def test_update_dataset_state_wrong_param_state(
mock_dataset_worker, default_dataset, payload, error
):
api_payload = {
"dataset": mock_dataset_worker.dataset,
"dataset": Dataset(**default_dataset),
"state": DatasetState.Building,
**payload,
}
......@@ -324,17 +302,15 @@ def test_update_dataset_state_wrong_param_state(mock_dataset_worker, payload, er
assert str(e.value) == error
def test_update_dataset_state_readonly_error(caplog, mock_dataset_worker):
# Set worker in read_only mode
mock_dataset_worker.worker_run_id = None
assert mock_dataset_worker.is_read_only
def test_update_dataset_state_readonly_error(
caplog, mock_dev_dataset_worker, default_dataset
):
api_payload = {
"dataset": mock_dataset_worker.dataset,
"dataset": Dataset(**default_dataset),
"state": DatasetState.Building,
}
assert not mock_dataset_worker.update_dataset_state(**api_payload)
assert not mock_dev_dataset_worker.update_dataset_state(**api_payload)
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(
logging.WARNING,
......@@ -343,28 +319,31 @@ def test_update_dataset_state_readonly_error(caplog, mock_dataset_worker):
]
def test_update_dataset_state_api_error(responses, mock_dataset_worker):
dataset = mock_dataset_worker.dataset
def test_update_dataset_state_api_error(
responses, mock_dataset_worker, default_dataset
):
responses.add(
responses.PATCH,
f"http://testserver/api/v1/datasets/{dataset.id}/",
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_dataset_worker.update_dataset_state(
dataset=dataset,
dataset=default_dataset,
state=DatasetState.Building,
)
assert len(responses.calls) == 5
assert [(call.request.method, call.request.url) for call in responses.calls] == [
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
("PATCH", f"http://testserver/api/v1/datasets/{dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
]
......@@ -373,8 +352,6 @@ def test_update_dataset_state(
mock_dataset_worker,
default_dataset,
):
dataset = mock_dataset_worker.dataset
dataset_response = {
"name": "My dataset",
"description": "A super dataset built by me",
......@@ -383,21 +360,23 @@ def test_update_dataset_state(
}
responses.add(
responses.PATCH,
f"http://testserver/api/v1/datasets/{dataset.id}/",
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
status=200,
json=dataset_response,
)
updated_dataset = mock_dataset_worker.update_dataset_state(
dataset=dataset,
dataset=default_dataset,
state=DatasetState.Building,
)
assert len(responses.calls) == 1
assert [(call.request.method, call.request.url) for call in responses.calls] == [
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"PATCH",
f"http://testserver/api/v1/datasets/{dataset.id}/",
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
),
]
assert json.loads(responses.calls[-1].request.body) == {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment