diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index ccb968b3a7af8828d9cc13d54a3506d5fb9a0365..2eba022973a4de72c3068868f8ae0800c9cbb926 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -16,7 +16,7 @@ from arkindex_worker.worker.element import ElementMixin from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401 from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401 from arkindex_worker.worker.transcription import TranscriptionMixin -from arkindex_worker.worker.version import MANUAL_SLUG, WorkerVersionMixin # noqa: F401 +from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401 class ActivityState(Enum): diff --git a/arkindex_worker/worker/version.py b/arkindex_worker/worker/version.py index 9685d00143d1d06b8a4d4faa964366231508cd10..f32fdc3a2f3dc1780113bcf3d451e15841a88aee 100644 --- a/arkindex_worker/worker/version.py +++ b/arkindex_worker/worker/version.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -MANUAL_SLUG = "manual" - class WorkerVersionMixin(object): def get_worker_version(self, worker_version_id: str) -> dict: @@ -16,39 +14,28 @@ class WorkerVersionMixin(object): return worker_version - def get_worker_version_slug(self, worker_version_id: str) -> str: - """ - Get worker version slug from cache if possible, otherwise make API request - - Should use `get_ml_result_slug` instead of using this method directly - """ - worker_version = self.get_worker_version(worker_version_id) - return worker_version["worker"]["slug"] def get_ml_result_slug(self, ml_result) -> str: """ - Helper function to get the slug from element, classification or transcription + Helper function to get the worker slug from element, classification or transcription. + Returns None if there is no associated worker version. - Can handle old and new (source vs worker_version) - - :type ml_result: Element or classification or transcription + :type ml_result: Element, classification or transcription """ - if ( - "source" in ml_result - and ml_result["source"] - and ml_result["source"]["slug"] - ): - return ml_result["source"]["slug"] - elif "worker_version" in ml_result and ml_result["worker_version"]: - return self.get_worker_version_slug(ml_result["worker_version"]) + + # Handle cached models + if hasattr(ml_result, 'worker_version_id'): + worker_version_id = ml_result.worker_version_id + elif "worker_version" in ml_result: + worker_version_id = ml_result['worker_version'] # transcriptions have worker_version_id but elements have worker_version - elif "worker_version_id" in ml_result and ml_result["worker_version_id"]: - return self.get_worker_version_slug(ml_result["worker_version_id"]) - elif "worker_version" in ml_result and ml_result["worker_version"] is None: - return MANUAL_SLUG - elif ( - "worker_version_id" in ml_result and ml_result["worker_version_id"] is None - ): - return MANUAL_SLUG + elif "worker_version_id" in ml_result: + worker_version_id = ml_result['worker_version_id'] else: raise ValueError(f"Unable to get slug from: {ml_result}") + + if worker_version_id is None: + return + + worker_version = self.get_worker_version(worker_version_id) + return worker_version['worker']['slug'] diff --git a/tests/test_elements_worker/test_worker.py b/tests/test_elements_worker/test_worker.py index 210abc81af506f9269c963ee2ddecd39ab8f9b21..6c520a9dc8704ab614ec0d280d3a8495abe8d069 100644 --- a/tests/test_elements_worker/test_worker.py +++ b/tests/test_elements_worker/test_worker.py @@ -5,7 +5,8 @@ import pytest from apistar.exceptions import ErrorResponse from arkindex_worker.models import Element -from arkindex_worker.worker import MANUAL_SLUG, ActivityState +from arkindex_worker.cache import CachedElement +from arkindex_worker.worker import ActivityState # Common API calls for all workers BASE_API_CALLS = [ @@ -48,14 +49,6 @@ def test_get_worker_version__uses_cache(fake_dummy_worker): assert not api_client.responses -def test_get_slug__old_style(fake_dummy_worker): - element = {"source": {"slug": TEST_SLUG}} - - slug = fake_dummy_worker.get_ml_result_slug(element) - - assert slug == TEST_SLUG - - def test_get_slug__worker_version(fake_dummy_worker): api_client = fake_dummy_worker.api_client @@ -84,8 +77,8 @@ def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_ ) expected_slugs = [ - "scikit_portrait_outlier_balsac", - "scikit_portrait_outlier_balsac", + None, + None, "ufcn_line_historical", ] @@ -117,23 +110,23 @@ def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker): @pytest.mark.parametrize( "ml_result, expected_slug", ( - # old - ({"source": {"slug": "test_123"}}, "test_123"), - ({"source": {"slug": "test_123"}, "worker_version": None}, "test_123"), - ({"source": {"slug": "test_123"}, "worker_version_id": None}, "test_123"), - # new - ({"source": None, "worker_version": "foo_1"}, "mock_slug"), - ({"source": None, "worker_version_id": "foo_1"}, "mock_slug"), - ({"worker_version_id": "foo_1"}, "mock_slug"), + ({"worker_version": TEST_VERSION_ID}, "mock_slug"), + ({"worker_version_id": TEST_VERSION_ID}, "mock_slug"), + (CachedElement(worker_version_id=TEST_VERSION_ID), "mock_slug"), # manual - ({"worker_version_id": None}, MANUAL_SLUG), - ({"worker_version": None}, MANUAL_SLUG), - ({"source": None, "worker_version": None}, MANUAL_SLUG), + ({"worker_version_id": None}, None), + ({"worker_version": None}, None), + (CachedElement(), None), ), ) def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug): - fake_dummy_worker.get_worker_version_slug = mocker.MagicMock() - fake_dummy_worker.get_worker_version_slug.return_value = "mock_slug" + fake_dummy_worker.get_worker_version = mocker.MagicMock() + fake_dummy_worker.get_worker_version.return_value = { + "id": TEST_VERSION_ID, + "worker": { + "slug": "mock_slug" + } + } slug = fake_dummy_worker.get_ml_result_slug(ml_result) assert slug == expected_slug