Skip to content
Snippets Groups Projects
Commit 4f71d07b authored by Martin's avatar Martin
Browse files

get ml_result manual slug

parent aebdd530
No related branches found
No related tags found
1 merge request!33Support manual source for ml results
Pipeline #78012 passed
......@@ -18,6 +18,8 @@ from arkindex_worker import logger
from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
MANUAL_SLUG = "manual"
class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker"):
......@@ -604,5 +606,11 @@ class ElementsWorker(BaseWorker):
# transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result and ml_result["worker_version_id"]:
return self.get_worker_version_slug(ml_result["worker_version_id"])
elif "worker_version" in ml_result and ml_result["worker_version"] is None:
return MANUAL_SLUG
elif (
"worker_version_id" in ml_result and ml_result["worker_version_id"] is None
):
return MANUAL_SLUG
else:
raise ValueError(f"Unable to get slug from: {ml_result}")
......@@ -10,7 +10,12 @@ import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker, EntityType, TranscriptionType
from arkindex_worker.worker import (
MANUAL_SLUG,
ElementsWorker,
EntityType,
TranscriptionType,
)
TRANSCRIPTIONS_SAMPLE = [
{
......@@ -1888,3 +1893,42 @@ def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker):
assert slug == TEST_SLUG
assert len(api_client.history) == 1
assert not api_client.responses
@pytest.mark.parametrize(
"ml_result, expected_slug",
(
# old
({"source": {"slug": "test_123"}}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version": None}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version_id": None}, "test_123"),
# new
({"source": None, "worker_version": "foo_1"}, "mock_slug"),
({"source": None, "worker_version_id": "foo_1"}, "mock_slug"),
({"worker_version_id": "foo_1"}, "mock_slug"),
# manual
({"worker_version_id": None}, MANUAL_SLUG),
({"worker_version": None}, MANUAL_SLUG),
({"source": None, "worker_version": None}, MANUAL_SLUG),
),
)
def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug):
fake_dummy_worker.get_worker_version_slug = mocker.MagicMock()
fake_dummy_worker.get_worker_version_slug.return_value = "mock_slug"
slug = fake_dummy_worker.get_ml_result_slug(ml_result)
assert slug == expected_slug
@pytest.mark.parametrize(
"ml_result",
(
({},),
({"source": None},),
({"source": {"slug": None}},),
),
)
def test_get_ml_result_slug__fail(fake_dummy_worker, ml_result):
with pytest.raises(ValueError) as excinfo:
fake_dummy_worker.get_ml_result_slug(ml_result)
assert str(excinfo.value).startswith("Unable to get slug from")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment