Skip to content
Snippets Groups Projects
Verified Commit 6cefad90 authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

WIP: Drop DataSource handling

parent d2821a65
No related branches found
No related tags found
No related merge requests found
Pipeline #78390 failed
......@@ -16,7 +16,7 @@ from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import MANUAL_SLUG, WorkerVersionMixin # noqa: F401
from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
class ActivityState(Enum):
......
# -*- coding: utf-8 -*-
MANUAL_SLUG = "manual"
class WorkerVersionMixin(object):
def get_worker_version(self, worker_version_id: str) -> dict:
......@@ -16,39 +14,28 @@ class WorkerVersionMixin(object):
return worker_version
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Get worker version slug from cache if possible, otherwise make API request
Should use `get_ml_result_slug` instead of using this method directly
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
def get_ml_result_slug(self, ml_result) -> str:
"""
Helper function to get the slug from element, classification or transcription
Helper function to get the worker slug from element, classification or transcription.
Returns None if there is no associated worker version.
Can handle old and new (source vs worker_version)
:type ml_result: Element or classification or transcription
:type ml_result: Element, classification or transcription
"""
if (
"source" in ml_result
and ml_result["source"]
and ml_result["source"]["slug"]
):
return ml_result["source"]["slug"]
elif "worker_version" in ml_result and ml_result["worker_version"]:
return self.get_worker_version_slug(ml_result["worker_version"])
# Handle cached models
if hasattr(ml_result, 'worker_version_id'):
worker_version_id = ml_result.worker_version_id
elif "worker_version" in ml_result:
worker_version_id = ml_result['worker_version']
# transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result and ml_result["worker_version_id"]:
return self.get_worker_version_slug(ml_result["worker_version_id"])
elif "worker_version" in ml_result and ml_result["worker_version"] is None:
return MANUAL_SLUG
elif (
"worker_version_id" in ml_result and ml_result["worker_version_id"] is None
):
return MANUAL_SLUG
elif "worker_version_id" in ml_result:
worker_version_id = ml_result['worker_version_id']
else:
raise ValueError(f"Unable to get slug from: {ml_result}")
if worker_version_id is None:
return
worker_version = self.get_worker_version(worker_version_id)
return worker_version['worker']['slug']
......@@ -5,7 +5,8 @@ import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Element
from arkindex_worker.worker import MANUAL_SLUG, ActivityState
from arkindex_worker.cache import CachedElement
from arkindex_worker.worker import ActivityState
# Common API calls for all workers
BASE_API_CALLS = [
......@@ -48,14 +49,6 @@ def test_get_worker_version__uses_cache(fake_dummy_worker):
assert not api_client.responses
def test_get_slug__old_style(fake_dummy_worker):
element = {"source": {"slug": TEST_SLUG}}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
def test_get_slug__worker_version(fake_dummy_worker):
api_client = fake_dummy_worker.api_client
......@@ -84,8 +77,8 @@ def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_
)
expected_slugs = [
"scikit_portrait_outlier_balsac",
"scikit_portrait_outlier_balsac",
None,
None,
"ufcn_line_historical",
]
......@@ -117,23 +110,23 @@ def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker):
@pytest.mark.parametrize(
"ml_result, expected_slug",
(
# old
({"source": {"slug": "test_123"}}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version": None}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version_id": None}, "test_123"),
# new
({"source": None, "worker_version": "foo_1"}, "mock_slug"),
({"source": None, "worker_version_id": "foo_1"}, "mock_slug"),
({"worker_version_id": "foo_1"}, "mock_slug"),
({"worker_version": TEST_VERSION_ID}, "mock_slug"),
({"worker_version_id": TEST_VERSION_ID}, "mock_slug"),
(CachedElement(worker_version_id=TEST_VERSION_ID), "mock_slug"),
# manual
({"worker_version_id": None}, MANUAL_SLUG),
({"worker_version": None}, MANUAL_SLUG),
({"source": None, "worker_version": None}, MANUAL_SLUG),
({"worker_version_id": None}, None),
({"worker_version": None}, None),
(CachedElement(), None),
),
)
def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug):
fake_dummy_worker.get_worker_version_slug = mocker.MagicMock()
fake_dummy_worker.get_worker_version_slug.return_value = "mock_slug"
fake_dummy_worker.get_worker_version = mocker.MagicMock()
fake_dummy_worker.get_worker_version.return_value = {
"id": TEST_VERSION_ID,
"worker": {
"slug": "mock_slug"
}
}
slug = fake_dummy_worker.get_ml_result_slug(ml_result)
assert slug == expected_slug
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment