Skip to content
Snippets Groups Projects
Commit 8f5cf687 authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

Drop DataSource handling

parent 82630075
No related branches found
No related tags found
1 merge request!72Drop DataSource handling
Pipeline #78404 passed
......@@ -16,7 +16,7 @@ from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import MANUAL_SLUG, WorkerVersionMixin # noqa: F401
from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
class ActivityState(Enum):
......
# -*- coding: utf-8 -*-
MANUAL_SLUG = "manual"
class WorkerVersionMixin(object):
def get_worker_version(self, worker_version_id: str) -> dict:
"""
Get worker version from cache if possible, otherwise make API request
"""
if worker_version_id is None:
raise ValueError("No worker version ID")
if worker_version_id in self._worker_version_cache:
return self._worker_version_cache[worker_version_id]
......@@ -18,37 +19,11 @@ class WorkerVersionMixin(object):
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Get worker version slug from cache if possible, otherwise make API request
Helper function to get the worker slug from element, classification or transcription.
Gets the worker version slug from cache if possible, otherwise makes an API request.
Returns None if there is no associated worker version.
Should use `get_ml_result_slug` instead of using this method directly
:type worker_version_id: A worker version UUID
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
def get_ml_result_slug(self, ml_result) -> str:
"""
Helper function to get the slug from element, classification or transcription
Can handle old and new (source vs worker_version)
:type ml_result: Element or classification or transcription
"""
if (
"source" in ml_result
and ml_result["source"]
and ml_result["source"]["slug"]
):
return ml_result["source"]["slug"]
elif "worker_version" in ml_result and ml_result["worker_version"]:
return self.get_worker_version_slug(ml_result["worker_version"])
# transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result and ml_result["worker_version_id"]:
return self.get_worker_version_slug(ml_result["worker_version_id"])
elif "worker_version" in ml_result and ml_result["worker_version"] is None:
return MANUAL_SLUG
elif (
"worker_version_id" in ml_result and ml_result["worker_version_id"] is None
):
return MANUAL_SLUG
else:
raise ValueError(f"Unable to get slug from: {ml_result}")
......@@ -5,7 +5,7 @@ import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Element
from arkindex_worker.worker import MANUAL_SLUG, ActivityState
from arkindex_worker.worker import ActivityState
# Common API calls for all workers
BASE_API_CALLS = [
......@@ -48,109 +48,21 @@ def test_get_worker_version__uses_cache(fake_dummy_worker):
assert not api_client.responses
def test_get_slug__old_style(fake_dummy_worker):
element = {"source": {"slug": TEST_SLUG}}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
def test_get_slug__worker_version(fake_dummy_worker):
api_client = fake_dummy_worker.api_client
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
element = {"worker_version": TEST_VERSION_ID}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
# assert that only one call to the API
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
api_client.add_response(
"RetrieveWorkerVersion",
fake_ufcn_worker_version,
id=fake_ufcn_worker_version["id"],
)
expected_slugs = [
"scikit_portrait_outlier_balsac",
"scikit_portrait_outlier_balsac",
"ufcn_line_historical",
]
slugs = [
fake_dummy_worker.get_ml_result_slug(clf)
for clf in fake_page_element["classifications"]
]
assert slugs == expected_slugs
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
version_id = "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996"
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=version_id)
slug = fake_dummy_worker.get_ml_result_slug(fake_transcriptions_small["results"][0])
assert slug == TEST_SLUG
assert len(api_client.history) == 1
assert not api_client.responses
@pytest.mark.parametrize(
"ml_result, expected_slug",
(
# old
({"source": {"slug": "test_123"}}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version": None}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version_id": None}, "test_123"),
# new
({"source": None, "worker_version": "foo_1"}, "mock_slug"),
({"source": None, "worker_version_id": "foo_1"}, "mock_slug"),
({"worker_version_id": "foo_1"}, "mock_slug"),
# manual
({"worker_version_id": None}, MANUAL_SLUG),
({"worker_version": None}, MANUAL_SLUG),
({"source": None, "worker_version": None}, MANUAL_SLUG),
),
)
def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug):
fake_dummy_worker.get_worker_version_slug = mocker.MagicMock()
fake_dummy_worker.get_worker_version_slug.return_value = "mock_slug"
def test_get_worker_version_slug(mocker, fake_dummy_worker):
fake_dummy_worker.get_worker_version = mocker.MagicMock()
fake_dummy_worker.get_worker_version.return_value = {
"id": TEST_VERSION_ID,
"worker": {"slug": "mock_slug"},
}
slug = fake_dummy_worker.get_ml_result_slug(ml_result)
assert slug == expected_slug
slug = fake_dummy_worker.get_worker_version_slug(TEST_VERSION_ID)
assert slug == "mock_slug"
@pytest.mark.parametrize(
"ml_result",
(
({},),
({"source": None},),
({"source": {"slug": None}},),
),
)
def test_get_ml_result_slug__fail(fake_dummy_worker, ml_result):
def test_get_worker_version_slug_none(fake_dummy_worker):
with pytest.raises(ValueError) as excinfo:
fake_dummy_worker.get_ml_result_slug(ml_result)
assert str(excinfo.value).startswith("Unable to get slug from")
fake_dummy_worker.get_worker_version_slug(None)
assert str(excinfo.value) == "No worker version ID"
def test_defaults(responses, mock_elements_worker):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment