Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (2)
......@@ -4,6 +4,7 @@
from enum import Enum
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
......@@ -24,8 +25,8 @@ class EntityMixin(object):
Return the ID of the created entity
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
......
......@@ -248,24 +248,39 @@ class TranscriptionMixin(object):
), "worker_version should be of type str"
query_params["worker_version"] = worker_version
if self.use_cache and recursive is None:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"worker_version",
}, "When using the local cache, you can only filter by 'worker_version'"
transcriptions = CachedTranscription.select().where(
CachedTranscription.element_id == element.id
)
if self.use_cache:
if not recursive:
# In this case we don't have to return anything, it's easier to use an
# impossible condition (False) rather than filtering by type for nothing
if element_type and element_type != element.type:
return CachedTranscription.select().where(False)
transcriptions = CachedTranscription.select().where(
CachedTranscription.element_id == element.id
)
else:
base_case = (
CachedElement.select()
.where(CachedElement.id == element.id)
.cte("base", recursive=True)
)
recursive = CachedElement.select().join(
base_case, on=(CachedElement.parent_id == base_case.c.id)
)
cte = base_case.union_all(recursive)
transcriptions = (
CachedTranscription.select()
.join(cte, on=(CachedTranscription.element_id == cte.c.id))
.with_cte(cte)
)
if element_type:
transcriptions = transcriptions.where(cte.c.type == element_type)
if worker_version:
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
)
else:
if self.use_cache:
logger.warning(
"'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter."
)
transcriptions = self.api_client.paginate(
"ListTranscriptions", id=element.id, **query_params
)
......
......@@ -269,22 +269,71 @@ def mock_cached_elements():
def mock_cached_transcriptions():
"""Insert few transcriptions in local cache, on a shared element"""
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
id=UUID("11111111-1111-1111-1111-111111111111"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
type="something_else",
parent_id=UUID("11111111-1111-1111-1111-111111111111"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
type="page",
parent_id=UUID("11111111-1111-1111-1111-111111111111"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("44444444-4444-4444-4444-444444444444"),
type="something_else",
parent_id=UUID("22222222-2222-2222-2222-222222222222"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("55555555-5555-5555-5555-555555555555"),
type="something_else",
parent_id=UUID("44444444-4444-4444-4444-444444444444"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello!",
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="How are you?",
element_id=UUID("22222222-2222-2222-2222-222222222222"),
text="is",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
element_id=UUID("33333333-3333-3333-3333-333333333333"),
text="a",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("44444444-4444-4444-4444-444444444444"),
element_id=UUID("44444444-4444-4444-4444-444444444444"),
text="good",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("55555555-5555-5555-5555-555555555555"),
element_id=UUID("55555555-5555-5555-5555-555555555555"),
text="test",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
......
......@@ -16,7 +16,10 @@ def test_create_entity_wrong_element(mock_elements_worker):
type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
......@@ -25,7 +28,10 @@ def test_create_entity_wrong_element(mock_elements_worker):
type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_entity_wrong_name(mock_elements_worker):
......
......@@ -1262,79 +1262,84 @@ def test_list_transcriptions(responses, mock_elements_worker):
]
def test_list_transcriptions_with_cache_unhandled_param(
responses, mock_elements_worker_with_cache
):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker_with_cache.list_transcriptions(
element=elt, element_type="page"
)
assert (
str(e.value)
== "When using the local cache, you can only filter by 'worker_version'"
)
def test_list_transcriptions_with_cache_skip_recursive(
responses, mock_elements_worker_with_cache
):
# When the local cache is activated and the user defines the recursive filter, we should fallback to the API
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
trans = [
{
"id": "0000",
"text": "hey",
"confidence": 0.42,
"worker_version_id": "56785678-5678-5678-5678-567856785678",
"element": None,
},
]
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True",
status=200,
json={
"count": 3,
"next": None,
"results": trans,
},
)
for idx, transcription in enumerate(
mock_elements_worker_with_cache.list_transcriptions(element=elt, recursive=True)
):
assert transcription == trans[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True",
]
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give all elements inserted
# Filter on element should give first transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter on element and element_type should give first transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"element_type": "page",
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter on element and worker_version should give first transcription
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter recursively on element should give all transcriptions inserted
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"recursive": True,
},
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
),
),
# Filter on element and worker version should give first element
# Filter recursively on element and worker_version should give four transcriptions
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": "56785678-5678-5678-5678-567856785678",
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"worker_version": "90129012-9012-9012-9012-901290129012",
"recursive": True,
},
("11111111-1111-1111-1111-111111111111",),
(
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
),
),
# Filter recursively on element and element_type should give three transcriptions
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"element_type": "something_else",
"recursive": True,
},
(
"22222222-2222-2222-2222-222222222222",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
),
),
),
)
......@@ -1345,8 +1350,8 @@ def test_list_transcriptions_with_cache(
filters,
expected_ids,
):
# Check we have 2 elements already present in database
assert CachedTranscription.select().count() == 2
# Check we have 5 elements already present in database
assert CachedTranscription.select().count() == 5
# Query database through cache
transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters)
......