From 03a3fb0843bc2e2d6232068a4f456e07b6124a66 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 6 Apr 2021 10:57:30 +0200 Subject: [PATCH] Remove get_ml_result_slug --- arkindex_worker/worker/version.py | 29 ++------ tests/test_elements_worker/test_worker.py | 91 ++--------------------- 2 files changed, 12 insertions(+), 108 deletions(-) diff --git a/arkindex_worker/worker/version.py b/arkindex_worker/worker/version.py index df12660a..cce8fb21 100644 --- a/arkindex_worker/worker/version.py +++ b/arkindex_worker/worker/version.py @@ -6,6 +6,9 @@ class WorkerVersionMixin(object): """ Get worker version from cache if possible, otherwise make API request """ + if worker_version_id is None: + raise ValueError("No worker version ID") + if worker_version_id in self._worker_version_cache: return self._worker_version_cache[worker_version_id] @@ -15,32 +18,12 @@ class WorkerVersionMixin(object): return worker_version def get_worker_version_slug(self, worker_version_id: str) -> str: - """ - Get worker version slug from cache if possible, otherwise make API request - """ - worker_version = self.get_worker_version(worker_version_id) - return worker_version["worker"]["slug"] - - def get_ml_result_slug(self, ml_result) -> str: """ Helper function to get the worker slug from element, classification or transcription. + Gets the worker version slug from cache if possible, otherwise makes an API request. Returns None if there is no associated worker version. :type ml_result: Element, classification or transcription """ - - # Handle cached models - if hasattr(ml_result, "worker_version_id"): - worker_version_id = ml_result.worker_version_id - elif "worker_version" in ml_result: - worker_version_id = ml_result["worker_version"] - # transcriptions have worker_version_id but elements have worker_version - elif "worker_version_id" in ml_result: - worker_version_id = ml_result["worker_version_id"] - else: - raise ValueError(f"Unable to get slug from: {ml_result}") - - if worker_version_id is None: - return - - return self.get_worker_version_slug(worker_version_id) + worker_version = self.get_worker_version(worker_version_id) + return worker_version["worker"]["slug"] diff --git a/tests/test_elements_worker/test_worker.py b/tests/test_elements_worker/test_worker.py index ce85d12d..9942f3a9 100644 --- a/tests/test_elements_worker/test_worker.py +++ b/tests/test_elements_worker/test_worker.py @@ -4,7 +4,6 @@ import json import pytest from apistar.exceptions import ErrorResponse -from arkindex_worker.cache import CachedElement from arkindex_worker.models import Element from arkindex_worker.worker import ActivityState @@ -49,99 +48,21 @@ def test_get_worker_version__uses_cache(fake_dummy_worker): assert not api_client.responses -def test_get_slug__worker_version(fake_dummy_worker): - api_client = fake_dummy_worker.api_client - - response = {"worker": {"slug": TEST_SLUG}} - - api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID) - - element = {"worker_version": TEST_VERSION_ID} - - slug = fake_dummy_worker.get_ml_result_slug(element) - - assert slug == TEST_SLUG - - # assert that only one call to the API - assert len(api_client.history) == 1 - assert not api_client.responses - - -def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_worker): - api_client = fake_dummy_worker.api_client - - api_client.add_response( - "RetrieveWorkerVersion", - fake_ufcn_worker_version, - id=fake_ufcn_worker_version["id"], - ) - - expected_slugs = [ - None, - None, - "ufcn_line_historical", - ] - - slugs = [ - fake_dummy_worker.get_ml_result_slug(clf) - for clf in fake_page_element["classifications"] - ] - - assert slugs == expected_slugs - assert len(api_client.history) == 1 - assert not api_client.responses - - -def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker): - api_client = fake_dummy_worker.api_client - - version_id = "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996" - response = {"worker": {"slug": TEST_SLUG}} - - api_client.add_response("RetrieveWorkerVersion", response, id=version_id) - - slug = fake_dummy_worker.get_ml_result_slug(fake_transcriptions_small["results"][0]) - - assert slug == TEST_SLUG - assert len(api_client.history) == 1 - assert not api_client.responses - - -@pytest.mark.parametrize( - "ml_result, expected_slug", - ( - ({"worker_version": TEST_VERSION_ID}, "mock_slug"), - ({"worker_version_id": TEST_VERSION_ID}, "mock_slug"), - (CachedElement(worker_version_id=TEST_VERSION_ID), "mock_slug"), - # manual - ({"worker_version_id": None}, None), - ({"worker_version": None}, None), - (CachedElement(), None), - ), -) -def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug): +def test_get_worker_version_slug(mocker, fake_dummy_worker): fake_dummy_worker.get_worker_version = mocker.MagicMock() fake_dummy_worker.get_worker_version.return_value = { "id": TEST_VERSION_ID, "worker": {"slug": "mock_slug"}, } - slug = fake_dummy_worker.get_ml_result_slug(ml_result) - assert slug == expected_slug + slug = fake_dummy_worker.get_worker_version_slug(TEST_VERSION_ID) + assert slug == "mock_slug" -@pytest.mark.parametrize( - "ml_result", - ( - ({},), - ({"source": None},), - ({"source": {"slug": None}},), - ), -) -def test_get_ml_result_slug__fail(fake_dummy_worker, ml_result): +def test_get_worker_version_slug_none(fake_dummy_worker): with pytest.raises(ValueError) as excinfo: - fake_dummy_worker.get_ml_result_slug(ml_result) - assert str(excinfo.value).startswith("Unable to get slug from") + fake_dummy_worker.get_worker_version_slug(None) + assert str(excinfo.value) == "No worker version ID" def test_defaults(responses, mock_elements_worker): -- GitLab