Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (2)
......@@ -157,7 +157,7 @@ class CachedTranscription(Model):
text = TextField()
confidence = FloatField()
orientation = CharField(max_length=50)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -170,7 +170,7 @@ class CachedClassification(Model):
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -183,7 +183,7 @@ class CachedEntity(Model):
name = TextField()
validated = BooleanField(default=False)
metas = JSONField(null=True)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -197,7 +197,7 @@ class CachedTranscriptionEntity(Model):
entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
confidence = FloatField(null=True)
class Meta:
......
......@@ -139,8 +139,6 @@ class ElementsWorker(
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
super().setup_api_client()
if self.is_read_only:
super().configure_for_developers()
else:
......
......@@ -125,6 +125,9 @@ class BaseWorker(object):
# is at least one available sqlite database either given or in the parent tasks
self.use_cache = False
# Define API Client
self.setup_api_client()
@property
def is_read_only(self) -> bool:
"""
......
......@@ -261,7 +261,7 @@ class ElementMixin(object):
with_corpus: Optional[bool] = None,
with_has_children: Optional[bool] = None,
with_zone: Optional[bool] = None,
worker_version: Optional[str] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]:
"""
List children of an element.
......@@ -295,7 +295,7 @@ class ElementMixin(object):
This parameter is not supported when caching is enabled.
:type with_zone: Optional[bool]
:param worker_version: Restrict to elements created by a worker version with this UUID.
:type worker_version: Optional[str]
:type worker_version: Optional[Union[str, bool]]
:return: An iterable of dicts from the ``ListElementChildren`` API endpoint,
or an iterable of :class:`CachedElement` when caching is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedElement]]
......@@ -330,10 +330,14 @@ class ElementMixin(object):
if with_zone is not None:
assert isinstance(with_zone, bool), "with_zone should be of type bool"
query_params["with_zone"] = with_zone
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -346,8 +350,12 @@ class ElementMixin(object):
query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type:
query = query.where(CachedElement.type == type)
if worker_version:
query = query.where(CachedElement.worker_version_id == worker_version)
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
query = query.where(
CachedElement.worker_version_id == worker_version_id
)
return query
else:
......
......@@ -4,6 +4,7 @@ ElementsWorker methods for transcriptions.
"""
from enum import Enum
from typing import Iterable, Optional, Union
from peewee import IntegrityError
......@@ -348,8 +349,12 @@ class TranscriptionMixin(object):
return annotations
def list_transcriptions(
self, element, element_type=None, recursive=None, worker_version=None
):
self,
element: Union[Element, CachedElement],
element_type: Optional[str] = None,
recursive: Optional[bool] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedTranscription]]:
"""
List transcriptions on an element.
......@@ -359,11 +364,11 @@ class TranscriptionMixin(object):
:type element_type: Optional[str]
:param recursive: Include transcriptions of any descendant of this element, recursively.
:type recursive: Optional[bool]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID.
:type worker_version: Optional[str]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
:type worker_version: Optional[Union[str, bool]]
:returns: An iterable of dicts representing each transcription,
or an iterable of CachedTranscription when cache support is enabled.
:rtype: Iterable[dict] or Iterable[CachedTranscription]
:rtype: Union[Iterable[dict], Iterable[CachedTranscription]]
"""
assert element and isinstance(
element, (Element, CachedElement)
......@@ -375,10 +380,14 @@ class TranscriptionMixin(object):
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -409,9 +418,11 @@ class TranscriptionMixin(object):
if element_type:
transcriptions = transcriptions.where(cte.c.type == element_type)
if worker_version:
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
CachedTranscription.worker_version_id == worker_version_id
)
else:
transcriptions = self.api_client.paginate(
......
......@@ -327,7 +327,14 @@ def mock_cached_elements():
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
CachedElement.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="paragraph",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=None,
)
assert CachedElement.select().count() == 3
@pytest.fixture
......@@ -407,6 +414,14 @@ def mock_cached_transcriptions():
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("66666666-6666-6666-6666-666666666666"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This is a manual one",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=None,
)
@pytest.fixture(scope="function")
......
......@@ -90,7 +90,6 @@ def test_init_var_worker_local_file(monkeypatch, tmp_path):
def test_cli_default(mocker, mock_worker_run_api):
worker = BaseWorker()
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker"])
worker.args = worker.parser.parse_args()
......@@ -98,7 +97,6 @@ def test_cli_default(mocker, mock_worker_run_api):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
worker.setup_api_client()
worker.configure()
assert not worker.args.verbose
assert logger.level == logging.NOTSET
......@@ -111,7 +109,6 @@ def test_cli_default(mocker, mock_worker_run_api):
def test_cli_arg_verbose_given(mocker, mock_worker_run_api):
worker = BaseWorker()
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker", "-v"])
worker.args = worker.parser.parse_args()
......@@ -119,7 +116,6 @@ def test_cli_arg_verbose_given(mocker, mock_worker_run_api):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
worker.setup_api_client()
worker.configure()
assert worker.args.verbose
assert logger.level == logging.DEBUG
......@@ -133,7 +129,6 @@ def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api):
worker = BaseWorker()
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_DEBUG", True)
worker.args = worker.parser.parse_args()
......@@ -141,7 +136,6 @@ def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
worker.setup_api_client()
worker.configure()
assert logger.level == logging.DEBUG
assert worker.api_client
......@@ -215,7 +209,6 @@ def test_configure_worker_run(mocker, monkeypatch, responses):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
worker.setup_api_client()
worker.configure()
assert worker.user_configuration == {"a": "b"}
......@@ -274,7 +267,6 @@ def test_configure_user_configuration_defaults(
content_type="application/json",
)
worker.setup_api_client()
worker.configure()
assert worker.config == {"param_1": "/some/path/file.pth", "param_2": 12}
......@@ -328,7 +320,6 @@ def test_configure_user_config_debug(mocker, monkeypatch, responses, debug):
content_type="application/json",
)
worker.args = worker.parser.parse_args()
worker.setup_api_client()
worker.configure()
assert worker.user_configuration == {"debug": debug}
......@@ -376,7 +367,6 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
content_type="application/json",
)
worker.args = worker.parser.parse_args()
worker.setup_api_client()
worker.configure()
assert worker.user_configuration is None
......
......@@ -58,12 +58,12 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, "confidence" REAL, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
[
......
......@@ -1255,7 +1255,18 @@ def test_list_element_children_wrong_worker_version(mock_elements_worker):
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
assert str(e.value) == "worker_version should be of type str or bool"
def test_list_element_children_wrong_bool_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(
element=elt,
worker_version=True,
)
assert str(e.value) == "if of type bool, worker_version can only be set to False"
def test_list_element_children_api_error(responses, mock_elements_worker):
......@@ -1363,6 +1374,48 @@ def test_list_element_children(responses, mock_elements_worker):
]
def test_list_element_children_manual_worker_version(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
expected_children = [
{
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
status=200,
json={
"count": 1,
"next": None,
"results": expected_children,
},
)
for idx, child in enumerate(
mock_elements_worker.list_element_children(element=elt, worker_version=False)
):
assert child == expected_children[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
),
]
def test_list_element_children_with_cache_unhandled_param(
mock_elements_worker_with_cache,
):
......@@ -1389,6 +1442,7 @@ def test_list_element_children_with_cache_unhandled_param(
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
),
),
# Filter on element and page should give the second element
......@@ -1399,7 +1453,7 @@ def test_list_element_children_with_cache_unhandled_param(
},
("22222222-2222-2222-2222-222222222222",),
),
# Filter on element and worker version should give all elements
# Filter on element and worker version should give first two elements
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1410,7 +1464,7 @@ def test_list_element_children_with_cache_unhandled_param(
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element, type something and worker version should give first
# Filter on element, type something and worker version should give first
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1419,6 +1473,14 @@ def test_list_element_children_with_cache_unhandled_param(
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter on element, manual worker version should give third
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": False,
},
("33333333-3333-3333-3333-333333333333",),
),
),
)
def test_list_element_children_with_cache(
......@@ -1430,7 +1492,7 @@ def test_list_element_children_with_cache(
):
# Check we have 2 elements already present in database
assert CachedElement.select().count() == 2
assert CachedElement.select().count() == 3
# Query database through cache
elements = mock_elements_worker_with_cache.list_element_children(**filters)
......
......@@ -1682,7 +1682,18 @@ def test_list_transcriptions_wrong_worker_version(mock_elements_worker):
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
assert str(e.value) == "worker_version should be of type str or bool"
def test_list_transcriptions_wrong_bool_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(
element=elt,
worker_version=True,
)
assert str(e.value) == "if of type bool, worker_version can only be set to False"
def test_list_transcriptions_api_error(responses, mock_elements_worker):
......@@ -1778,19 +1789,60 @@ def test_list_transcriptions(responses, mock_elements_worker):
]
def test_list_transcriptions_manual_worker_version(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
trans = [
{
"id": "0000",
"text": "hey",
"confidence": 0.42,
"worker_version_id": None,
"element": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False",
status=200,
json={
"count": 1,
"next": None,
"results": trans,
},
)
for idx, transcription in enumerate(
mock_elements_worker.list_transcriptions(element=elt, worker_version=False)
):
assert transcription == trans[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False",
),
]
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give first transcription
# Filter on element should give first and sixth transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
},
("11111111-1111-1111-1111-111111111111",),
(
"11111111-1111-1111-1111-111111111111",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter on element and element_type should give first transcription
# Filter on element and element_type should give first and sixth transcription
(
{
"element": CachedElement(
......@@ -1798,7 +1850,10 @@ def test_list_transcriptions(responses, mock_elements_worker):
),
"element_type": "page",
},
("11111111-1111-1111-1111-111111111111",),
(
"11111111-1111-1111-1111-111111111111",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter on element and worker_version should give first transcription
(
......@@ -1824,6 +1879,7 @@ def test_list_transcriptions(responses, mock_elements_worker):
"33333333-3333-3333-3333-333333333333",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter recursively on element and worker_version should give four transcriptions
......@@ -1857,6 +1913,16 @@ def test_list_transcriptions(responses, mock_elements_worker):
"55555555-5555-5555-5555-555555555555",
),
),
# Filter on element with manually created transcription should give sixth transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"worker_version": False,
},
("66666666-6666-6666-6666-666666666666",),
),
),
)
def test_list_transcriptions_with_cache(
......@@ -1867,7 +1933,7 @@ def test_list_transcriptions_with_cache(
expected_ids,
):
# Check we have 5 elements already present in database
assert CachedTranscription.select().count() == 5
assert CachedTranscription.select().count() == 6
# Query database through cache
transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters)
......