Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (12)
0.2.1-beta2 0.2.2-beta2
...@@ -156,6 +156,7 @@ class CachedTranscriptionEntity(Model): ...@@ -156,6 +156,7 @@ class CachedTranscriptionEntity(Model):
entity = ForeignKeyField(CachedEntity, backref="transcription_entities") entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")]) offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")]) length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField()
class Meta: class Meta:
primary_key = CompositeKey("transcription", "entity") primary_key = CompositeKey("transcription", "entity")
......
...@@ -62,6 +62,7 @@ class BaseWorker(object): ...@@ -62,6 +62,7 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory") logger.info(f"Worker will use {self.work_dir} as working directory")
self.process_information = None self.process_information = None
self.user_configuration = None
self.support_cache = support_cache self.support_cache = support_cache
# use_cache will be updated in configure() if the cache is supported and if there # use_cache will be updated in configure() if the cache is supported and if there
# is at least one available sqlite database either given or in the parent tasks # is at least one available sqlite database either given or in the parent tasks
...@@ -160,6 +161,15 @@ class BaseWorker(object): ...@@ -160,6 +161,15 @@ class BaseWorker(object):
# Load all required secrets # Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets} self.secrets = {name: self.load_secret(name) for name in required_secrets}
# Load worker run configuration when available and not in dev mode
if os.environ.get("ARKINDEX_WORKER_RUN_ID") and not self.args.dev:
worker_run = self.request(
"RetrieveWorkerRun", id=os.environ["ARKINDEX_WORKER_RUN_ID"]
)
self.user_configuration = worker_run.get("configuration")
if self.user_configuration:
logger.info("Loaded user configuration from WorkerRun")
task_id = os.environ.get("PONOS_TASK") task_id = os.environ.get("PONOS_TASK")
paths = None paths = None
if self.support_cache and self.args.database is not None: if self.support_cache and self.args.database is not None:
......
...@@ -131,3 +131,71 @@ class ClassificationMixin(object): ...@@ -131,3 +131,71 @@ class ClassificationMixin(object):
raise raise
self.report.add_classification(element.id, ml_class) self.report.add_classification(element.id, ml_class)
def create_classifications(self, element, classifications):
"""
Create multiple classifications at once on the given element through the API
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert classifications and isinstance(
classifications, list
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
confidence = classification.get("confidence")
assert (
confidence is not None
and isinstance(confidence, float)
and 0 <= confidence <= 1
), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range"
high_confidence = classification.get("high_confidence")
if high_confidence is not None:
assert isinstance(
high_confidence, bool
), f"Classification at index {index} in classifications: high_confidence should be of type bool"
if self.is_read_only:
logger.warning(
"Cannot create classifications as this worker is in read-only mode"
)
return
created_cls = self.request(
"CreateClassifications",
body={
"parent": str(element.id),
"worker_version": self.worker_version_id,
"classifications": classifications,
},
)["classifications"]
for created_cl in created_cls:
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
# Store classifications in local cache
try:
to_insert = [
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_version_id": self.worker_version_id,
}
for created_cl in created_cls
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classifications in local cache: {e}"
)
...@@ -8,7 +8,34 @@ from arkindex_worker.cache import CachedElement, CachedImage ...@@ -8,7 +8,34 @@ from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Element from arkindex_worker.models import Element
class MissingTypeError(Exception):
"""
A required element type was not found in a corpus.
"""
class ElementMixin(object): class ElementMixin(object):
def check_required_types(self, corpus_id: str, *type_slugs: str) -> bool:
"""
Check that a corpus has a list of required element types,
and raise an exception if any of them are missing.
"""
assert len(type_slugs), "At least one element type slug is required."
assert all(
isinstance(slug, str) for slug in type_slugs
), "Element type slugs must be strings."
corpus = self.request("RetrieveCorpus", id=corpus_id)
available_slugs = {element_type["slug"] for element_type in corpus["types"]}
missing_slugs = set(type_slugs) - available_slugs
if missing_slugs:
raise MissingTypeError(
f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus["name"]} corpus ({corpus["id"]}).'
)
return True
def create_sub_element(self, element, type, name, polygon): def create_sub_element(self, element, type, name, polygon):
""" """
Create a child element on the given element through API Create a child element on the given element through API
......
...@@ -21,7 +21,7 @@ class EntityType(Enum): ...@@ -21,7 +21,7 @@ class EntityType(Enum):
class EntityMixin(object): class EntityMixin(object):
def create_entity( def create_entity(
self, element, name, type, corpus=None, metas=None, validated=None self, element, name, type, corpus=None, metas=dict(), validated=None
): ):
""" """
Create an entity on the given corpus through API Create an entity on the given corpus through API
...@@ -111,6 +111,7 @@ class EntityMixin(object): ...@@ -111,6 +111,7 @@ class EntityMixin(object):
"entity": entity, "entity": entity,
"length": length, "length": length,
"offset": offset, "offset": offset,
"worker_version_id": self.worker_version_id,
}, },
) )
# TODO: Report transcription entity creation # TODO: Report transcription entity creation
...@@ -118,15 +119,13 @@ class EntityMixin(object): ...@@ -118,15 +119,13 @@ class EntityMixin(object):
if self.use_cache: if self.use_cache:
# Store transcription entity in local cache # Store transcription entity in local cache
try: try:
to_insert = [ CachedTranscriptionEntity.create(
{ transcription=transcription,
"transcription": transcription, entity=entity,
"entity": entity, offset=offset,
"offset": offset, length=length,
"length": length, worker_version_id=self.worker_version_id,
} )
]
CachedTranscriptionEntity.insert_many(to_insert).execute()
except IntegrityError as e: except IntegrityError as e:
logger.warning( logger.warning(
f"Couldn't save created transcription entity in local cache: {e}" f"Couldn't save created transcription entity in local cache: {e}"
......
arkindex-client==1.0.6 arkindex-client==1.0.6
peewee==3.14.4 peewee==3.14.4
Pillow==8.2.0 Pillow==8.3.1
python-gitlab==2.7.1 python-gitlab==2.7.1
python-gnupg==0.4.7 python-gnupg==0.4.7
sh==1.14.2 sh==1.14.2
tenacity==7.0.0 tenacity==8.0.1
...@@ -119,18 +119,56 @@ def test_cli_arg_verbose_given(mocker, mock_config_api): ...@@ -119,18 +119,56 @@ def test_cli_arg_verbose_given(mocker, mock_config_api):
logger.setLevel(logging.NOTSET) logger.setLevel(logging.NOTSET)
def test_configure_dev_mode(mocker, mock_user_api, mock_worker_version_api): def test_configure_dev_mode(
mocker, monkeypatch, mock_user_api, mock_worker_version_api
):
""" """
Configuring a worker in developer mode avoid retrieving process information Configuring a worker in developer mode avoid retrieving process information
""" """
worker = BaseWorker() worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker", "--dev"]) mocker.patch.object(sys, "argv", ["worker", "--dev"])
monkeypatch.setenv(
"ARKINDEX_WORKER_RUN_ID", "aaaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
)
worker.configure() worker.configure()
assert worker.args.dev is True assert worker.args.dev is True
assert worker.process_information is None assert worker.process_information is None
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234" assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.is_read_only is True assert worker.is_read_only is True
assert worker.user_configuration is None
def test_configure_worker_run(mocker, monkeypatch, responses, mock_config_api):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
run_id = "aaaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", run_id)
responses.add(
responses.GET,
f"http://testserver/api/v1/imports/workers/{run_id}/",
json={"id": run_id, "configuration": {"a": "b"}},
)
worker.configure()
assert worker.user_configuration == {"a": "b"}
def test_configure_worker_run_missing_conf(
mocker, monkeypatch, responses, mock_config_api
):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
run_id = "aaaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", run_id)
responses.add(
responses.GET,
f"http://testserver/api/v1/imports/workers/{run_id}/",
json={"id": run_id},
)
worker.configure()
assert worker.user_configuration is None
def test_load_missing_secret(): def test_load_missing_secret():
......
...@@ -58,7 +58,7 @@ def test_create_tables(tmp_path): ...@@ -58,7 +58,7 @@ def test_create_tables(tmp_path):
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join( actual_schema = "\n".join(
......
...@@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker): ...@@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
# Classification has NOT been created # Classification has NOT been created
assert mock_elements_worker.report.report_data["elements"] == {} assert mock_elements_worker.report.report_data["elements"] == {}
def test_create_classifications_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=None,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element="not element type",
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_classifications_wrong_classifications(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=None,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=1234,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": None,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": 1234,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": None,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": "wrong confidence",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 2.00,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": "wrong high_confidence",
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: high_confidence should be of type bool"
)
def test_create_classifications_api_error(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=500,
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
with pytest.raises(ErrorResponse):
mock_elements_worker.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
]
def test_create_classifications(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=200,
json={
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
},
],
},
)
mock_elements_worker_with_cache.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classification/bulk/"),
]
assert json.loads(responses.calls[-1].request.body) == {
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": classes,
}
# Check that created classifications were properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID(elt.id),
class_name="portrait",
confidence=0.75,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedClassification(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID(elt.id),
class_name="landscape",
confidence=0.25,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import os
import tempfile
from argparse import Namespace from argparse import Namespace
from uuid import UUID from uuid import UUID
...@@ -11,58 +9,102 @@ from apistar.exceptions import ErrorResponse ...@@ -11,58 +9,102 @@ from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedImage from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Element from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker.element import MissingTypeError
from . import BASE_API_CALLS from . import BASE_API_CALLS
def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker): def test_check_required_types_argument_types(mock_elements_worker):
_, path = tempfile.mkstemp() corpus_id = "12341234-1234-1234-1234-123412341234"
with open(path, "w") as f: worker = ElementsWorker()
json.dump({}, f)
with pytest.raises(AssertionError) as e:
worker.check_required_types(corpus_id)
assert str(e.value) == "At least one element type slug is required."
with pytest.raises(AssertionError) as e:
worker.check_required_types(corpus_id, "lol", 42)
assert str(e.value) == "Element type slugs must be strings."
monkeypatch.setenv("TASK_ELEMENTS", path) def test_check_required_types(monkeypatch, tmp_path, mock_elements_worker, responses):
elements_path = tmp_path / "elements.json"
elements_path.write_text("[]")
monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
corpus_id = "12341234-1234-1234-1234-123412341234"
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{corpus_id}/",
json={
"id": corpus_id,
"name": "Some Corpus",
"types": [{"slug": "folder"}, {"slug": "page"}],
},
)
worker = ElementsWorker()
worker.configure()
assert worker.check_required_types(corpus_id, "page")
assert worker.check_required_types(corpus_id, "page", "folder")
with pytest.raises(MissingTypeError) as e:
assert worker.check_required_types(corpus_id, "page", "text_line", "act")
assert (
str(e.value)
== "Element type(s) act, text_line were not found in the Some Corpus corpus (12341234-1234-1234-1234-123412341234)."
)
def test_list_elements_elements_list_arg_wrong_type(
monkeypatch, tmp_path, mock_elements_worker
):
elements_path = tmp_path / "elements.json"
elements_path.write_text("{}")
monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
worker = ElementsWorker() worker = ElementsWorker()
worker.configure() worker.configure()
os.unlink(path)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
worker.list_elements() worker.list_elements()
assert str(e.value) == "Elements list must be a list" assert str(e.value) == "Elements list must be a list"
def test_list_elements_elements_list_arg_empty_list(monkeypatch, mock_elements_worker): def test_list_elements_elements_list_arg_empty_list(
_, path = tempfile.mkstemp() monkeypatch, tmp_path, mock_elements_worker
with open(path, "w") as f: ):
json.dump([], f) elements_path = tmp_path / "elements.json"
elements_path.write_text("[]")
monkeypatch.setenv("TASK_ELEMENTS", path) monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
worker = ElementsWorker() worker = ElementsWorker()
worker.configure() worker.configure()
os.unlink(path)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
worker.list_elements() worker.list_elements()
assert str(e.value) == "No elements in elements list" assert str(e.value) == "No elements in elements list"
def test_list_elements_elements_list_arg_missing_id(monkeypatch, mock_elements_worker): def test_list_elements_elements_list_arg_missing_id(
_, path = tempfile.mkstemp() monkeypatch, tmp_path, mock_elements_worker
with open(path, "w") as f: ):
elements_path = tmp_path / "elements.json"
with elements_path.open("w") as f:
json.dump([{"type": "volume"}], f) json.dump([{"type": "volume"}], f)
monkeypatch.setenv("TASK_ELEMENTS", path) monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
worker = ElementsWorker() worker = ElementsWorker()
worker.configure() worker.configure()
os.unlink(path)
elt_list = worker.list_elements() elt_list = worker.list_elements()
assert elt_list == [] assert elt_list == []
def test_list_elements_elements_list_arg(monkeypatch, mock_elements_worker): def test_list_elements_elements_list_arg(monkeypatch, tmp_path, mock_elements_worker):
_, path = tempfile.mkstemp() elements_path = tmp_path / "elements.json"
with open(path, "w") as f: with elements_path.open("w") as f:
json.dump( json.dump(
[ [
{"id": "volumeid", "type": "volume"}, {"id": "volumeid", "type": "volume"},
...@@ -73,10 +115,9 @@ def test_list_elements_elements_list_arg(monkeypatch, mock_elements_worker): ...@@ -73,10 +115,9 @@ def test_list_elements_elements_list_arg(monkeypatch, mock_elements_worker):
f, f,
) )
monkeypatch.setenv("TASK_ELEMENTS", path) monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
worker = ElementsWorker() worker = ElementsWorker()
worker.configure() worker.configure()
os.unlink(path)
elt_list = worker.list_elements() elt_list = worker.list_elements()
...@@ -103,9 +144,9 @@ def test_list_elements_element_arg(mocker, mock_elements_worker): ...@@ -103,9 +144,9 @@ def test_list_elements_element_arg(mocker, mock_elements_worker):
assert elt_list == ["volumeid", "pageid"] assert elt_list == ["volumeid", "pageid"]
def test_list_elements_both_args_error(mocker, mock_elements_worker): def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path):
_, path = tempfile.mkstemp() elements_path = tmp_path / "elements.json"
with open(path, "w") as f: with elements_path.open("w") as f:
json.dump( json.dump(
[ [
{"id": "volumeid", "type": "volume"}, {"id": "volumeid", "type": "volume"},
...@@ -120,7 +161,7 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker): ...@@ -120,7 +161,7 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker):
return_value=Namespace( return_value=Namespace(
element=["anotherid", "againanotherid"], element=["anotherid", "againanotherid"],
verbose=False, verbose=False,
elements_list=open(path), elements_list=elements_path.open(),
database=None, database=None,
dev=False, dev=False,
), ),
...@@ -128,7 +169,6 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker): ...@@ -128,7 +169,6 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker):
worker = ElementsWorker() worker = ElementsWorker()
worker.configure() worker.configure()
os.unlink(path)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
worker.list_elements() worker.list_elements()
...@@ -847,7 +887,7 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path): ...@@ -847,7 +887,7 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}] assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
# Check that created elements were properly stored in SQLite cache # Check that created elements were properly stored in SQLite cache
assert os.path.isfile(tmp_path / "db.sqlite") assert (tmp_path / "db.sqlite").is_file()
assert list(CachedElement.select()) == [ assert list(CachedElement.select()) == [
CachedElement( CachedElement(
......
...@@ -213,7 +213,7 @@ def test_create_entity(responses, mock_elements_worker): ...@@ -213,7 +213,7 @@ def test_create_entity(responses, mock_elements_worker):
assert json.loads(responses.calls[-1].request.body) == { assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob", "name": "Bob Bob",
"type": "person", "type": "person",
"metas": None, "metas": {},
"validated": None, "validated": None,
"corpus": "12341234-1234-1234-1234-123412341234", "corpus": "12341234-1234-1234-1234-123412341234",
"worker_version": "12341234-1234-1234-1234-123412341234", "worker_version": "12341234-1234-1234-1234-123412341234",
...@@ -247,7 +247,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): ...@@ -247,7 +247,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
assert json.loads(responses.calls[-1].request.body) == { assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob", "name": "Bob Bob",
"type": "person", "type": "person",
"metas": None, "metas": {},
"validated": None, "validated": None,
"corpus": "12341234-1234-1234-1234-123412341234", "corpus": "12341234-1234-1234-1234-123412341234",
"worker_version": "12341234-1234-1234-1234-123412341234", "worker_version": "12341234-1234-1234-1234-123412341234",
...@@ -261,7 +261,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): ...@@ -261,7 +261,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
type="person", type="person",
name="Bob Bob", name="Bob Bob",
validated=False, validated=False,
metas=None, metas={},
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
) )
] ]
...@@ -449,6 +449,7 @@ def test_create_transcription_entity(responses, mock_elements_worker): ...@@ -449,6 +449,7 @@ def test_create_transcription_entity(responses, mock_elements_worker):
"entity": "11111111-1111-1111-1111-111111111111", "entity": "11111111-1111-1111-1111-111111111111",
"offset": 5, "offset": 5,
"length": 10, "length": 10,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
} }
...@@ -504,6 +505,7 @@ def test_create_transcription_entity_with_cache( ...@@ -504,6 +505,7 @@ def test_create_transcription_entity_with_cache(
"entity": "11111111-1111-1111-1111-111111111111", "entity": "11111111-1111-1111-1111-111111111111",
"offset": 5, "offset": 5,
"length": 10, "length": 10,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
} }
# Check that created transcription entity was properly stored in SQLite cache # Check that created transcription entity was properly stored in SQLite cache
...@@ -513,5 +515,6 @@ def test_create_transcription_entity_with_cache( ...@@ -513,5 +515,6 @@ def test_create_transcription_entity_with_cache(
entity=UUID("11111111-1111-1111-1111-111111111111"), entity=UUID("11111111-1111-1111-1111-111111111111"),
offset=5, offset=5,
length=10, length=10,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
) )
] ]
arkindex-base-worker==0.2.0 arkindex-base-worker==0.2.1