Skip to content
Snippets Groups Projects
Commit 598f083a authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Download the dataset's compressed archive during processing

parent bf76f8d2
No related branches found
No related tags found
1 merge request!420Download the dataset's compressed archive during processing
Pipeline #138598 passed
......@@ -10,6 +10,7 @@ import uuid
from enum import Enum
from itertools import groupby
from operator import itemgetter
from pathlib import Path
from typing import Iterable, Iterator, List, Tuple, Union
from apistar.exceptions import ErrorResponse
......@@ -23,6 +24,7 @@ from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.task import TaskMixin
from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
......@@ -302,13 +304,32 @@ class ElementsWorker(
return True
class DatasetWorker(BaseWorker, DatasetMixin):
class MissingDatasetArchive(Exception):
"""
Exception raised when the compressed `.zstd` archive associated to
a dataset isn't found in its task artifacts.
"""
class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
"""
Base class for ML workers that operate on Arkindex datasets.
This class inherits from numerous mixin classes found in other modules of
``arkindex.worker``, which provide helpers to read and write to the Arkindex API.
"""
def __init__(
self,
description: str = "Arkindex Dataset Worker",
support_cache: bool = False,
generator: bool = False,
):
"""
:param description: The worker's description.
:param support_cache: Whether the worker supports cache.
:param generator: Whether the worker generates the dataset archive artifact.
"""
super().__init__(description, support_cache)
self.parser.add_argument(
......@@ -333,11 +354,42 @@ class DatasetWorker(BaseWorker, DatasetMixin):
super().configure()
super().configure_cache()
def download_dataset_artifact(self, dataset: Dataset) -> Path:
"""
Find and download the compressed archive artifact describing a dataset using
the [list_artifacts][arkindex_worker.worker.task.TaskMixin.list_artifacts] and
[download_artifact][arkindex_worker.worker.task.TaskMixin.download_artifact] methods.
:param dataset: The dataset to retrieve the compressed archive artifact for.
:returns: A path to the downloaded artifact.
:raises MissingDatasetArchive: When the dataset artifact is not found.
"""
task_id = uuid.UUID(dataset.task_id)
archive_name = f"{dataset.id}.zstd"
for artifact in self.list_artifacts(task_id):
if artifact.path != archive_name:
continue
extra_dir = self.find_extras_directory()
archive = extra_dir / archive_name
archive.write_bytes(self.download_artifact(task_id, artifact).read())
return archive
raise MissingDatasetArchive(
"The dataset compressed archive artifact was not found."
)
def list_dataset_elements_per_split(
self, dataset: Dataset
) -> Iterator[Tuple[str, List[Element]]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
List the elements in the dataset, grouped by split, using the
[list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method.
:param dataset: The dataset to retrieve elements from.
:returns: An iterator of tuples containing the split name and the list of its elements.
"""
def format_split(
......@@ -362,8 +414,11 @@ class DatasetWorker(BaseWorker, DatasetMixin):
def list_datasets(self) -> Iterator[Dataset] | Iterator[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
List the datasets to be processed, either from the CLI arguments or using the
[list_process_datasets][arkindex_worker.worker.dataset.DatasetMixin.list_process_datasets] method.
:returns: An iterator of strings if the worker is in read-only mode,
else an iterator of ``Dataset`` objects.
"""
if self.is_read_only:
return map(str, self.args.dataset)
......@@ -371,6 +426,14 @@ class DatasetWorker(BaseWorker, DatasetMixin):
return self.list_process_datasets()
def run(self):
"""
Implements an Arkindex worker that goes through each dataset returned by
[list_datasets][arkindex_worker.worker.DatasetWorker.list_datasets].
It calls [process_dataset][arkindex_worker.worker.DatasetWorker.process_dataset],
catching exceptions, and handles updating the [DatasetState][arkindex_worker.worker.dataset.DatasetState]
when the worker is a generator.
"""
self.configure()
datasets: List[Dataset] | List[str] = list(self.list_datasets())
......@@ -406,6 +469,9 @@ class DatasetWorker(BaseWorker, DatasetMixin):
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
else:
logger.info(f"Downloading data for {dataset} ({i}/{count})")
self.download_dataset_artifact(dataset)
# Process the dataset
self.process_dataset(dataset)
......
......@@ -3,6 +3,8 @@
::: arkindex_worker.worker.task
options:
members: no
options:
show_category_heading: no
::: arkindex_worker.worker.task.TaskMixin
options:
......
......@@ -23,7 +23,7 @@ from arkindex_worker.cache import (
init_cache_db,
)
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.models import Dataset
from arkindex_worker.models import Artifact, Dataset
from arkindex_worker.worker import BaseWorker, DatasetWorker, ElementsWorker
from arkindex_worker.worker.dataset import DatasetState
from arkindex_worker.worker.transcription import TextOrientation
......@@ -570,7 +570,7 @@ def default_dataset():
"state": DatasetState.Open.value,
"corpus_id": "corpus_id",
"creator": "creator@teklia.com",
"task_id": "task_id",
"task_id": "11111111-1111-1111-1111-111111111111",
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
}
......@@ -578,7 +578,8 @@ def default_dataset():
@pytest.fixture
def mock_dataset_worker(mocker, mock_worker_run_api):
def mock_dataset_worker(monkeypatch, mocker, mock_worker_run_api):
monkeypatch.setenv("PONOS_TASK", "my_task")
mocker.patch.object(sys, "argv", ["worker"])
dataset_worker = DatasetWorker()
......@@ -612,3 +613,18 @@ def mock_dev_dataset_worker(mocker):
assert dataset_worker.is_read_only is True
return dataset_worker
@pytest.fixture
def default_artifact():
return Artifact(
**{
"id": "artifact_id",
"path": "dataset_id.zstd",
"size": 42,
"content_type": "application/zstd",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
}
)
import logging
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.worker import MissingDatasetArchive
from arkindex_worker.worker.dataset import DatasetState
from tests.conftest import PROCESS_ID
from tests.conftest import FIXTURES_DIR, PROCESS_ID
from tests.test_elements_worker import BASE_API_CALLS
def test_download_dataset_artifact_list_api_error(
responses, mock_dataset_worker, default_dataset
):
task_id = default_dataset.task_id
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{task_id}/artifacts/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_dataset_worker.download_dataset_artifact(default_dataset)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
]
def test_download_dataset_artifact_download_api_error(
responses, mock_dataset_worker, default_dataset
):
task_id = default_dataset.task_id
expected_results = [
{
"id": "artifact_1",
"path": "dataset_id.zstd",
"size": 42,
"content_type": "application/zstd",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
},
{
"id": "artifact_2",
"path": "logs.log",
"size": 42,
"content_type": "text/plain",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{task_id}/artifacts/",
status=200,
json=expected_results,
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd",
status=500,
)
with pytest.raises(ErrorResponse):
mock_dataset_worker.download_dataset_artifact(default_dataset)
assert len(responses.calls) == len(BASE_API_CALLS) + 6
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"),
]
def test_download_dataset_artifact_no_archive(
responses, mock_dataset_worker, default_dataset
):
task_id = default_dataset.task_id
expected_results = [
{
"id": "artifact_id",
"path": "logs.log",
"size": 42,
"content_type": "text/plain",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{task_id}/artifacts/",
status=200,
json=expected_results,
)
with pytest.raises(
MissingDatasetArchive,
match="The dataset compressed archive artifact was not found.",
):
mock_dataset_worker.download_dataset_artifact(default_dataset)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
]
def test_download_dataset_artifact(
mocker, tmp_path, responses, mock_dataset_worker, default_dataset
):
task_id = default_dataset.task_id
archive_path = (
FIXTURES_DIR / "extract_parent_archives" / "first_parent" / "arkindex_data.zstd"
)
mocker.patch(
"arkindex_worker.worker.base.BaseWorker.find_extras_directory",
return_value=tmp_path,
)
expected_results = [
{
"id": "artifact_1",
"path": "dataset_id.zstd",
"size": 42,
"content_type": "application/zstd",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
},
{
"id": "artifact_2",
"path": "logs.log",
"size": 42,
"content_type": "text/plain",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
},
]
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{task_id}/artifacts/",
status=200,
json=expected_results,
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd",
status=200,
body=archive_path.read_bytes(),
content_type="application/zstd",
)
archive = mock_dataset_worker.download_dataset_artifact(default_dataset)
assert archive == tmp_path / "dataset_id.zstd"
assert archive.read_bytes() == archive_path.read_bytes()
archive.unlink()
assert len(responses.calls) == len(BASE_API_CALLS) + 2
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.zstd"),
]
def test_list_dataset_elements_per_split_api_error(
responses, mock_dataset_worker, default_dataset
):
......@@ -342,11 +522,132 @@ def test_run_update_dataset_state_api_error(
]
def test_run_download_dataset_artifact_api_error(
mocker,
tmp_path,
responses,
caplog,
mock_dataset_worker,
default_dataset,
):
default_dataset.state = DatasetState.Complete.value
mocker.patch(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset],
)
mocker.patch(
"arkindex_worker.worker.base.BaseWorker.find_extras_directory",
return_value=tmp_path,
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
status=500,
)
with pytest.raises(SystemExit):
mock_dataset_worker.run()
assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS * 2 + [
# We retry 5 times the API call
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
]
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.INFO, "Loaded worker Fake worker revision deadbee from API"),
(logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
(logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
*[
(
logging.INFO,
f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
)
for retry in [3.0, 4.0, 8.0, 16.0]
],
(
logging.WARNING,
"An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
),
(
logging.ERROR,
"Ran on 1 datasets: 0 completed, 1 failed",
),
]
def test_run_no_downloaded_artifact_error(
mocker,
tmp_path,
responses,
caplog,
mock_dataset_worker,
default_dataset,
):
default_dataset.state = DatasetState.Complete.value
mocker.patch(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset],
)
mocker.patch(
"arkindex_worker.worker.base.BaseWorker.find_extras_directory",
return_value=tmp_path,
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
status=200,
json={},
)
with pytest.raises(SystemExit):
mock_dataset_worker.run()
assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS * 2 + [
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
]
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.INFO, "Loaded worker Fake worker revision deadbee from API"),
(logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
(logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
(
logging.WARNING,
"Failed running worker on dataset dataset_id: MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
),
(
logging.ERROR,
"Ran on 1 datasets: 0 completed, 1 failed",
),
]
@pytest.mark.parametrize(
"generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)]
)
def test_run(
mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, state
mocker,
tmp_path,
responses,
caplog,
mock_dataset_worker,
default_dataset,
default_artifact,
generator,
state,
):
mock_dataset_worker.generator = generator
default_dataset.state = state.value
......@@ -355,6 +656,10 @@ def test_run(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset],
)
mocker.patch(
"arkindex_worker.worker.base.BaseWorker.find_extras_directory",
return_value=tmp_path,
)
mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
extra_calls = []
......@@ -369,10 +674,43 @@ def test_run(
extra_calls += [
("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
] * 2
extra_logs = [
extra_logs += [
(logging.INFO, "Building Dataset (dataset_id) (1/1)"),
(logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
]
else:
archive_path = (
FIXTURES_DIR
/ "extract_parent_archives"
/ "first_parent"
/ "arkindex_data.zstd"
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
status=200,
json=[default_artifact],
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd",
status=200,
body=archive_path.read_bytes(),
content_type="application/zstd",
)
extra_calls += [
(
"GET",
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
),
(
"GET",
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd",
),
]
extra_logs += [
(logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
]
mock_dataset_worker.run()
......@@ -394,10 +732,12 @@ def test_run(
)
def test_run_read_only(
mocker,
tmp_path,
responses,
caplog,
mock_dev_dataset_worker,
default_dataset,
default_artifact,
generator,
state,
):
......@@ -408,6 +748,10 @@ def test_run_read_only(
"arkindex_worker.worker.DatasetWorker.list_datasets",
return_value=[default_dataset.id],
)
mocker.patch(
"arkindex_worker.worker.base.BaseWorker.find_extras_directory",
return_value=tmp_path,
)
mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
responses.add(
......@@ -417,9 +761,10 @@ def test_run_read_only(
json=default_dataset,
)
extra_calls = []
extra_logs = []
if generator:
extra_logs = [
extra_logs += [
(logging.INFO, "Building Dataset (dataset_id) (1/1)"),
(
logging.WARNING,
......@@ -431,15 +776,48 @@ def test_run_read_only(
"Cannot update dataset as this worker is in read-only mode",
),
]
else:
archive_path = (
FIXTURES_DIR
/ "extract_parent_archives"
/ "first_parent"
/ "arkindex_data.zstd"
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
status=200,
json=[default_artifact],
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd",
status=200,
body=archive_path.read_bytes(),
content_type="application/zstd",
)
extra_calls += [
(
"GET",
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
),
(
"GET",
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.zstd",
),
]
extra_logs += [
(logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
]
mock_dev_dataset_worker.run()
assert mock_process.call_count == 1
assert len(responses.calls) == 1
assert len(responses.calls) == 1 + len(extra_calls)
assert [(call.request.method, call.request.url) for call in responses.calls] == [
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
]
] + extra_calls
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(logging.WARNING, "Running without any extra configuration"),
......
# -*- coding: utf-8 -*-
import sys
import uuid
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Artifact
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.dataset import DatasetMixin
from arkindex_worker.worker.task import TaskMixin
from tests.conftest import FIXTURES_DIR
from tests.test_elements_worker import BASE_API_CALLS
TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe")
@pytest.fixture
def default_artifact():
return {
"id": "artifact_id",
"path": "arkindex_data.zstd",
"size": 42,
"content_type": "application/zstd",
"s3_put_url": None,
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
}
@pytest.fixture
def mock_dataset_worker(monkeypatch):
class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
"""
This class is needed to run tests in the context of a dataset worker
"""
monkeypatch.setattr(sys, "argv", ["worker"])
dataset_worker = DatasetWorker()
dataset_worker.args = dataset_worker.parser.parse_args()
return dataset_worker
@pytest.mark.parametrize(
"payload, error",
(
......@@ -70,8 +40,10 @@ def test_list_artifacts_api_error(responses, mock_dataset_worker):
with pytest.raises(ErrorResponse):
mock_dataset_worker.list_artifacts(task_id=TASK_ID)
assert len(responses.calls) == 5
assert [(call.request.method, call.request.url) for call in responses.calls] == [
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
......@@ -116,8 +88,10 @@ def test_list_artifacts(
assert isinstance(artifact, Artifact)
assert artifact == expected_results[idx]
assert len(responses.calls) == 1
assert [(call.request.method, call.request.url) for call in responses.calls] == [
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifacts/"),
]
......@@ -141,7 +115,7 @@ def test_download_artifact_wrong_param_task_id(
):
api_payload = {
"task_id": TASK_ID,
"artifact": Artifact(default_artifact),
"artifact": default_artifact,
**payload,
}
......@@ -168,7 +142,7 @@ def test_download_artifact_wrong_param_artifact(
):
api_payload = {
"task_id": TASK_ID,
"artifact": Artifact(default_artifact),
"artifact": default_artifact,
**payload,
}
......@@ -177,25 +151,27 @@ def test_download_artifact_wrong_param_artifact(
def test_download_artifact_api_error(responses, mock_dataset_worker, default_artifact):
artifact = Artifact(default_artifact)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd",
f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd",
status=500,
)
with pytest.raises(ErrorResponse):
mock_dataset_worker.download_artifact(task_id=TASK_ID, artifact=artifact)
assert len(responses.calls) == 5
assert [(call.request.method, call.request.url) for call in responses.calls] == [
mock_dataset_worker.download_artifact(
task_id=TASK_ID, artifact=default_artifact
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# The API call is retried 5 times
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
]
......@@ -204,25 +180,27 @@ def test_download_artifact(
mock_dataset_worker,
default_artifact,
):
artifact = Artifact(default_artifact)
archive_path = (
FIXTURES_DIR / "extract_parent_archives" / "first_parent" / "arkindex_data.zstd"
)
responses.add(
responses.GET,
f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd",
f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd",
status=200,
body=archive_path.read_bytes(),
content_type="application/zstd",
)
assert (
mock_dataset_worker.download_artifact(task_id=TASK_ID, artifact=artifact).read()
mock_dataset_worker.download_artifact(
task_id=TASK_ID, artifact=default_artifact
).read()
== archive_path.read_bytes()
)
assert len(responses.calls) == 1
assert [(call.request.method, call.request.url) for call in responses.calls] == [
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/arkindex_data.zstd"),
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.zstd"),
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment