Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (6)
0.2.0-beta2 0.2.0-rc2
...@@ -112,9 +112,22 @@ class CachedTranscription(Model): ...@@ -112,9 +112,22 @@ class CachedTranscription(Model):
table_name = "transcriptions" table_name = "transcriptions"
class CachedClassification(Model):
id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="classifications")
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "classifications"
# Add all the managed models in that list # Add all the managed models in that list
# It's used here, but also in unit tests # It's used here, but also in unit tests
MODELS = [CachedImage, CachedElement, CachedTranscription] MODELS = [CachedImage, CachedElement, CachedTranscription, CachedClassification]
def init_cache_db(path): def init_cache_db(path):
......
...@@ -136,16 +136,15 @@ class BaseWorker(object): ...@@ -136,16 +136,15 @@ class BaseWorker(object):
if self.args.database is not None: if self.args.database is not None:
self.use_cache = True self.use_cache = True
task_id = os.environ.get("PONOS_TASK")
if self.use_cache is True: if self.use_cache is True:
if self.args.database is not None: if self.args.database is not None:
assert os.path.isfile( assert os.path.isfile(
self.args.database self.args.database
), f"Database in {self.args.database} does not exist" ), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database self.cache_path = self.args.database
elif os.environ.get("TASK_ID"): elif task_id:
cache_dir = os.path.join( cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id)
os.environ.get("PONOS_DATA", "/data"), os.environ.get("TASK_ID")
)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite") self.cache_path = os.path.join(cache_dir, "db.sqlite")
else: else:
...@@ -157,7 +156,6 @@ class BaseWorker(object): ...@@ -157,7 +156,6 @@ class BaseWorker(object):
logger.debug("Cache is disabled") logger.debug("Cache is disabled")
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
task_id = os.environ.get("TASK_ID")
if self.use_cache and self.args.database is None and task_id is not None: if self.use_cache and self.args.database is None and task_id is not None:
task = self.request("RetrieveTaskFromAgent", id=task_id) task = self.request("RetrieveTaskFromAgent", id=task_id)
merge_parents_cache( merge_parents_cache(
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from peewee import IntegrityError
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element from arkindex_worker.models import Element
...@@ -24,6 +28,9 @@ class ClassificationMixin(object): ...@@ -24,6 +28,9 @@ class ClassificationMixin(object):
Return the ID corresponding to the given class name on a specific corpus Return the ID corresponding to the given class name on a specific corpus
This method will automatically create missing classes This method will automatically create missing classes
""" """
if corpus_id is None:
corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
if not self.classes.get(corpus_id): if not self.classes.get(corpus_id):
self.load_corpus_classes(corpus_id) self.load_corpus_classes(corpus_id)
...@@ -60,8 +67,8 @@ class ClassificationMixin(object): ...@@ -60,8 +67,8 @@ class ClassificationMixin(object):
Create a classification on the given element through API Create a classification on the given element through API
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element" ), "element shouldn't be null and should be an Element or CachedElement"
assert ml_class and isinstance( assert ml_class and isinstance(
ml_class, str ml_class, str
), "ml_class shouldn't be null and should be of type str" ), "ml_class shouldn't be null and should be of type str"
...@@ -78,18 +85,36 @@ class ClassificationMixin(object): ...@@ -78,18 +85,36 @@ class ClassificationMixin(object):
return return
try: try:
self.request( created = self.request(
"CreateClassification", "CreateClassification",
body={ body={
"element": element.id, "element": str(element.id),
"ml_class": self.get_ml_class_id(element.corpus.id, ml_class), "ml_class": self.get_ml_class_id(None, ml_class),
"worker_version": self.worker_version_id, "worker_version": self.worker_version_id,
"confidence": confidence, "confidence": confidence,
"high_confidence": high_confidence, "high_confidence": high_confidence,
}, },
) )
except ErrorResponse as e:
if self.use_cache:
# Store classification in local cache
try:
to_insert = [
{
"id": created["id"],
"element_id": element.id,
"class_name": ml_class,
"confidence": created["confidence"],
"state": created["state"],
"worker_version_id": self.worker_version_id,
}
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classification in local cache: {e}"
)
except ErrorResponse as e:
# Detect already existing classification # Detect already existing classification
if ( if (
e.status_code == 400 e.status_code == 400
......
...@@ -172,8 +172,8 @@ class ElementMixin(object): ...@@ -172,8 +172,8 @@ class ElementMixin(object):
List children of an element List children of an element
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element" ), "element shouldn't be null and should be an Element or CachedElement"
query_params = {} query_params = {}
if best_class is not None: if best_class is not None:
assert isinstance(best_class, str) or isinstance( assert isinstance(best_class, str) or isinstance(
......
...@@ -233,8 +233,8 @@ class TranscriptionMixin(object): ...@@ -233,8 +233,8 @@ class TranscriptionMixin(object):
List transcriptions on an element List transcriptions on an element
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element" ), "element shouldn't be null and should be an Element or CachedElement"
query_params = {} query_params = {}
if element_type: if element_type:
assert isinstance(element_type, str), "element_type should be of type str" assert isinstance(element_type, str), "element_type should be of type str"
......
...@@ -165,6 +165,7 @@ def mock_user_api(responses): ...@@ -165,6 +165,7 @@ def mock_user_api(responses):
def mock_elements_worker(monkeypatch, mock_worker_version_api): def mock_elements_worker(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest""" """Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker() worker = ElementsWorker()
worker.configure() worker.configure()
...@@ -173,11 +174,11 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api): ...@@ -173,11 +174,11 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
@pytest.fixture @pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api): def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
"""Build a BaseWorker using SQLite cache, also mocking a TASK_ID""" """Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True) worker = BaseWorker(use_cache=True)
monkeypatch.setenv("TASK_ID", "my_task") monkeypatch.setenv("PONOS_TASK", "my_task")
return worker return worker
...@@ -185,6 +186,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api): ...@@ -185,6 +186,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api): def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest""" """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker(use_cache=True) worker = ElementsWorker(use_cache=True)
worker.configure() worker.configure()
......
...@@ -54,7 +54,8 @@ def test_create_tables(tmp_path): ...@@ -54,7 +54,8 @@ def test_create_tables(tmp_path):
init_cache_db(db_path) init_cache_db(db_path)
create_tables() create_tables()
expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
from uuid import UUID
import pytest import pytest
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element from arkindex_worker.models import Element
...@@ -159,7 +161,10 @@ def test_create_classification_wrong_element(mock_elements_worker): ...@@ -159,7 +161,10 @@ def test_create_classification_wrong_element(mock_elements_worker):
confidence=0.42, confidence=0.42,
high_confidence=True, high_confidence=True,
) )
assert str(e.value) == "element shouldn't be null and should be of type Element" assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification( mock_elements_worker.create_classification(
...@@ -168,16 +173,14 @@ def test_create_classification_wrong_element(mock_elements_worker): ...@@ -168,16 +173,14 @@ def test_create_classification_wrong_element(mock_elements_worker):
confidence=0.42, confidence=0.42,
high_confidence=True, high_confidence=True,
) )
assert str(e.value) == "element shouldn't be null and should be of type Element" assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_classification_wrong_ml_class(mock_elements_worker, responses): def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
elt = Element( elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification( mock_elements_worker.create_classification(
...@@ -249,12 +252,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker): ...@@ -249,12 +252,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker):
mock_elements_worker.classes = { mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"} "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
} }
elt = Element( elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification( mock_elements_worker.create_classification(
element=elt, element=elt,
...@@ -308,12 +306,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker): ...@@ -308,12 +306,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker):
mock_elements_worker.classes = { mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"} "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
} }
elt = Element( elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification( mock_elements_worker.create_classification(
...@@ -342,12 +335,7 @@ def test_create_classification_api_error(responses, mock_elements_worker): ...@@ -342,12 +335,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
mock_elements_worker.classes = { mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"} "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
} }
elt = Element( elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
responses.add( responses.add(
responses.POST, responses.POST,
"http://testserver/api/v1/classifications/", "http://testserver/api/v1/classifications/",
...@@ -379,12 +367,7 @@ def test_create_classification(responses, mock_elements_worker): ...@@ -379,12 +367,7 @@ def test_create_classification(responses, mock_elements_worker):
mock_elements_worker.classes = { mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"} "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
} }
elt = Element( elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
responses.add( responses.add(
responses.POST, responses.POST,
"http://testserver/api/v1/classifications/", "http://testserver/api/v1/classifications/",
...@@ -419,16 +402,72 @@ def test_create_classification(responses, mock_elements_worker): ...@@ -419,16 +402,72 @@ def test_create_classification(responses, mock_elements_worker):
] == {"a_class": 1} ] == {"a_class": 1}
def test_create_classification_with_cache(responses, mock_elements_worker_with_cache):
mock_elements_worker_with_cache.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
"high_confidence": True,
"state": "pending",
},
)
mock_elements_worker_with_cache.create_classification(
element=elt,
ml_class="a_class",
confidence=0.42,
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
]
assert json.loads(responses.calls[2].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
"high_confidence": True,
}
# Classification has been created and reported
assert mock_elements_worker_with_cache.report.report_data["elements"][elt.id][
"classifications"
] == {"a_class": 1}
# Check that created classification was properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("56785678-5678-5678-5678-567856785678"),
element_id=UUID(elt.id),
class_name="a_class",
confidence=0.42,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_classification_duplicate(responses, mock_elements_worker): def test_create_classification_duplicate(responses, mock_elements_worker):
mock_elements_worker.classes = { mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"} "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
} }
elt = Element( elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
responses.add( responses.add(
responses.POST, responses.POST,
"http://testserver/api/v1/classifications/", "http://testserver/api/v1/classifications/",
......
...@@ -900,11 +900,17 @@ def test_create_elements_integrity_error( ...@@ -900,11 +900,17 @@ def test_create_elements_integrity_error(
def test_list_element_children_wrong_element(mock_elements_worker): def test_list_element_children_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(element=None) mock_elements_worker.list_element_children(element=None)
assert str(e.value) == "element shouldn't be null and should be of type Element" assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(element="not element type") mock_elements_worker.list_element_children(element="not element type")
assert str(e.value) == "element shouldn't be null and should be of type Element" assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_list_element_children_wrong_best_class(mock_elements_worker): def test_list_element_children_wrong_best_class(mock_elements_worker):
...@@ -1125,7 +1131,7 @@ def test_list_element_children_with_cache_unhandled_param( ...@@ -1125,7 +1131,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element should give all elements inserted # Filter on element should give all elements inserted
( (
{ {
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
}, },
( (
"11111111-1111-1111-1111-111111111111", "11111111-1111-1111-1111-111111111111",
...@@ -1135,7 +1141,7 @@ def test_list_element_children_with_cache_unhandled_param( ...@@ -1135,7 +1141,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element and page should give the second element # Filter on element and page should give the second element
( (
{ {
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"type": "page", "type": "page",
}, },
("22222222-2222-2222-2222-222222222222",), ("22222222-2222-2222-2222-222222222222",),
...@@ -1143,7 +1149,7 @@ def test_list_element_children_with_cache_unhandled_param( ...@@ -1143,7 +1149,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element and worker version should give all elements # Filter on element and worker version should give all elements
( (
{ {
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": "56785678-5678-5678-5678-567856785678", "worker_version": "56785678-5678-5678-5678-567856785678",
}, },
( (
...@@ -1154,7 +1160,7 @@ def test_list_element_children_with_cache_unhandled_param( ...@@ -1154,7 +1160,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element, type something and worker version should give first # Filter on element, type something and worker version should give first
( (
{ {
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"type": "something", "type": "something",
"worker_version": "56785678-5678-5678-5678-567856785678", "worker_version": "56785678-5678-5678-5678-567856785678",
}, },
......
...@@ -143,7 +143,42 @@ def test_create_transcription_api_error(responses, mock_elements_worker): ...@@ -143,7 +143,42 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
] ]
def test_create_transcription(responses, mock_elements_worker_with_cache): def test_create_transcription(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "i am a line",
"score": 0.42,
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score=0.42,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
]
assert json.loads(responses.calls[2].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
}
def test_create_transcription_with_cache(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add( responses.add(
...@@ -933,7 +968,72 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker ...@@ -933,7 +968,72 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
] ]
def test_create_element_transcriptions(responses, mock_elements_worker_with_cache): def test_create_element_transcriptions(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=200,
json=[
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
],
)
annotations = mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
]
assert json.loads(responses.calls[2].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
"return_elements": True,
}
assert annotations == [
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
]
def test_create_element_transcriptions_with_cache(
responses, mock_elements_worker_with_cache
):
elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing") elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add( responses.add(
...@@ -1041,11 +1141,17 @@ def test_create_element_transcriptions(responses, mock_elements_worker_with_cach ...@@ -1041,11 +1141,17 @@ def test_create_element_transcriptions(responses, mock_elements_worker_with_cach
def test_list_transcriptions_wrong_element(mock_elements_worker): def test_list_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element=None) mock_elements_worker.list_transcriptions(element=None)
assert str(e.value) == "element shouldn't be null and should be of type Element" assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element="not element type") mock_elements_worker.list_transcriptions(element="not element type")
assert str(e.value) == "element shouldn't be null and should be of type Element" assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_list_transcriptions_wrong_element_type(mock_elements_worker): def test_list_transcriptions_wrong_element_type(mock_elements_worker):
...@@ -1215,7 +1321,7 @@ def test_list_transcriptions_with_cache_skip_recursive( ...@@ -1215,7 +1321,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
# Filter on element should give all elements inserted # Filter on element should give all elements inserted
( (
{ {
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
}, },
( (
"11111111-1111-1111-1111-111111111111", "11111111-1111-1111-1111-111111111111",
...@@ -1225,7 +1331,7 @@ def test_list_transcriptions_with_cache_skip_recursive( ...@@ -1225,7 +1331,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
# Filter on element and worker version should give first element # Filter on element and worker version should give first element
( (
{ {
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": "56785678-5678-5678-5678-567856785678", "worker_version": "56785678-5678-5678-5678-567856785678",
}, },
("11111111-1111-1111-1111-111111111111",), ("11111111-1111-1111-1111-111111111111",),
......