diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index ccb968b3a7af8828d9cc13d54a3506d5fb9a0365..2eba022973a4de72c3068868f8ae0800c9cbb926 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -16,7 +16,7 @@ from arkindex_worker.worker.element import ElementMixin from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401 from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401 from arkindex_worker.worker.transcription import TranscriptionMixin -from arkindex_worker.worker.version import MANUAL_SLUG, WorkerVersionMixin # noqa: F401 +from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401 class ActivityState(Enum): diff --git a/arkindex_worker/worker/version.py b/arkindex_worker/worker/version.py index 9685d00143d1d06b8a4d4faa964366231508cd10..0000cc96d4041b485a1325b70b1bf3d75db4e933 100644 --- a/arkindex_worker/worker/version.py +++ b/arkindex_worker/worker/version.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- -MANUAL_SLUG = "manual" - class WorkerVersionMixin(object): def get_worker_version(self, worker_version_id: str) -> dict: """ Get worker version from cache if possible, otherwise make API request """ + if worker_version_id is None: + raise ValueError("No worker version ID") + if worker_version_id in self._worker_version_cache: return self._worker_version_cache[worker_version_id] @@ -18,37 +19,11 @@ class WorkerVersionMixin(object): def get_worker_version_slug(self, worker_version_id: str) -> str: """ - Get worker version slug from cache if possible, otherwise make API request + Helper function to get the worker slug from element, classification or transcription. + Gets the worker version slug from cache if possible, otherwise makes an API request. + Returns None if there is no associated worker version. - Should use `get_ml_result_slug` instead of using this method directly + :type worker_version_id: A worker version UUID """ worker_version = self.get_worker_version(worker_version_id) return worker_version["worker"]["slug"] - - def get_ml_result_slug(self, ml_result) -> str: - """ - Helper function to get the slug from element, classification or transcription - - Can handle old and new (source vs worker_version) - - :type ml_result: Element or classification or transcription - """ - if ( - "source" in ml_result - and ml_result["source"] - and ml_result["source"]["slug"] - ): - return ml_result["source"]["slug"] - elif "worker_version" in ml_result and ml_result["worker_version"]: - return self.get_worker_version_slug(ml_result["worker_version"]) - # transcriptions have worker_version_id but elements have worker_version - elif "worker_version_id" in ml_result and ml_result["worker_version_id"]: - return self.get_worker_version_slug(ml_result["worker_version_id"]) - elif "worker_version" in ml_result and ml_result["worker_version"] is None: - return MANUAL_SLUG - elif ( - "worker_version_id" in ml_result and ml_result["worker_version_id"] is None - ): - return MANUAL_SLUG - else: - raise ValueError(f"Unable to get slug from: {ml_result}") diff --git a/tests/test_elements_worker/test_worker.py b/tests/test_elements_worker/test_worker.py index 210abc81af506f9269c963ee2ddecd39ab8f9b21..9942f3a9d0d1b09f5100a3bc6dc5d6a01e3dffc0 100644 --- a/tests/test_elements_worker/test_worker.py +++ b/tests/test_elements_worker/test_worker.py @@ -5,7 +5,7 @@ import pytest from apistar.exceptions import ErrorResponse from arkindex_worker.models import Element -from arkindex_worker.worker import MANUAL_SLUG, ActivityState +from arkindex_worker.worker import ActivityState # Common API calls for all workers BASE_API_CALLS = [ @@ -48,109 +48,21 @@ def test_get_worker_version__uses_cache(fake_dummy_worker): assert not api_client.responses -def test_get_slug__old_style(fake_dummy_worker): - element = {"source": {"slug": TEST_SLUG}} - - slug = fake_dummy_worker.get_ml_result_slug(element) - - assert slug == TEST_SLUG - - -def test_get_slug__worker_version(fake_dummy_worker): - api_client = fake_dummy_worker.api_client - - response = {"worker": {"slug": TEST_SLUG}} - - api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID) - - element = {"worker_version": TEST_VERSION_ID} - - slug = fake_dummy_worker.get_ml_result_slug(element) - - assert slug == TEST_SLUG - - # assert that only one call to the API - assert len(api_client.history) == 1 - assert not api_client.responses - - -def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_worker): - api_client = fake_dummy_worker.api_client - - api_client.add_response( - "RetrieveWorkerVersion", - fake_ufcn_worker_version, - id=fake_ufcn_worker_version["id"], - ) - - expected_slugs = [ - "scikit_portrait_outlier_balsac", - "scikit_portrait_outlier_balsac", - "ufcn_line_historical", - ] - - slugs = [ - fake_dummy_worker.get_ml_result_slug(clf) - for clf in fake_page_element["classifications"] - ] - - assert slugs == expected_slugs - assert len(api_client.history) == 1 - assert not api_client.responses - - -def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker): - api_client = fake_dummy_worker.api_client - - version_id = "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996" - response = {"worker": {"slug": TEST_SLUG}} - - api_client.add_response("RetrieveWorkerVersion", response, id=version_id) - - slug = fake_dummy_worker.get_ml_result_slug(fake_transcriptions_small["results"][0]) - - assert slug == TEST_SLUG - assert len(api_client.history) == 1 - assert not api_client.responses - - -@pytest.mark.parametrize( - "ml_result, expected_slug", - ( - # old - ({"source": {"slug": "test_123"}}, "test_123"), - ({"source": {"slug": "test_123"}, "worker_version": None}, "test_123"), - ({"source": {"slug": "test_123"}, "worker_version_id": None}, "test_123"), - # new - ({"source": None, "worker_version": "foo_1"}, "mock_slug"), - ({"source": None, "worker_version_id": "foo_1"}, "mock_slug"), - ({"worker_version_id": "foo_1"}, "mock_slug"), - # manual - ({"worker_version_id": None}, MANUAL_SLUG), - ({"worker_version": None}, MANUAL_SLUG), - ({"source": None, "worker_version": None}, MANUAL_SLUG), - ), -) -def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug): - fake_dummy_worker.get_worker_version_slug = mocker.MagicMock() - fake_dummy_worker.get_worker_version_slug.return_value = "mock_slug" +def test_get_worker_version_slug(mocker, fake_dummy_worker): + fake_dummy_worker.get_worker_version = mocker.MagicMock() + fake_dummy_worker.get_worker_version.return_value = { + "id": TEST_VERSION_ID, + "worker": {"slug": "mock_slug"}, + } - slug = fake_dummy_worker.get_ml_result_slug(ml_result) - assert slug == expected_slug + slug = fake_dummy_worker.get_worker_version_slug(TEST_VERSION_ID) + assert slug == "mock_slug" -@pytest.mark.parametrize( - "ml_result", - ( - ({},), - ({"source": None},), - ({"source": {"slug": None}},), - ), -) -def test_get_ml_result_slug__fail(fake_dummy_worker, ml_result): +def test_get_worker_version_slug_none(fake_dummy_worker): with pytest.raises(ValueError) as excinfo: - fake_dummy_worker.get_ml_result_slug(ml_result) - assert str(excinfo.value).startswith("Unable to get slug from") + fake_dummy_worker.get_worker_version_slug(None) + assert str(excinfo.value) == "No worker version ID" def test_defaults(responses, mock_elements_worker):