Skip to content
Snippets Groups Projects
Commit fa82e866 authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

WIP: Drop DataSource handling

parent 82630075
No related branches found
No related tags found
No related merge requests found
...@@ -16,7 +16,7 @@ from arkindex_worker.worker.element import ElementMixin ...@@ -16,7 +16,7 @@ from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401 from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401 from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.transcription import TranscriptionMixin from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import MANUAL_SLUG, WorkerVersionMixin # noqa: F401 from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
class ActivityState(Enum): class ActivityState(Enum):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
MANUAL_SLUG = "manual"
class WorkerVersionMixin(object): class WorkerVersionMixin(object):
def get_worker_version(self, worker_version_id: str) -> dict: def get_worker_version(self, worker_version_id: str) -> dict:
...@@ -16,39 +14,28 @@ class WorkerVersionMixin(object): ...@@ -16,39 +14,28 @@ class WorkerVersionMixin(object):
return worker_version return worker_version
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Get worker version slug from cache if possible, otherwise make API request
Should use `get_ml_result_slug` instead of using this method directly
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
def get_ml_result_slug(self, ml_result) -> str: def get_ml_result_slug(self, ml_result) -> str:
""" """
Helper function to get the slug from element, classification or transcription Helper function to get the worker slug from element, classification or transcription.
Returns None if there is no associated worker version.
Can handle old and new (source vs worker_version) :type ml_result: Element, classification or transcription
:type ml_result: Element or classification or transcription
""" """
if (
"source" in ml_result # Handle cached models
and ml_result["source"] if hasattr(ml_result, 'worker_version_id'):
and ml_result["source"]["slug"] worker_version_id = ml_result.worker_version_id
): elif "worker_version" in ml_result:
return ml_result["source"]["slug"] worker_version_id = ml_result['worker_version']
elif "worker_version" in ml_result and ml_result["worker_version"]:
return self.get_worker_version_slug(ml_result["worker_version"])
# transcriptions have worker_version_id but elements have worker_version # transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result and ml_result["worker_version_id"]: elif "worker_version_id" in ml_result:
return self.get_worker_version_slug(ml_result["worker_version_id"]) worker_version_id = ml_result['worker_version_id']
elif "worker_version" in ml_result and ml_result["worker_version"] is None:
return MANUAL_SLUG
elif (
"worker_version_id" in ml_result and ml_result["worker_version_id"] is None
):
return MANUAL_SLUG
else: else:
raise ValueError(f"Unable to get slug from: {ml_result}") raise ValueError(f"Unable to get slug from: {ml_result}")
if worker_version_id is None:
return
worker_version = self.get_worker_version(worker_version_id)
return worker_version['worker']['slug']
...@@ -5,7 +5,8 @@ import pytest ...@@ -5,7 +5,8 @@ import pytest
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Element from arkindex_worker.models import Element
from arkindex_worker.worker import MANUAL_SLUG, ActivityState from arkindex_worker.cache import CachedElement
from arkindex_worker.worker import ActivityState
# Common API calls for all workers # Common API calls for all workers
BASE_API_CALLS = [ BASE_API_CALLS = [
...@@ -48,14 +49,6 @@ def test_get_worker_version__uses_cache(fake_dummy_worker): ...@@ -48,14 +49,6 @@ def test_get_worker_version__uses_cache(fake_dummy_worker):
assert not api_client.responses assert not api_client.responses
def test_get_slug__old_style(fake_dummy_worker):
element = {"source": {"slug": TEST_SLUG}}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
def test_get_slug__worker_version(fake_dummy_worker): def test_get_slug__worker_version(fake_dummy_worker):
api_client = fake_dummy_worker.api_client api_client = fake_dummy_worker.api_client
...@@ -84,8 +77,8 @@ def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_ ...@@ -84,8 +77,8 @@ def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_
) )
expected_slugs = [ expected_slugs = [
"scikit_portrait_outlier_balsac", None,
"scikit_portrait_outlier_balsac", None,
"ufcn_line_historical", "ufcn_line_historical",
] ]
...@@ -117,23 +110,23 @@ def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker): ...@@ -117,23 +110,23 @@ def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ml_result, expected_slug", "ml_result, expected_slug",
( (
# old ({"worker_version": TEST_VERSION_ID}, "mock_slug"),
({"source": {"slug": "test_123"}}, "test_123"), ({"worker_version_id": TEST_VERSION_ID}, "mock_slug"),
({"source": {"slug": "test_123"}, "worker_version": None}, "test_123"), (CachedElement(worker_version_id=TEST_VERSION_ID), "mock_slug"),
({"source": {"slug": "test_123"}, "worker_version_id": None}, "test_123"),
# new
({"source": None, "worker_version": "foo_1"}, "mock_slug"),
({"source": None, "worker_version_id": "foo_1"}, "mock_slug"),
({"worker_version_id": "foo_1"}, "mock_slug"),
# manual # manual
({"worker_version_id": None}, MANUAL_SLUG), ({"worker_version_id": None}, None),
({"worker_version": None}, MANUAL_SLUG), ({"worker_version": None}, None),
({"source": None, "worker_version": None}, MANUAL_SLUG), (CachedElement(), None),
), ),
) )
def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug): def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug):
fake_dummy_worker.get_worker_version_slug = mocker.MagicMock() fake_dummy_worker.get_worker_version = mocker.MagicMock()
fake_dummy_worker.get_worker_version_slug.return_value = "mock_slug" fake_dummy_worker.get_worker_version.return_value = {
"id": TEST_VERSION_ID,
"worker": {
"slug": "mock_slug"
}
}
slug = fake_dummy_worker.get_ml_result_slug(ml_result) slug = fake_dummy_worker.get_ml_result_slug(ml_result)
assert slug == expected_slug assert slug == expected_slug
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment