Skip to content
Snippets Groups Projects
Verified Commit 03a3fb08 authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

Remove get_ml_result_slug

parent f904431e
No related branches found
No related tags found
1 merge request!72Drop DataSource handling
Pipeline #78402 passed
......@@ -6,6 +6,9 @@ class WorkerVersionMixin(object):
"""
Get worker version from cache if possible, otherwise make API request
"""
if worker_version_id is None:
raise ValueError("No worker version ID")
if worker_version_id in self._worker_version_cache:
return self._worker_version_cache[worker_version_id]
......@@ -15,32 +18,12 @@ class WorkerVersionMixin(object):
return worker_version
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Get worker version slug from cache if possible, otherwise make API request
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
def get_ml_result_slug(self, ml_result) -> str:
"""
Helper function to get the worker slug from element, classification or transcription.
Gets the worker version slug from cache if possible, otherwise makes an API request.
Returns None if there is no associated worker version.
:type ml_result: Element, classification or transcription
"""
# Handle cached models
if hasattr(ml_result, "worker_version_id"):
worker_version_id = ml_result.worker_version_id
elif "worker_version" in ml_result:
worker_version_id = ml_result["worker_version"]
# transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result:
worker_version_id = ml_result["worker_version_id"]
else:
raise ValueError(f"Unable to get slug from: {ml_result}")
if worker_version_id is None:
return
return self.get_worker_version_slug(worker_version_id)
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
......@@ -4,7 +4,6 @@ import json
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
from arkindex_worker.worker import ActivityState
......@@ -49,99 +48,21 @@ def test_get_worker_version__uses_cache(fake_dummy_worker):
assert not api_client.responses
def test_get_slug__worker_version(fake_dummy_worker):
api_client = fake_dummy_worker.api_client
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
element = {"worker_version": TEST_VERSION_ID}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
# assert that only one call to the API
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
api_client.add_response(
"RetrieveWorkerVersion",
fake_ufcn_worker_version,
id=fake_ufcn_worker_version["id"],
)
expected_slugs = [
None,
None,
"ufcn_line_historical",
]
slugs = [
fake_dummy_worker.get_ml_result_slug(clf)
for clf in fake_page_element["classifications"]
]
assert slugs == expected_slugs
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
version_id = "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996"
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=version_id)
slug = fake_dummy_worker.get_ml_result_slug(fake_transcriptions_small["results"][0])
assert slug == TEST_SLUG
assert len(api_client.history) == 1
assert not api_client.responses
@pytest.mark.parametrize(
"ml_result, expected_slug",
(
({"worker_version": TEST_VERSION_ID}, "mock_slug"),
({"worker_version_id": TEST_VERSION_ID}, "mock_slug"),
(CachedElement(worker_version_id=TEST_VERSION_ID), "mock_slug"),
# manual
({"worker_version_id": None}, None),
({"worker_version": None}, None),
(CachedElement(), None),
),
)
def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug):
def test_get_worker_version_slug(mocker, fake_dummy_worker):
fake_dummy_worker.get_worker_version = mocker.MagicMock()
fake_dummy_worker.get_worker_version.return_value = {
"id": TEST_VERSION_ID,
"worker": {"slug": "mock_slug"},
}
slug = fake_dummy_worker.get_ml_result_slug(ml_result)
assert slug == expected_slug
slug = fake_dummy_worker.get_worker_version_slug(TEST_VERSION_ID)
assert slug == "mock_slug"
@pytest.mark.parametrize(
"ml_result",
(
({},),
({"source": None},),
({"source": {"slug": None}},),
),
)
def test_get_ml_result_slug__fail(fake_dummy_worker, ml_result):
def test_get_worker_version_slug_none(fake_dummy_worker):
with pytest.raises(ValueError) as excinfo:
fake_dummy_worker.get_ml_result_slug(ml_result)
assert str(excinfo.value).startswith("Unable to get slug from")
fake_dummy_worker.get_worker_version_slug(None)
assert str(excinfo.value) == "No worker version ID"
def test_defaults(responses, mock_elements_worker):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment