Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (6)
0.2.0-beta2
0.2.0-rc2
......@@ -112,9 +112,22 @@ class CachedTranscription(Model):
table_name = "transcriptions"
class CachedClassification(Model):
id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="classifications")
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "classifications"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [CachedImage, CachedElement, CachedTranscription]
MODELS = [CachedImage, CachedElement, CachedTranscription, CachedClassification]
def init_cache_db(path):
......
......@@ -136,16 +136,15 @@ class BaseWorker(object):
if self.args.database is not None:
self.use_cache = True
task_id = os.environ.get("PONOS_TASK")
if self.use_cache is True:
if self.args.database is not None:
assert os.path.isfile(
self.args.database
), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database
elif os.environ.get("TASK_ID"):
cache_dir = os.path.join(
os.environ.get("PONOS_DATA", "/data"), os.environ.get("TASK_ID")
)
elif task_id:
cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
else:
......@@ -157,7 +156,6 @@ class BaseWorker(object):
logger.debug("Cache is disabled")
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
task_id = os.environ.get("TASK_ID")
if self.use_cache and self.args.database is None and task_id is not None:
task = self.request("RetrieveTaskFromAgent", id=task_id)
merge_parents_cache(
......
# -*- coding: utf-8 -*-
import os
from apistar.exceptions import ErrorResponse
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
......@@ -24,6 +28,9 @@ class ClassificationMixin(object):
Return the ID corresponding to the given class name on a specific corpus
This method will automatically create missing classes
"""
if corpus_id is None:
corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
if not self.classes.get(corpus_id):
self.load_corpus_classes(corpus_id)
......@@ -60,8 +67,8 @@ class ClassificationMixin(object):
Create a classification on the given element through API
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert ml_class and isinstance(
ml_class, str
), "ml_class shouldn't be null and should be of type str"
......@@ -78,18 +85,36 @@ class ClassificationMixin(object):
return
try:
self.request(
created = self.request(
"CreateClassification",
body={
"element": element.id,
"ml_class": self.get_ml_class_id(element.corpus.id, ml_class),
"element": str(element.id),
"ml_class": self.get_ml_class_id(None, ml_class),
"worker_version": self.worker_version_id,
"confidence": confidence,
"high_confidence": high_confidence,
},
)
except ErrorResponse as e:
if self.use_cache:
# Store classification in local cache
try:
to_insert = [
{
"id": created["id"],
"element_id": element.id,
"class_name": ml_class,
"confidence": created["confidence"],
"state": created["state"],
"worker_version_id": self.worker_version_id,
}
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classification in local cache: {e}"
)
except ErrorResponse as e:
# Detect already existing classification
if (
e.status_code == 400
......
......@@ -172,8 +172,8 @@ class ElementMixin(object):
List children of an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
query_params = {}
if best_class is not None:
assert isinstance(best_class, str) or isinstance(
......
......@@ -233,8 +233,8 @@ class TranscriptionMixin(object):
List transcriptions on an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
query_params = {}
if element_type:
assert isinstance(element_type, str), "element_type should be of type str"
......
......@@ -165,6 +165,7 @@ def mock_user_api(responses):
def mock_elements_worker(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker()
worker.configure()
......@@ -173,11 +174,11 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
"""Build a BaseWorker using SQLite cache, also mocking a TASK_ID"""
"""Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True)
monkeypatch.setenv("TASK_ID", "my_task")
monkeypatch.setenv("PONOS_TASK", "my_task")
return worker
......@@ -185,6 +186,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker(use_cache=True)
worker.configure()
......
......@@ -54,7 +54,8 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
......
# -*- coding: utf-8 -*-
import json
from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
......@@ -159,7 +161,10 @@ def test_create_classification_wrong_element(mock_elements_worker):
confidence=0.42,
high_confidence=True,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -168,16 +173,14 @@ def test_create_classification_wrong_element(mock_elements_worker):
confidence=0.42,
high_confidence=True,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -249,12 +252,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
element=elt,
......@@ -308,12 +306,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -342,12 +335,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
......@@ -379,12 +367,7 @@ def test_create_classification(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
......@@ -419,16 +402,72 @@ def test_create_classification(responses, mock_elements_worker):
] == {"a_class": 1}
def test_create_classification_with_cache(responses, mock_elements_worker_with_cache):
mock_elements_worker_with_cache.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
"high_confidence": True,
"state": "pending",
},
)
mock_elements_worker_with_cache.create_classification(
element=elt,
ml_class="a_class",
confidence=0.42,
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
]
assert json.loads(responses.calls[2].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
"high_confidence": True,
}
# Classification has been created and reported
assert mock_elements_worker_with_cache.report.report_data["elements"][elt.id][
"classifications"
] == {"a_class": 1}
# Check that created classification was properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("56785678-5678-5678-5678-567856785678"),
element_id=UUID(elt.id),
class_name="a_class",
confidence=0.42,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_classification_duplicate(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
......
......@@ -900,11 +900,17 @@ def test_create_elements_integrity_error(
def test_list_element_children_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(element=None)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(element="not element type")
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_list_element_children_wrong_best_class(mock_elements_worker):
......@@ -1125,7 +1131,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element should give all elements inserted
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
},
(
"11111111-1111-1111-1111-111111111111",
......@@ -1135,7 +1141,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element and page should give the second element
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"type": "page",
},
("22222222-2222-2222-2222-222222222222",),
......@@ -1143,7 +1149,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element and worker version should give all elements
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
(
......@@ -1154,7 +1160,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element, type something and worker version should give first
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"type": "something",
"worker_version": "56785678-5678-5678-5678-567856785678",
},
......
......@@ -143,7 +143,42 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
]
def test_create_transcription(responses, mock_elements_worker_with_cache):
def test_create_transcription(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "i am a line",
"score": 0.42,
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score=0.42,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
]
assert json.loads(responses.calls[2].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
}
def test_create_transcription_with_cache(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
......@@ -933,7 +968,72 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
]
def test_create_element_transcriptions(responses, mock_elements_worker_with_cache):
def test_create_element_transcriptions(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=200,
json=[
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
],
)
annotations = mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
]
assert json.loads(responses.calls[2].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
"return_elements": True,
}
assert annotations == [
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
]
def test_create_element_transcriptions_with_cache(
responses, mock_elements_worker_with_cache
):
elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
......@@ -1041,11 +1141,17 @@ def test_create_element_transcriptions(responses, mock_elements_worker_with_cach
def test_list_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element=None)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element="not element type")
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_list_transcriptions_wrong_element_type(mock_elements_worker):
......@@ -1215,7 +1321,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
# Filter on element should give all elements inserted
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
},
(
"11111111-1111-1111-1111-111111111111",
......@@ -1225,7 +1331,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
# Filter on element and worker version should give first element
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
("11111111-1111-1111-1111-111111111111",),
......