Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (14)
Showing
with 479 additions and 244 deletions
0.2.0-beta1
0.2.0-rc2
......@@ -17,6 +17,7 @@ from peewee import (
)
from arkindex_worker import logger
from arkindex_worker.image import open_image, polygon_bounding_box
db = SqliteDatabase(None)
......@@ -50,7 +51,7 @@ class CachedElement(Model):
id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True)
type = CharField(max_length=50)
image_id = ForeignKeyField(CachedImage, backref="elements", null=True)
image = ForeignKeyField(CachedImage, backref="elements", null=True)
polygon = JSONField(null=True)
initial = BooleanField(default=False)
worker_version_id = UUIDField(null=True)
......@@ -59,10 +60,49 @@ class CachedElement(Model):
database = db
table_name = "elements"
def open_image(self, *args, max_size=None, **kwargs):
"""
Open this element's image as a Pillow image.
This does not crop the image to the element's polygon.
IIIF servers with maxWidth, maxHeight or maxArea restrictions on image size are not supported.
:param max_size: Subresolution of the image.
"""
if not self.image_id or not self.polygon:
raise ValueError(f"Element {self.id} has no image")
if max_size is None:
resize = "full"
else:
bounding_box = polygon_bounding_box(self.polygon)
# Do not resize for polygons that do not exactly match the images
if (
bounding_box.width != self.image.width
or bounding_box.height != self.image.height
):
resize = "full"
logger.warning(
"Only full size elements covered, downloading full size image"
)
# Do not resize when the image is below the maximum size
elif self.image.width <= max_size and self.image.height <= max_size:
resize = "full"
else:
ratio = max_size / max(self.image.width, self.image.height)
new_width, new_height = int(self.image.width * ratio), int(
self.image.height * ratio
)
resize = f"{new_width},{new_height}"
url = self.image.url
if not url.endswith("/"):
url += "/"
return open_image(f"{url}full/{resize}/0/default.jpg", *args, **kwargs)
class CachedTranscription(Model):
id = UUIDField(primary_key=True)
element_id = ForeignKeyField(CachedElement, backref="transcriptions")
element = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField()
confidence = FloatField()
worker_version_id = UUIDField()
......@@ -72,9 +112,22 @@ class CachedTranscription(Model):
table_name = "transcriptions"
class CachedClassification(Model):
id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="classifications")
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "classifications"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [CachedImage, CachedElement, CachedTranscription]
MODELS = [CachedImage, CachedElement, CachedTranscription, CachedClassification]
def init_cache_db(path):
......
......@@ -48,10 +48,27 @@ def download_image(url):
Download an image and open it with Pillow
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
try:
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
except requests.exceptions.SSLError:
logger.warning(
"An SSLError occurred during image download, retrying with a weaker and unsafe SSL configuration"
)
# Saving current ciphers
previous_ciphers = requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS
# Downgrading ciphers to download the image
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL:@SECLEVEL=1"
resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
# Restoring previous ciphers
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = previous_ciphers
resp.raise_for_status()
# Preprocess the image and prepare it for classification
......
......@@ -82,6 +82,7 @@ class Element(MagicDict):
def open_image(self, *args, max_size=None, **kwargs):
"""
Open this element's image as a Pillow image.
This does not crop the image to the element's polygon.
:param max_size: Subresolution of the image.
"""
if not self.get("zone"):
......
......@@ -8,6 +8,7 @@ from enum import Enum
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
from arkindex_worker.worker.base import BaseWorker
......@@ -16,7 +17,7 @@ from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin, EntityType # noqa: F401
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
from arkindex_worker.worker.transcription import TranscriptionMixin
from arkindex_worker.worker.version import MANUAL_SLUG, WorkerVersionMixin # noqa: F401
from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
class ActivityState(Enum):
......@@ -64,8 +65,15 @@ class ElementsWorker(
), "elements-list and element CLI args shouldn't be both set"
out = []
# Load from the cache when available
# Flake8 wants us to use 'is True', but Peewee only supports '== True'
cache_query = CachedElement.select().where(
CachedElement.initial == True # noqa: E712
)
if self.use_cache and cache_query.exists():
return cache_query
# Process elements from JSON file
if self.args.elements_list:
elif self.args.elements_list:
data = json.load(self.args.elements_list)
assert isinstance(data, list), "Elements list must be a list"
assert len(data), "No elements in elements list"
......@@ -92,31 +100,40 @@ class ElementsWorker(
# Process every element
count = len(elements)
failed = 0
for i, element_id in enumerate(elements, start=1):
for i, item in enumerate(elements, start=1):
element = None
try:
# Load element using Arkindex API
element = Element(**self.request("RetrieveElement", id=element_id))
if self.use_cache:
# Just use the result of list_elements as the element
element = item
else:
# Load element using the Arkindex API
element = Element(**self.request("RetrieveElement", id=item))
logger.info(f"Processing {element} ({i}/{count})")
# Report start of process, run process, then report end of process
self.update_activity(element, ActivityState.Started)
self.update_activity(element.id, ActivityState.Started)
self.process_element(element)
self.update_activity(element, ActivityState.Processed)
except ErrorResponse as e:
failed += 1
logger.warning(
f"An API error occurred while processing element {element_id}: {e.title} - {e.content}",
exc_info=e if self.args.verbose else None,
)
self.update_activity(element, ActivityState.Error)
self.report.error(element_id, e)
self.update_activity(element.id, ActivityState.Processed)
except Exception as e:
failed += 1
element_id = (
element.id
if isinstance(element, (Element, CachedElement))
else item
)
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing element {element_id}: {e.title} - {e.content}"
else:
message = f"Failed running worker on element {element_id}: {e}"
logger.warning(
f"Failed running worker on element {element_id}: {e}",
message,
exc_info=e if self.args.verbose else None,
)
self.update_activity(element, ActivityState.Error)
self.update_activity(element_id, ActivityState.Error)
self.report.error(element_id, e)
# Save report as local artifact
......@@ -134,14 +151,14 @@ class ElementsWorker(
def process_element(self, element):
"""Override this method to analyze an Arkindex element from the provided list"""
def update_activity(self, element, state):
def update_activity(self, element_id, state):
"""
Update worker activity for this element
This method should not raise a runtime exception, but simply warn users
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
assert element_id and isinstance(
element_id, (uuid.UUID, str)
), "element_id shouldn't be null and should be an UUID or str"
assert isinstance(state, ActivityState), "state should be an ActivityState"
if not self.features.get("workers_activity"):
......@@ -157,17 +174,17 @@ class ElementsWorker(
"UpdateWorkerActivity",
id=self.worker_version_id,
body={
"element_id": element.id,
"element_id": str(element_id),
"state": state.value,
},
)
logger.debug(f"Updated activity of element {element.id} to {state}")
logger.debug(f"Updated activity of element {element_id} to {state}")
return out
except ErrorResponse as e:
logger.warning(
f"Failed to update activity of element {element.id} to {state.value} due to an API error: {e.content}"
f"Failed to update activity of element {element_id} to {state.value} due to an API error: {e.content}"
)
except Exception as e:
logger.warning(
f"Failed to update activity of element {element.id} to {state.value}: {e}"
f"Failed to update activity of element {element_id} to {state.value}: {e}"
)
......@@ -136,16 +136,15 @@ class BaseWorker(object):
if self.args.database is not None:
self.use_cache = True
task_id = os.environ.get("PONOS_TASK")
if self.use_cache is True:
if self.args.database is not None:
assert os.path.isfile(
self.args.database
), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database
elif os.environ.get("TASK_ID"):
cache_dir = os.path.join(
os.environ.get("PONOS_DATA", "/data"), os.environ.get("TASK_ID")
)
elif task_id:
cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
else:
......@@ -157,7 +156,6 @@ class BaseWorker(object):
logger.debug("Cache is disabled")
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
task_id = os.environ.get("TASK_ID")
if self.use_cache and self.args.database is None and task_id is not None:
task = self.request("RetrieveTaskFromAgent", id=task_id)
merge_parents_cache(
......
# -*- coding: utf-8 -*-
import os
from apistar.exceptions import ErrorResponse
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
......@@ -24,6 +28,9 @@ class ClassificationMixin(object):
Return the ID corresponding to the given class name on a specific corpus
This method will automatically create missing classes
"""
if corpus_id is None:
corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
if not self.classes.get(corpus_id):
self.load_corpus_classes(corpus_id)
......@@ -60,8 +67,8 @@ class ClassificationMixin(object):
Create a classification on the given element through API
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert ml_class and isinstance(
ml_class, str
), "ml_class shouldn't be null and should be of type str"
......@@ -78,18 +85,36 @@ class ClassificationMixin(object):
return
try:
self.request(
created = self.request(
"CreateClassification",
body={
"element": element.id,
"ml_class": self.get_ml_class_id(element.corpus.id, ml_class),
"element": str(element.id),
"ml_class": self.get_ml_class_id(None, ml_class),
"worker_version": self.worker_version_id,
"confidence": confidence,
"high_confidence": high_confidence,
},
)
except ErrorResponse as e:
if self.use_cache:
# Store classification in local cache
try:
to_insert = [
{
"id": created["id"],
"element_id": element.id,
"class_name": ml_class,
"confidence": created["confidence"],
"state": created["state"],
"worker_version_id": self.worker_version_id,
}
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classification in local cache: {e}"
)
except ErrorResponse as e:
# Detect already existing classification
if (
e.status_code == 400
......
......@@ -172,8 +172,8 @@ class ElementMixin(object):
List children of an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
query_params = {}
if best_class is not None:
assert isinstance(best_class, str) or isinstance(
......
......@@ -233,8 +233,8 @@ class TranscriptionMixin(object):
List transcriptions on an element
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
query_params = {}
if element_type:
assert isinstance(element_type, str), "element_type should be of type str"
......
# -*- coding: utf-8 -*-
MANUAL_SLUG = "manual"
class WorkerVersionMixin(object):
def get_worker_version(self, worker_version_id: str) -> dict:
"""
Get worker version from cache if possible, otherwise make API request
"""
if worker_version_id is None:
raise ValueError("No worker version ID")
if worker_version_id in self._worker_version_cache:
return self._worker_version_cache[worker_version_id]
......@@ -18,37 +19,11 @@ class WorkerVersionMixin(object):
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Get worker version slug from cache if possible, otherwise make API request
Helper function to get the worker slug from element, classification or transcription.
Gets the worker version slug from cache if possible, otherwise makes an API request.
Returns None if there is no associated worker version.
Should use `get_ml_result_slug` instead of using this method directly
:type worker_version_id: A worker version UUID
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
def get_ml_result_slug(self, ml_result) -> str:
"""
Helper function to get the slug from element, classification or transcription
Can handle old and new (source vs worker_version)
:type ml_result: Element or classification or transcription
"""
if (
"source" in ml_result
and ml_result["source"]
and ml_result["source"]["slug"]
):
return ml_result["source"]["slug"]
elif "worker_version" in ml_result and ml_result["worker_version"]:
return self.get_worker_version_slug(ml_result["worker_version"])
# transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result and ml_result["worker_version_id"]:
return self.get_worker_version_slug(ml_result["worker_version_id"])
elif "worker_version" in ml_result and ml_result["worker_version"] is None:
return MANUAL_SLUG
elif (
"worker_version_id" in ml_result and ml_result["worker_version_id"] is None
):
return MANUAL_SLUG
else:
raise ValueError(f"Unable to get slug from: {ml_result}")
......@@ -2,6 +2,6 @@ arkindex-client==1.0.6
peewee==3.14.4
Pillow==8.1.0
python-gitlab==2.6.0
python-gnupg==0.4.6
python-gnupg==0.4.7
sh==1.14.1
tenacity==7.0.0
pytest==6.2.2
pytest==6.2.3
pytest-mock==3.5.1
pytest-responses==0.4.0
......@@ -165,6 +165,7 @@ def mock_user_api(responses):
def mock_elements_worker(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker()
worker.configure()
......@@ -173,11 +174,11 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
"""Build a BaseWorker using SQLite cache, also mocking a TASK_ID"""
"""Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True)
monkeypatch.setenv("TASK_ID", "my_task")
monkeypatch.setenv("PONOS_TASK", "my_task")
return worker
......@@ -185,6 +186,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker(use_cache=True)
worker.configure()
......
# -*- coding: utf-8 -*-
import os
from uuid import UUID
import pytest
from peewee import OperationalError
from arkindex_worker.cache import create_tables, db, init_cache_db
from arkindex_worker.cache import (
CachedElement,
CachedImage,
create_tables,
db,
init_cache_db,
)
def test_init_non_existent_path():
......@@ -47,7 +54,8 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
......@@ -61,3 +69,66 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
)
assert expected_schema == actual_schema
@pytest.mark.parametrize(
"image_width,image_height,polygon_width,polygon_height,max_size,expected_url",
[
# No max_size: no resize
(400, 600, 400, 600, None, "http://something/full/full/0/default.jpg"),
# max_size equal to the image size, no resize
(400, 600, 400, 600, 600, "http://something/full/full/0/default.jpg"),
(600, 400, 600, 400, 600, "http://something/full/full/0/default.jpg"),
(400, 400, 400, 400, 400, "http://something/full/full/0/default.jpg"),
# max_size is smaller than the image, resize
(400, 600, 400, 600, 400, "http://something/full/266,400/0/default.jpg"),
(400, 600, 200, 600, 400, "http://something/full/266,400/0/default.jpg"),
(600, 400, 600, 400, 400, "http://something/full/400,266/0/default.jpg"),
(400, 400, 400, 400, 200, "http://something/full/200,200/0/default.jpg"),
# max_size above the image size, no resize
(400, 600, 400, 600, 800, "http://something/full/full/0/default.jpg"),
(600, 400, 600, 400, 800, "http://something/full/full/0/default.jpg"),
(400, 400, 400, 400, 800, "http://something/full/full/0/default.jpg"),
],
)
def test_element_open_image(
mocker,
image_width,
image_height,
polygon_width,
polygon_height,
max_size,
expected_url,
):
open_mock = mocker.patch(
"arkindex_worker.cache.open_image", return_value="an image!"
)
image = CachedImage(
id=UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"),
width=image_width,
height=image_height,
url="http://something",
)
elt = CachedElement(
id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"),
type="element",
image=image,
polygon=[
[0, 0],
[image_width, 0],
[image_width, image_height],
[0, image_height],
[0, 0],
],
)
assert elt.open_image(max_size=max_size) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(expected_url)
def test_element_open_image_requires_image():
with pytest.raises(ValueError) as e:
CachedElement(id=UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")).open_image()
assert str(e.value) == "Element aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa has no image"
# -*- coding: utf-8 -*-
import json
from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
......@@ -159,7 +161,10 @@ def test_create_classification_wrong_element(mock_elements_worker):
confidence=0.42,
high_confidence=True,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -168,16 +173,14 @@ def test_create_classification_wrong_element(mock_elements_worker):
confidence=0.42,
high_confidence=True,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -249,12 +252,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
element=elt,
......@@ -308,12 +306,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -342,12 +335,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
......@@ -379,12 +367,7 @@ def test_create_classification(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
......@@ -419,16 +402,72 @@ def test_create_classification(responses, mock_elements_worker):
] == {"a_class": 1}
def test_create_classification_with_cache(responses, mock_elements_worker_with_cache):
mock_elements_worker_with_cache.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
"high_confidence": True,
"state": "pending",
},
)
mock_elements_worker_with_cache.create_classification(
element=elt,
ml_class="a_class",
confidence=0.42,
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
]
assert json.loads(responses.calls[2].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
"high_confidence": True,
}
# Classification has been created and reported
assert mock_elements_worker_with_cache.report.report_data["elements"][elt.id][
"classifications"
] == {"a_class": 1}
# Check that created classification was properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("56785678-5678-5678-5678-567856785678"),
element_id=UUID(elt.id),
class_name="a_class",
confidence=0.42,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_classification_duplicate(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
}
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/classifications/",
......
......@@ -900,11 +900,17 @@ def test_create_elements_integrity_error(
def test_list_element_children_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(element=None)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(element="not element type")
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_list_element_children_wrong_best_class(mock_elements_worker):
......@@ -1125,7 +1131,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element should give all elements inserted
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
},
(
"11111111-1111-1111-1111-111111111111",
......@@ -1135,7 +1141,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element and page should give the second element
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"type": "page",
},
("22222222-2222-2222-2222-222222222222",),
......@@ -1143,7 +1149,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element and worker version should give all elements
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
(
......@@ -1154,7 +1160,7 @@ def test_list_element_children_with_cache_unhandled_param(
# Filter on element, type something and worker version should give first
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"type": "something",
"worker_version": "56785678-5678-5678-5678-567856785678",
},
......
......@@ -143,7 +143,42 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
]
def test_create_transcription(responses, mock_elements_worker_with_cache):
def test_create_transcription(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
status=200,
json={
"id": "56785678-5678-5678-5678-567856785678",
"text": "i am a line",
"score": 0.42,
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
},
)
mock_elements_worker.create_transcription(
element=elt,
text="i am a line",
score=0.42,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
]
assert json.loads(responses.calls[2].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
}
def test_create_transcription_with_cache(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
......@@ -933,7 +968,72 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
]
def test_create_element_transcriptions(responses, mock_elements_worker_with_cache):
def test_create_element_transcriptions(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=200,
json=[
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
],
)
annotations = mock_elements_worker.create_element_transcriptions(
element=elt,
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
]
assert json.loads(responses.calls[2].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
"return_elements": True,
}
assert annotations == [
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
]
def test_create_element_transcriptions_with_cache(
responses, mock_elements_worker_with_cache
):
elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
......@@ -1041,11 +1141,17 @@ def test_create_element_transcriptions(responses, mock_elements_worker_with_cach
def test_list_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element=None)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(element="not element type")
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_list_transcriptions_wrong_element_type(mock_elements_worker):
......@@ -1215,7 +1321,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
# Filter on element should give all elements inserted
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
},
(
"11111111-1111-1111-1111-111111111111",
......@@ -1225,7 +1331,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
# Filter on element and worker version should give first element
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
("11111111-1111-1111-1111-111111111111",),
......
......@@ -4,8 +4,8 @@ import json
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Element
from arkindex_worker.worker import MANUAL_SLUG, ActivityState
from arkindex_worker.cache import CachedElement
from arkindex_worker.worker import ActivityState
# Common API calls for all workers
BASE_API_CALLS = [
......@@ -48,109 +48,21 @@ def test_get_worker_version__uses_cache(fake_dummy_worker):
assert not api_client.responses
def test_get_slug__old_style(fake_dummy_worker):
element = {"source": {"slug": TEST_SLUG}}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
def test_get_slug__worker_version(fake_dummy_worker):
api_client = fake_dummy_worker.api_client
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
element = {"worker_version": TEST_VERSION_ID}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
# assert that only one call to the API
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
api_client.add_response(
"RetrieveWorkerVersion",
fake_ufcn_worker_version,
id=fake_ufcn_worker_version["id"],
)
expected_slugs = [
"scikit_portrait_outlier_balsac",
"scikit_portrait_outlier_balsac",
"ufcn_line_historical",
]
slugs = [
fake_dummy_worker.get_ml_result_slug(clf)
for clf in fake_page_element["classifications"]
]
assert slugs == expected_slugs
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
version_id = "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996"
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=version_id)
slug = fake_dummy_worker.get_ml_result_slug(fake_transcriptions_small["results"][0])
assert slug == TEST_SLUG
assert len(api_client.history) == 1
assert not api_client.responses
@pytest.mark.parametrize(
"ml_result, expected_slug",
(
# old
({"source": {"slug": "test_123"}}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version": None}, "test_123"),
({"source": {"slug": "test_123"}, "worker_version_id": None}, "test_123"),
# new
({"source": None, "worker_version": "foo_1"}, "mock_slug"),
({"source": None, "worker_version_id": "foo_1"}, "mock_slug"),
({"worker_version_id": "foo_1"}, "mock_slug"),
# manual
({"worker_version_id": None}, MANUAL_SLUG),
({"worker_version": None}, MANUAL_SLUG),
({"source": None, "worker_version": None}, MANUAL_SLUG),
),
)
def test_get_ml_result_slug__ok(mocker, fake_dummy_worker, ml_result, expected_slug):
fake_dummy_worker.get_worker_version_slug = mocker.MagicMock()
fake_dummy_worker.get_worker_version_slug.return_value = "mock_slug"
def test_get_worker_version_slug(mocker, fake_dummy_worker):
fake_dummy_worker.get_worker_version = mocker.MagicMock()
fake_dummy_worker.get_worker_version.return_value = {
"id": TEST_VERSION_ID,
"worker": {"slug": "mock_slug"},
}
slug = fake_dummy_worker.get_ml_result_slug(ml_result)
assert slug == expected_slug
slug = fake_dummy_worker.get_worker_version_slug(TEST_VERSION_ID)
assert slug == "mock_slug"
@pytest.mark.parametrize(
"ml_result",
(
({},),
({"source": None},),
({"source": {"slug": None}},),
),
)
def test_get_ml_result_slug__fail(fake_dummy_worker, ml_result):
def test_get_worker_version_slug_none(fake_dummy_worker):
with pytest.raises(ValueError) as excinfo:
fake_dummy_worker.get_ml_result_slug(ml_result)
assert str(excinfo.value).startswith("Unable to get slug from")
fake_dummy_worker.get_worker_version_slug(None)
assert str(excinfo.value) == "No worker version ID"
def test_defaults(responses, mock_elements_worker):
......@@ -169,9 +81,7 @@ def test_feature_disabled(responses, mock_elements_worker):
"""Test disabled calls do not trigger any API calls"""
assert not mock_elements_worker.is_read_only
out = mock_elements_worker.update_activity(
Element({"id": "1234-deadbeef"}), ActivityState.Processed
)
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
assert out is None
assert len(responses.calls) == 2
......@@ -186,9 +96,7 @@ def test_readonly(responses, mock_elements_worker):
assert mock_elements_worker.is_read_only is True
mock_elements_worker.features["workers_activity"] = True
out = mock_elements_worker.update_activity(
Element({"id": "1234-deadbeef"}), ActivityState.Processed
)
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
assert out is None
assert len(responses.calls) == 2
......@@ -210,9 +118,7 @@ def test_update_call(responses, mock_elements_worker):
# Enable worker activity
mock_elements_worker.features["workers_activity"] = True
out = mock_elements_worker.update_activity(
Element({"id": "1234-deadbeef"}), ActivityState.Processed
)
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
# Check the response received by worker
assert out == {
......@@ -327,3 +233,22 @@ def test_run(
"element_id": "1234-deadbeef",
"state": final_state,
}
def test_run_cache(
monkeypatch, mocker, mock_elements_worker_with_cache, mock_cached_elements
):
# Disable second configure call from run()
monkeypatch.setattr(mock_elements_worker_with_cache, "configure", lambda: None)
# Make all the cached elements from the fixture initial elements
CachedElement.update(initial=True).execute()
mock_elements_worker_with_cache.process_element = mocker.MagicMock()
mock_elements_worker_with_cache.run()
assert mock_elements_worker_with_cache.process_element.call_args_list == [
# Called once for each cached element
mocker.call(elt)
for elt in CachedElement.select()
]
arkindex-base-worker==0.1.12
arkindex-base-worker==0.1.14