diff --git a/arkindex_worker/worker/corpus.py b/arkindex_worker/worker/corpus.py index 36f7d3c9cf9b6efe9f8b235641ad2cf7195d4baa..75d20bd6fe9a9d1121507bf09c8ecc16c0ccff54 100644 --- a/arkindex_worker/worker/corpus.py +++ b/arkindex_worker/worker/corpus.py @@ -5,6 +5,7 @@ BaseWorker methods for corpora. from enum import Enum from operator import itemgetter from tempfile import _TemporaryFileWrapper +from uuid import UUID from arkindex_worker import logger @@ -36,6 +37,25 @@ class CorpusExportState(Enum): class CorpusMixin: + def download_export(self, export_id: str) -> _TemporaryFileWrapper: + """ + Download an export. + + :param export_id: UUID of the export to download + :returns: The downloaded export stored in a temporary file. + """ + try: + UUID(export_id) + except ValueError as e: + raise ValueError("export_id is not a valid uuid.") from e + + logger.info(f"Downloading export ({export_id})...") + export: _TemporaryFileWrapper = self.api_client.request( + "DownloadExport", id=export_id + ) + logger.info(f"Downloaded export ({export_id}) @ `{export.name}`") + return export + def download_latest_export(self) -> _TemporaryFileWrapper: """ Download the latest export in `done` state of the current corpus. @@ -62,10 +82,5 @@ class CorpusMixin: # Download latest export export_id: str = exports[0]["id"] - logger.info(f"Downloading export ({export_id})...") - export: _TemporaryFileWrapper = self.api_client.request( - "DownloadExport", id=export_id - ) - logger.info(f"Downloaded export ({export_id}) @ `{export.name}`") - return export + return self.download_export(export_id) diff --git a/tests/test_elements_worker/test_corpus.py b/tests/test_elements_worker/test_corpus.py index f549a244f848d9456a733eee620a8e8de3112ae3..72586eaaf06f368d65cdb9151464fff88b7fa3ca 100644 --- a/tests/test_elements_worker/test_corpus.py +++ b/tests/test_elements_worker/test_corpus.py @@ -135,3 +135,34 @@ def test_download_latest_export(responses, mock_elements_worker): ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"), ("GET", f"http://testserver/api/v1/export/{export_id}/"), ] + + +def test_download_export_not_a_uuid(responses, mock_elements_worker): + with pytest.raises(ValueError, match="export_id is not a valid uuid."): + mock_elements_worker.download_export("mon export") + + +def test_download_export(responses, mock_elements_worker): + responses.add( + responses.GET, + "http://testserver/api/v1/export/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/", + status=302, + body=b"some SQLite export", + content_type="application/x-sqlite3", + stream=True, + ) + + export = mock_elements_worker.download_export( + "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff" + ) + assert export.name == "/tmp/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff" + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + "http://testserver/api/v1/export/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/", + ), + ]