diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index efa40bc1a2c8df21d2310a8ef2d3663ab5eb10bb..18cde62edd6def777b4ab88be7007ca3c5e3f33f 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -10,6 +10,7 @@ import uuid from enum import Enum from itertools import groupby from operator import itemgetter +from pathlib import Path from typing import Iterable, Iterator, List, Tuple, Union from apistar.exceptions import ErrorResponse @@ -23,6 +24,7 @@ from arkindex_worker.worker.dataset import DatasetMixin, DatasetState from arkindex_worker.worker.element import ElementMixin from arkindex_worker.worker.entity import EntityMixin # noqa: F401 from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401 +from arkindex_worker.worker.task import TaskMixin from arkindex_worker.worker.transcription import TranscriptionMixin from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401 @@ -302,13 +304,32 @@ class ElementsWorker( return True -class DatasetWorker(BaseWorker, DatasetMixin): +class MissingDatasetArchive(Exception): + """ + Exception raised when the compressed `.zstd` archive associated to + a dataset isn't found in its task artifacts. + """ + + +class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): + """ + Base class for ML workers that operate on Arkindex datasets. + + This class inherits from numerous mixin classes found in other modules of + ``arkindex.worker``, which provide helpers to read and write to the Arkindex API. + """ + def __init__( self, description: str = "Arkindex Dataset Worker", support_cache: bool = False, generator: bool = False, ): + """ + :param description: The worker's description. + :param support_cache: Whether the worker supports cache. + :param generator: Whether the worker generates the dataset archive artifact. + """ super().__init__(description, support_cache) self.parser.add_argument( @@ -333,11 +354,42 @@ class DatasetWorker(BaseWorker, DatasetMixin): super().configure() super().configure_cache() + def download_dataset_artifact(self, dataset: Dataset) -> Path: + """ + Find and download the compressed archive artifact describing a dataset using + the [list_artifacts][arkindex_worker.worker.task.TaskMixin.list_artifacts] and + [download_artifact][arkindex_worker.worker.task.TaskMixin.download_artifact] methods. + + :param dataset: The dataset to retrieve the compressed archive artifact for. + :returns: A path to the downloaded artifact. + :raises MissingDatasetArchive: When the dataset artifact is not found. + """ + + task_id = uuid.UUID(dataset.task_id) + archive_name = f"{dataset.id}.zstd" + + for artifact in self.list_artifacts(task_id): + if artifact.path != archive_name: + continue + + extra_dir = self.find_extras_directory() + archive = extra_dir / archive_name + archive.write_bytes(self.download_artifact(task_id, artifact).read()) + return archive + + raise MissingDatasetArchive( + "The dataset compressed archive artifact was not found." + ) + def list_dataset_elements_per_split( self, dataset: Dataset ) -> Iterator[Tuple[str, List[Element]]]: """ - Calls `list_dataset_elements` but returns results grouped by Set + List the elements in the dataset, grouped by split, using the + [list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method. + + :param dataset: The dataset to retrieve elements from. + :returns: An iterator of tuples containing the split name and the list of its elements. """ def format_split( @@ -362,8 +414,11 @@ class DatasetWorker(BaseWorker, DatasetMixin): def list_datasets(self) -> Iterator[Dataset] | Iterator[str]: """ - Calls `list_process_datasets` if not is_read_only, - else simply give the list of IDs provided via CLI + List the datasets to be processed, either from the CLI arguments or using the + [list_process_datasets][arkindex_worker.worker.dataset.DatasetMixin.list_process_datasets] method. + + :returns: An iterator of strings if the worker is in read-only mode, + else an iterator of ``Dataset`` objects. """ if self.is_read_only: return map(str, self.args.dataset) @@ -371,6 +426,14 @@ class DatasetWorker(BaseWorker, DatasetMixin): return self.list_process_datasets() def run(self): + """ + Implements an Arkindex worker that goes through each dataset returned by + [list_datasets][arkindex_worker.worker.DatasetWorker.list_datasets]. + + It calls [process_dataset][arkindex_worker.worker.DatasetWorker.process_dataset], + catching exceptions, and handles updating the [DatasetState][arkindex_worker.worker.dataset.DatasetState] + when the worker is a generator. + """ self.configure() datasets: List[Dataset] | List[str] = list(self.list_datasets()) @@ -406,6 +469,9 @@ class DatasetWorker(BaseWorker, DatasetMixin): # Update the dataset state to Building logger.info(f"Building {dataset} ({i}/{count})") self.update_dataset_state(dataset, DatasetState.Building) + else: + logger.info(f"Downloading data for {dataset} ({i}/{count})") + self.download_dataset_artifact(dataset) # Process the dataset self.process_dataset(dataset) diff --git a/docs/ref/api/task.md b/docs/ref/api/task.md index d66570361eef55719e65c2dcc9d87de7b22d36e9..257e978a02ec952d996860f7d4c45e3eef337ca1 100644 --- a/docs/ref/api/task.md +++ b/docs/ref/api/task.md @@ -3,6 +3,8 @@ ::: arkindex_worker.worker.task options: members: no + options: + show_category_heading: no ::: arkindex_worker.worker.task.TaskMixin options: diff --git a/tests/conftest.py b/tests/conftest.py index 40a6f831b7ba9f21d537261463bbe6b59af3d1d1..d01ee23e95df670d1dd1afbc593343b477501bd4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from arkindex_worker.cache import ( init_cache_db, ) from arkindex_worker.git import GitHelper, GitlabHelper -from arkindex_worker.models import Dataset +from arkindex_worker.models import Artifact, Dataset from arkindex_worker.worker import BaseWorker, DatasetWorker, ElementsWorker from arkindex_worker.worker.dataset import DatasetState from arkindex_worker.worker.transcription import TextOrientation @@ -570,7 +570,7 @@ def default_dataset(): "state": DatasetState.Open.value, "corpus_id": "corpus_id", "creator": "creator@teklia.com", - "task_id": "task_id", + "task_id": "11111111-1111-1111-1111-111111111111", "created": "2000-01-01T00:00:00Z", "updated": "2000-01-01T00:00:00Z", } @@ -578,7 +578,8 @@ def default_dataset(): @pytest.fixture -def mock_dataset_worker(mocker, mock_worker_run_api): +def mock_dataset_worker(monkeypatch, mocker, mock_worker_run_api): + monkeypatch.setenv("PONOS_TASK", "my_task") mocker.patch.object(sys, "argv", ["worker"]) dataset_worker = DatasetWorker() @@ -612,3 +613,18 @@ def mock_dev_dataset_worker(mocker): assert dataset_worker.is_read_only is True return dataset_worker + + +@pytest.fixture +def default_artifact(): + return Artifact( + **{ + "id": "artifact_id", + "path": "dataset_id.zstd", + "size": 42, + "content_type": "application/zstd", + "s3_put_url": None, + "created": "2000-01-01T00:00:00Z", + "updated": "2000-01-01T00:00:00Z", + } + ) diff --git a/tests/test_dataset_worker.py b/tests/test_dataset_worker.py index 5e0169edfc4e81d24a5bfa29dd978ca282e6a866..944ccefb247703ea98b15c2f2348044dce123397 100644 --- a/tests/test_dataset_worker.py +++ b/tests/test_dataset_worker.py @@ -1,12 +1,192 @@ import logging import pytest +from apistar.exceptions import ErrorResponse +from arkindex_worker.worker import MissingDatasetArchive from arkindex_worker.worker.dataset import DatasetState -from tests.conftest import PROCESS_ID +from tests.conftest import FIXTURES_DIR, PROCESS_ID from tests.test_elements_worker import BASE_API_CALLS +def test_download_dataset_artifact_list_api_error( + responses, mock_dataset_worker, default_dataset +): + task_id = default_dataset.task_id + + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{task_id}/artifacts/", + status=500, + ) + + with pytest.raises(ErrorResponse): + mock_dataset_worker.download_dataset_artifact(default_dataset) + + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + # The API call is retried 5 times + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + ] + + +def test_download_dataset_artifact_download_api_error( + responses, mock_dataset_worker, default_dataset +): + task_id = default_dataset.task_id + + expected_results = [ + { + "id": "artifact_1", + "path": "dataset_id.zstd", + "size": 42, + "content_type": "application/zstd", + "s3_put_url": None, + "created": "2000-01-01T00:00:00Z", + "updated": "2000-01-01T00:00:00Z", + }, + { + "id": "artifact_2", + "path": "logs.log", + "size": 42, + "content_type": "text/plain", + "s3_put_url": None, + "created": "2000-01-01T00:00:00Z", + "updated": "2000-01-01T00:00:00Z", + }, + ] + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{task_id}/artifacts/", + status=200, + json=expected_results, + ) + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd", + status=500, + ) + + with pytest.raises(ErrorResponse): + mock_dataset_worker.download_dataset_artifact(default_dataset) + + assert len(responses.calls) == len(BASE_API_CALLS) + 6 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + # The API call is retried 5 times + ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"), + ] + + +def test_download_dataset_artifact_no_archive( + responses, mock_dataset_worker, default_dataset +): + task_id = default_dataset.task_id + + expected_results = [ + { + "id": "artifact_id", + "path": "logs.log", + "size": 42, + "content_type": "text/plain", + "s3_put_url": None, + "created": "2000-01-01T00:00:00Z", + "updated": "2000-01-01T00:00:00Z", + }, + ] + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{task_id}/artifacts/", + status=200, + json=expected_results, + ) + + with pytest.raises( + MissingDatasetArchive, + match="The dataset compressed archive artifact was not found.", + ): + mock_dataset_worker.download_dataset_artifact(default_dataset) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + ] + + +def test_download_dataset_artifact( + mocker, tmp_path, responses, mock_dataset_worker, default_dataset +): + task_id = default_dataset.task_id + archive_path = ( + FIXTURES_DIR / "extract_parent_archives" / "first_parent" / "arkindex_data.zstd" + ) + mocker.patch( + "arkindex_worker.worker.base.BaseWorker.find_extras_directory", + return_value=tmp_path, + ) + + expected_results = [ + { + "id": "artifact_1", + "path": "dataset_id.zstd", + "size": 42, + "content_type": "application/zstd", + "s3_put_url": None, + "created": "2000-01-01T00:00:00Z", + "updated": "2000-01-01T00:00:00Z", + }, + { + "id": "artifact_2", + "path": "logs.log", + "size": 42, + "content_type": "text/plain", + "s3_put_url": None, + "created": "2000-01-01T00:00:00Z", + "updated": "2000-01-01T00:00:00Z", + }, + ] + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{task_id}/artifacts/", + status=200, + json=expected_results, + ) + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd", + status=200, + body=archive_path.read_bytes(), + content_type="application/zstd", + ) + + archive = mock_dataset_worker.download_dataset_artifact(default_dataset) + assert archive == tmp_path / "dataset_id.zstd" + assert archive.read_bytes() == archive_path.read_bytes() + archive.unlink() + + assert len(responses.calls) == len(BASE_API_CALLS) + 2 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"), + ] + + def test_list_dataset_elements_per_split_api_error( responses, mock_dataset_worker, default_dataset ): @@ -342,11 +522,132 @@ def test_run_update_dataset_state_api_error( ] +def test_run_download_dataset_artifact_api_error( + mocker, + tmp_path, + responses, + caplog, + mock_dataset_worker, + default_dataset, +): + default_dataset.state = DatasetState.Complete.value + + mocker.patch( + "arkindex_worker.worker.DatasetWorker.list_datasets", + return_value=[default_dataset], + ) + mocker.patch( + "arkindex_worker.worker.base.BaseWorker.find_extras_directory", + return_value=tmp_path, + ) + + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/", + status=500, + ) + + with pytest.raises(SystemExit): + mock_dataset_worker.run() + + assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS * 2 + [ + # We retry 5 times the API call + ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"), + ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"), + ] + + assert [(level, message) for _, level, message in caplog.record_tuples] == [ + (logging.INFO, "Loaded worker Fake worker revision deadbee from API"), + (logging.INFO, "Processing Dataset (dataset_id) (1/1)"), + (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"), + *[ + ( + logging.INFO, + f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .", + ) + for retry in [3.0, 4.0, 8.0, 16.0] + ], + ( + logging.WARNING, + "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None", + ), + ( + logging.ERROR, + "Ran on 1 datasets: 0 completed, 1 failed", + ), + ] + + +def test_run_no_downloaded_artifact_error( + mocker, + tmp_path, + responses, + caplog, + mock_dataset_worker, + default_dataset, +): + default_dataset.state = DatasetState.Complete.value + + mocker.patch( + "arkindex_worker.worker.DatasetWorker.list_datasets", + return_value=[default_dataset], + ) + mocker.patch( + "arkindex_worker.worker.base.BaseWorker.find_extras_directory", + return_value=tmp_path, + ) + + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/", + status=200, + json={}, + ) + + with pytest.raises(SystemExit): + mock_dataset_worker.run() + + assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS * 2 + [ + ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"), + ] + + assert [(level, message) for _, level, message in caplog.record_tuples] == [ + (logging.INFO, "Loaded worker Fake worker revision deadbee from API"), + (logging.INFO, "Processing Dataset (dataset_id) (1/1)"), + (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"), + ( + logging.WARNING, + "Failed running worker on dataset dataset_id: MissingDatasetArchive('The dataset compressed archive artifact was not found.')", + ), + ( + logging.ERROR, + "Ran on 1 datasets: 0 completed, 1 failed", + ), + ] + + @pytest.mark.parametrize( "generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)] ) def test_run( - mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, state + mocker, + tmp_path, + responses, + caplog, + mock_dataset_worker, + default_dataset, + default_artifact, + generator, + state, ): mock_dataset_worker.generator = generator default_dataset.state = state.value @@ -355,6 +656,10 @@ def test_run( "arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[default_dataset], ) + mocker.patch( + "arkindex_worker.worker.base.BaseWorker.find_extras_directory", + return_value=tmp_path, + ) mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset") extra_calls = [] @@ -369,10 +674,43 @@ def test_run( extra_calls += [ ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"), ] * 2 - extra_logs = [ + extra_logs += [ (logging.INFO, "Building Dataset (dataset_id) (1/1)"), (logging.INFO, "Completed Dataset (dataset_id) (1/1)"), ] + else: + archive_path = ( + FIXTURES_DIR + / "extract_parent_archives" + / "first_parent" + / "arkindex_data.zstd" + ) + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/", + status=200, + json=[default_artifact], + ) + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd", + status=200, + body=archive_path.read_bytes(), + content_type="application/zstd", + ) + extra_calls += [ + ( + "GET", + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/", + ), + ( + "GET", + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd", + ), + ] + extra_logs += [ + (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"), + ] mock_dataset_worker.run() @@ -394,10 +732,12 @@ def test_run( ) def test_run_read_only( mocker, + tmp_path, responses, caplog, mock_dev_dataset_worker, default_dataset, + default_artifact, generator, state, ): @@ -408,6 +748,10 @@ def test_run_read_only( "arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[default_dataset.id], ) + mocker.patch( + "arkindex_worker.worker.base.BaseWorker.find_extras_directory", + return_value=tmp_path, + ) mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset") responses.add( @@ -417,9 +761,10 @@ def test_run_read_only( json=default_dataset, ) + extra_calls = [] extra_logs = [] if generator: - extra_logs = [ + extra_logs += [ (logging.INFO, "Building Dataset (dataset_id) (1/1)"), ( logging.WARNING, @@ -431,15 +776,48 @@ def test_run_read_only( "Cannot update dataset as this worker is in read-only mode", ), ] + else: + archive_path = ( + FIXTURES_DIR + / "extract_parent_archives" + / "first_parent" + / "arkindex_data.zstd" + ) + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/", + status=200, + json=[default_artifact], + ) + responses.add( + responses.GET, + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd", + status=200, + body=archive_path.read_bytes(), + content_type="application/zstd", + ) + extra_calls += [ + ( + "GET", + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/", + ), + ( + "GET", + f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd", + ), + ] + extra_logs += [ + (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"), + ] mock_dev_dataset_worker.run() assert mock_process.call_count == 1 - assert len(responses.calls) == 1 + assert len(responses.calls) == 1 + len(extra_calls) assert [(call.request.method, call.request.url) for call in responses.calls] == [ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/") - ] + ] + extra_calls assert [(level, message) for _, level, message in caplog.record_tuples] == [ (logging.WARNING, "Running without any extra configuration"), diff --git a/tests/test_elements_worker/test_task.py b/tests/test_elements_worker/test_task.py index 754e467ecc7ef79d07820784d285fbbf3976e060..6d60b7b16efbd58ac28478d192a5957e50053f60 100644 --- a/tests/test_elements_worker/test_task.py +++ b/tests/test_elements_worker/test_task.py @@ -1,46 +1,16 @@ # -*- coding: utf-8 -*- -import sys import uuid import pytest from apistar.exceptions import ErrorResponse from arkindex_worker.models import Artifact -from arkindex_worker.worker import BaseWorker -from arkindex_worker.worker.dataset import DatasetMixin -from arkindex_worker.worker.task import TaskMixin from tests.conftest import FIXTURES_DIR +from tests.test_elements_worker import BASE_API_CALLS TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe") -@pytest.fixture -def default_artifact(): - return { - "id": "artifact_id", - "path": "arkindex_data.zstd", - "size": 42, - "content_type": "application/zstd", - "s3_put_url": None, - "created": "2000-01-01T00:00:00Z", - "updated": "2000-01-01T00:00:00Z", - } - - -@pytest.fixture -def mock_dataset_worker(monkeypatch): - class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): - """ - This class is needed to run tests in the context of a dataset worker - """ - - monkeypatch.setattr(sys, "argv", ["worker"]) - dataset_worker = DatasetWorker() - dataset_worker.args = dataset_worker.parser.parse_args() - - return dataset_worker - - @pytest.mark.parametrize( "payload, error", ( @@ -70,8 +40,10 @@ def test_list_artifacts_api_error(responses, mock_dataset_worker): with pytest.raises(ErrorResponse): mock_dataset_worker.list_artifacts(task_id=TASK_ID) - assert len(responses.calls) == 5 - assert [(call.request.method, call.request.url) for call in responses.calls] == [ + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ # The API call is retried 5 times ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"), ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"), @@ -116,8 +88,10 @@ def test_list_artifacts( assert isinstance(artifact, Artifact) assert artifact == expected_results[idx] - assert len(responses.calls) == 1 - assert [(call.request.method, call.request.url) for call in responses.calls] == [ + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"), ] @@ -141,7 +115,7 @@ def test_download_artifact_wrong_param_task_id( ): api_payload = { "task_id": TASK_ID, - "artifact": Artifact(default_artifact), + "artifact": default_artifact, **payload, } @@ -168,7 +142,7 @@ def test_download_artifact_wrong_param_artifact( ): api_payload = { "task_id": TASK_ID, - "artifact": Artifact(default_artifact), + "artifact": default_artifact, **payload, } @@ -177,25 +151,27 @@ def test_download_artifact_wrong_param_artifact( def test_download_artifact_api_error(responses, mock_dataset_worker, default_artifact): - artifact = Artifact(default_artifact) - responses.add( responses.GET, - f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd", + f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd", status=500, ) with pytest.raises(ErrorResponse): - mock_dataset_worker.download_artifact(task_id=TASK_ID, artifact=artifact) - - assert len(responses.calls) == 5 - assert [(call.request.method, call.request.url) for call in responses.calls] == [ + mock_dataset_worker.download_artifact( + task_id=TASK_ID, artifact=default_artifact + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ # The API call is retried 5 times - ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"), - ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"), - ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"), - ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"), - ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"), + ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"), + ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"), ] @@ -204,25 +180,27 @@ def test_download_artifact( mock_dataset_worker, default_artifact, ): - artifact = Artifact(default_artifact) - archive_path = ( FIXTURES_DIR / "extract_parent_archives" / "first_parent" / "arkindex_data.zstd" ) responses.add( responses.GET, - f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd", + f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd", status=200, body=archive_path.read_bytes(), content_type="application/zstd", ) assert ( - mock_dataset_worker.download_artifact(task_id=TASK_ID, artifact=artifact).read() + mock_dataset_worker.download_artifact( + task_id=TASK_ID, artifact=default_artifact + ).read() == archive_path.read_bytes() ) - assert len(responses.calls) == 1 - assert [(call.request.method, call.request.url) for call in responses.calls] == [ - ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"), + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"), ]