Skip to content
Snippets Groups Projects
Commit 61dd9cb6 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Support worker_version=False in both cache and no cache mode

parent ac91877b
No related branches found
No related tags found
1 merge request!185Support worker_version=False in both cache and no cache mode
Pipeline #79375 passed
......@@ -157,7 +157,7 @@ class CachedTranscription(Model):
text = TextField()
confidence = FloatField()
orientation = CharField(max_length=50)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -170,7 +170,7 @@ class CachedClassification(Model):
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -183,7 +183,7 @@ class CachedEntity(Model):
name = TextField()
validated = BooleanField(default=False)
metas = JSONField(null=True)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -197,7 +197,7 @@ class CachedTranscriptionEntity(Model):
entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
confidence = FloatField(null=True)
class Meta:
......
......@@ -261,7 +261,7 @@ class ElementMixin(object):
with_corpus: Optional[bool] = None,
with_has_children: Optional[bool] = None,
with_zone: Optional[bool] = None,
worker_version: Optional[str] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]:
"""
List children of an element.
......@@ -295,7 +295,7 @@ class ElementMixin(object):
This parameter is not supported when caching is enabled.
:type with_zone: Optional[bool]
:param worker_version: Restrict to elements created by a worker version with this UUID.
:type worker_version: Optional[str]
:type worker_version: Optional[Union[str, bool]]
:return: An iterable of dicts from the ``ListElementChildren`` API endpoint,
or an iterable of :class:`CachedElement` when caching is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedElement]]
......@@ -330,10 +330,14 @@ class ElementMixin(object):
if with_zone is not None:
assert isinstance(with_zone, bool), "with_zone should be of type bool"
query_params["with_zone"] = with_zone
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -346,8 +350,12 @@ class ElementMixin(object):
query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type:
query = query.where(CachedElement.type == type)
if worker_version:
query = query.where(CachedElement.worker_version_id == worker_version)
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
query = query.where(
CachedElement.worker_version_id == worker_version_id
)
return query
else:
......
......@@ -4,6 +4,7 @@ ElementsWorker methods for transcriptions.
"""
from enum import Enum
from typing import Iterable, Optional, Union
from peewee import IntegrityError
......@@ -348,8 +349,12 @@ class TranscriptionMixin(object):
return annotations
def list_transcriptions(
self, element, element_type=None, recursive=None, worker_version=None
):
self,
element: Union[Element, CachedElement],
element_type: Optional[str] = None,
recursive: Optional[bool] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedTranscription]]:
"""
List transcriptions on an element.
......@@ -359,11 +364,11 @@ class TranscriptionMixin(object):
:type element_type: Optional[str]
:param recursive: Include transcriptions of any descendant of this element, recursively.
:type recursive: Optional[bool]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID.
:type worker_version: Optional[str]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
:type worker_version: Optional[Union[str, bool]]
:returns: An iterable of dicts representing each transcription,
or an iterable of CachedTranscription when cache support is enabled.
:rtype: Iterable[dict] or Iterable[CachedTranscription]
:rtype: Union[Iterable[dict], Iterable[CachedTranscription]]
"""
assert element and isinstance(
element, (Element, CachedElement)
......@@ -375,10 +380,14 @@ class TranscriptionMixin(object):
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -409,9 +418,11 @@ class TranscriptionMixin(object):
if element_type:
transcriptions = transcriptions.where(cte.c.type == element_type)
if worker_version:
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
CachedTranscription.worker_version_id == worker_version_id
)
else:
transcriptions = self.api_client.paginate(
......
......@@ -327,7 +327,14 @@ def mock_cached_elements():
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
CachedElement.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="paragraph",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=None,
)
assert CachedElement.select().count() == 3
@pytest.fixture
......@@ -407,6 +414,14 @@ def mock_cached_transcriptions():
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("66666666-6666-6666-6666-666666666666"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This is a manual one",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=None,
)
@pytest.fixture(scope="function")
......
......@@ -58,12 +58,12 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, "confidence" REAL, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
[
......
......@@ -1255,7 +1255,18 @@ def test_list_element_children_wrong_worker_version(mock_elements_worker):
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
assert str(e.value) == "worker_version should be of type str or bool"
def test_list_element_children_wrong_bool_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(
element=elt,
worker_version=True,
)
assert str(e.value) == "if of type bool, worker_version can only be set to False"
def test_list_element_children_api_error(responses, mock_elements_worker):
......@@ -1363,6 +1374,48 @@ def test_list_element_children(responses, mock_elements_worker):
]
def test_list_element_children_manual_worker_version(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
expected_children = [
{
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
status=200,
json={
"count": 1,
"next": None,
"results": expected_children,
},
)
for idx, child in enumerate(
mock_elements_worker.list_element_children(element=elt, worker_version=False)
):
assert child == expected_children[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
),
]
def test_list_element_children_with_cache_unhandled_param(
mock_elements_worker_with_cache,
):
......@@ -1389,6 +1442,7 @@ def test_list_element_children_with_cache_unhandled_param(
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
),
),
# Filter on element and page should give the second element
......@@ -1399,7 +1453,7 @@ def test_list_element_children_with_cache_unhandled_param(
},
("22222222-2222-2222-2222-222222222222",),
),
# Filter on element and worker version should give all elements
# Filter on element and worker version should give first two elements
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1410,7 +1464,7 @@ def test_list_element_children_with_cache_unhandled_param(
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element, type something and worker version should give first
# Filter on element, type something and worker version should give first
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1419,6 +1473,14 @@ def test_list_element_children_with_cache_unhandled_param(
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter on element, manual worker version should give third
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": False,
},
("33333333-3333-3333-3333-333333333333",),
),
),
)
def test_list_element_children_with_cache(
......@@ -1430,7 +1492,7 @@ def test_list_element_children_with_cache(
):
# Check we have 2 elements already present in database
assert CachedElement.select().count() == 2
assert CachedElement.select().count() == 3
# Query database through cache
elements = mock_elements_worker_with_cache.list_element_children(**filters)
......
......@@ -1682,7 +1682,18 @@ def test_list_transcriptions_wrong_worker_version(mock_elements_worker):
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
assert str(e.value) == "worker_version should be of type str or bool"
def test_list_transcriptions_wrong_bool_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(
element=elt,
worker_version=True,
)
assert str(e.value) == "if of type bool, worker_version can only be set to False"
def test_list_transcriptions_api_error(responses, mock_elements_worker):
......@@ -1778,19 +1789,60 @@ def test_list_transcriptions(responses, mock_elements_worker):
]
def test_list_transcriptions_manual_worker_version(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
trans = [
{
"id": "0000",
"text": "hey",
"confidence": 0.42,
"worker_version_id": None,
"element": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False",
status=200,
json={
"count": 1,
"next": None,
"results": trans,
},
)
for idx, transcription in enumerate(
mock_elements_worker.list_transcriptions(element=elt, worker_version=False)
):
assert transcription == trans[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False",
),
]
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give first transcription
# Filter on element should give first and sixth transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
},
("11111111-1111-1111-1111-111111111111",),
(
"11111111-1111-1111-1111-111111111111",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter on element and element_type should give first transcription
# Filter on element and element_type should give first and sixth transcription
(
{
"element": CachedElement(
......@@ -1798,7 +1850,10 @@ def test_list_transcriptions(responses, mock_elements_worker):
),
"element_type": "page",
},
("11111111-1111-1111-1111-111111111111",),
(
"11111111-1111-1111-1111-111111111111",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter on element and worker_version should give first transcription
(
......@@ -1824,6 +1879,7 @@ def test_list_transcriptions(responses, mock_elements_worker):
"33333333-3333-3333-3333-333333333333",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter recursively on element and worker_version should give four transcriptions
......@@ -1857,6 +1913,16 @@ def test_list_transcriptions(responses, mock_elements_worker):
"55555555-5555-5555-5555-555555555555",
),
),
# Filter on element with manually created transcription should give sixth transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"worker_version": False,
},
("66666666-6666-6666-6666-666666666666",),
),
),
)
def test_list_transcriptions_with_cache(
......@@ -1867,7 +1933,7 @@ def test_list_transcriptions_with_cache(
expected_ids,
):
# Check we have 5 elements already present in database
assert CachedTranscription.select().count() == 5
assert CachedTranscription.select().count() == 6
# Query database through cache
transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment