Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (17)
Showing
with 667 additions and 388 deletions
0.2.0-rc5
0.2.1-beta2
......@@ -72,19 +72,28 @@ class CachedElement(Model):
"""
if not self.image_id or not self.polygon:
raise ValueError(f"Element {self.id} has no image")
# Always fetch the image from the bounding box when size differs from full image
bounding_box = polygon_bounding_box(self.polygon)
if (
bounding_box.width != self.image.width
or bounding_box.height != self.image.height
):
box = f"{bounding_box.x},{bounding_box.y},{bounding_box.x + bounding_box.width},{bounding_box.y + bounding_box.height}"
else:
box = "full"
if max_size is None:
resize = "full"
else:
bounding_box = polygon_bounding_box(self.polygon)
# Do not resize for polygons that do not exactly match the images
# as the resize is made directly by the IIIF server using the box parameter
if (
bounding_box.width != self.image.width
or bounding_box.height != self.image.height
):
resize = "full"
logger.warning(
"Only full size elements covered, downloading full size image"
)
# Do not resize when the image is below the maximum size
elif self.image.width <= max_size and self.image.height <= max_size:
resize = "full"
......@@ -99,7 +108,7 @@ class CachedElement(Model):
if not url.endswith("/"):
url += "/"
return open_image(f"{url}full/{resize}/0/default.jpg", *args, **kwargs)
return open_image(f"{url}{box}/{resize}/0/default.jpg", *args, **kwargs)
class CachedTranscription(Model):
......@@ -186,13 +195,9 @@ def create_tables():
db.create_tables(MODELS)
def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=None):
"""
Merge all the potential parent task's databases into the existing local one
"""
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
assert os.path.exists(current_database)
# Handle possible chunk in parent task name
# This is needed to support the init_elements databases
......@@ -203,7 +208,7 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No
filenames.append(f"db_{chunk}.sqlite")
# Find all the paths for these databases
paths = list(
return list(
filter(
lambda p: os.path.isfile(p),
[
......@@ -214,6 +219,13 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No
)
)
def merge_parents_cache(paths, current_database):
"""
Merge all the potential parent task's databases into the existing local one
"""
assert os.path.exists(current_database)
if not paths:
logger.info("No parents cache to use")
return
......
......@@ -351,7 +351,10 @@ class GitHelper:
while keeping the same directory structure
"""
file_count = 0
for file in export_out_dir.rglob("*.*"):
file_names = [
file_name for file_name in export_out_dir.rglob("*") if file_name.is_file()
]
for file in file_names:
rel_file_path = file.relative_to(export_out_dir)
out_file = self.export_path / rel_file_path
if not out_file.exists():
......
......@@ -55,6 +55,14 @@ class Element(MagicDict):
Describes any kind of element.
"""
def resize_zone_url(self, size="full"):
if size == "full":
return self.zone.url
else:
parts = self.zone.url.split("/")
parts[-3] = size
return "/".join(parts)
def image_url(self, size="full"):
"""
When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers
......@@ -79,10 +87,16 @@ class Element(MagicDict):
bounding_box = polygon_bounding_box(self.zone.polygon)
return bounding_box.width > max_width or bounding_box.height > max_height
def open_image(self, *args, max_size=None, **kwargs):
def open_image(self, *args, max_size=None, use_full_image=False, **kwargs):
"""
Open this element's image as a Pillow image.
This does not crop the image to the element's polygon.
If use_full_image is False:
Use zone url, instead of the full image url to be able to get
single page from double page image.
Else:
This does not crop the image to the element's polygon.
:param max_size: Subresolution of the image.
"""
if not self.get("zone"):
......@@ -119,7 +133,10 @@ class Element(MagicDict):
resize = "full"
try:
return open_image(self.image_url(resize), *args, **kwargs)
if use_full_image:
return open_image(self.image_url(resize), *args, **kwargs)
else:
return open_image(self.resize_zone_url(resize), *args, **kwargs)
except HTTPError as e:
if (
self.zone.image.get("s3_url") is not None
......
......@@ -36,8 +36,8 @@ class ElementsWorker(
EntityMixin,
MetaDataMixin,
):
def __init__(self, description="Arkindex Elements Worker", use_cache=False):
super().__init__(description, use_cache)
def __init__(self, description="Arkindex Elements Worker", support_cache=False):
super().__init__(description, support_cache)
# Add report concerning elements
self.report = Reporter("unknown worker")
......@@ -84,6 +84,13 @@ class ElementsWorker(
return out
@property
def store_activity(self):
assert (
self.process_information
), "Worker must be configured to access its process activity state"
return self.process_information.get("activity_state") == "ready"
def run(self):
"""
Process every elements from the provided list
......@@ -97,6 +104,11 @@ class ElementsWorker(
logger.warning("No elements to process, stopping.")
sys.exit(1)
if not self.store_activity:
logger.info(
"No worker activity will be stored as it is disabled for this process"
)
# Process every element
count = len(elements)
failed = 0
......@@ -112,7 +124,7 @@ class ElementsWorker(
logger.info(f"Processing {element} ({i}/{count})")
# Report start of process, run process, then report end of process
# Process the element and report its progress if activities are enabled
self.update_activity(element.id, ActivityState.Started)
self.process_element(element)
self.update_activity(element.id, ActivityState.Processed)
......@@ -156,15 +168,17 @@ class ElementsWorker(
Update worker activity for this element
This method should not raise a runtime exception, but simply warn users
"""
if not self.store_activity:
logger.debug(
"Activity is not stored as the feature is disabled on this process"
)
return
assert element_id and isinstance(
element_id, (uuid.UUID, str)
), "element_id shouldn't be null and should be an UUID or str"
assert isinstance(state, ActivityState), "state should be an ActivityState"
if not self.features.get("workers_activity"):
logger.debug("Skipping Worker activity update as it's disabled on backend")
return
if self.is_read_only:
logger.warning("Cannot update activity as this worker is in read-only mode")
return
......@@ -175,6 +189,7 @@ class ElementsWorker(
id=self.worker_version_id,
body={
"element_id": str(element_id),
"process_id": self.process_information["id"],
"state": state.value,
},
)
......
......@@ -18,7 +18,12 @@ from tenacity import (
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import create_tables, init_cache_db, merge_parents_cache
from arkindex_worker.cache import (
create_tables,
init_cache_db,
merge_parents_cache,
retrieve_parents_cache_path,
)
def _is_500_error(exc):
......@@ -33,7 +38,7 @@ def _is_500_error(exc):
class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker", use_cache=False):
def __init__(self, description="Arkindex Base Worker", support_cache=False):
self.parser = argparse.ArgumentParser(description=description)
# Setup workdir either in Ponos environment or on host's home
......@@ -56,7 +61,10 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory")
self.use_cache = use_cache
self.support_cache = support_cache
# use_cache will be updated in configure() if the cache is supported and if there
# is at least one available sqlite database either given or in the parent tasks
self.use_cache = False
@property
def is_read_only(self):
......@@ -108,6 +116,14 @@ class BaseWorker(object):
logger.debug(f"Connected as {user['display_name']} - {user['email']}")
self.features = user["features"]
# Load process information
assert os.environ.get(
"ARKINDEX_PROCESS_ID"
), "ARKINDEX_PROCESS_ID environment variable is not defined"
self.process_information = self.request(
"RetrieveDataImport", id=os.environ["ARKINDEX_PROCESS_ID"]
)
if self.worker_version_id:
# Retrieve initial configuration from API
worker_version = self.request(
......@@ -133,38 +149,39 @@ class BaseWorker(object):
# Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets}
if self.args.database is not None:
task_id = os.environ.get("PONOS_TASK")
paths = None
if self.support_cache and self.args.database is not None:
self.use_cache = True
elif self.support_cache and task_id:
task = self.request("RetrieveTaskFromAgent", id=task_id)
paths = retrieve_parents_cache_path(
task["parents"],
data_dir=os.environ.get("PONOS_DATA", "/data"),
chunk=os.environ.get("ARKINDEX_TASK_CHUNK"),
)
self.use_cache = len(paths) > 0
task_id = os.environ.get("PONOS_TASK")
if self.use_cache is True:
if self.use_cache:
if self.args.database is not None:
assert os.path.isfile(
self.args.database
), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database
elif task_id:
else:
cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
else:
self.cache_path = os.path.join(os.getcwd(), "db.sqlite")
init_cache_db(self.cache_path)
create_tables()
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
if self.args.database is None and paths is not None:
merge_parents_cache(paths, self.cache_path)
else:
logger.debug("Cache is disabled")
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
if self.use_cache and self.args.database is None and task_id is not None:
task = self.request("RetrieveTaskFromAgent", id=task_id)
merge_parents_cache(
task["parents"],
self.cache_path,
data_dir=os.environ.get("PONOS_DATA", "/data"),
chunk=os.environ.get("ARKINDEX_TASK_CHUNK"),
)
def load_secret(self, name):
"""Load all secrets described in the worker configuration"""
secret = None
......
arkindex-client==1.0.6
peewee==3.14.4
Pillow==8.1.0
Pillow==8.2.0
python-gitlab==2.6.0
python-gnupg==0.4.7
sh==1.14.1
sh==1.14.2
tenacity==7.0.0
pytest==6.2.3
pytest-mock==3.5.1
pytest-responses==0.4.0
pytest==6.2.4
pytest-mock==3.6.1
pytest-responses==0.5.0
......@@ -92,20 +92,21 @@ def setup_api(responses, monkeypatch, cache_yaml):
@pytest.fixture(autouse=True)
def temp_working_directory(monkeypatch, tmp_path):
def _getcwd():
return str(tmp_path)
monkeypatch.setattr(os, "getcwd", _getcwd)
def give_env_variable(monkeypatch):
"""Defines required environment variables"""
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
monkeypatch.setenv("ARKINDEX_PROCESS_ID", "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff")
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
@pytest.fixture(autouse=True)
def give_worker_version_id_env_variable(monkeypatch):
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
@pytest.fixture
def mock_config_api(mock_worker_version_api, mock_process_api, mock_user_api):
"""Mock all API endpoints required to configure a worker"""
pass
@pytest.fixture
def mock_worker_version_api(responses, mock_user_api):
def mock_worker_version_api(responses):
"""Provide a mock API response to get worker configuration"""
payload = {
"id": "12341234-1234-1234-1234-123412341234",
......@@ -137,18 +138,58 @@ def mock_worker_version_api(responses, mock_user_api):
)
@pytest.fixture
def mock_process_api(responses):
"""Provide a mock of the API response to get information on a process. Workers activity is enabled"""
payload = {
"name": None,
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "running",
"mode": "workers",
"corpus": "11111111-1111-1111-1111-111111111111",
"workflow": "http://testserver/ponos/v1/workflow/12341234-1234-1234-1234-123412341234/",
"files": [],
"revision": None,
"element": {
"id": "12341234-1234-1234-1234-123412341234",
"type": "folder",
"name": "Test folder",
"corpus": {
"id": "11111111-1111-1111-1111-111111111111",
"name": "John Doe project",
"public": False,
},
"thumbnail_url": "http://testserver/thumbnail.png",
"zone": None,
"thumbnail_put_url": "http://testserver/thumbnail.png",
},
"folder_type": None,
"element_type": "page",
"element_name_contains": None,
"load_children": True,
"use_cache": False,
"activity_state": "ready",
}
responses.add(
responses.GET,
"http://testserver/api/v1/imports/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
@pytest.fixture
def mock_user_api(responses):
"""
Provide a mock API response to retrieve user details
Workers Activity is disabled in this mock
Signup is disabled in this mock
"""
payload = {
"id": 1,
"email": "bot@teklia.com",
"display_name": "Bender",
"features": {
"workers_activity": False,
"signup": False,
},
}
......@@ -162,10 +203,9 @@ def mock_user_api(responses):
@pytest.fixture
def mock_elements_worker(monkeypatch, mock_worker_version_api):
def mock_elements_worker(monkeypatch, mock_config_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker()
worker.configure()
......@@ -173,22 +213,23 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
"""Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True)
worker = BaseWorker(support_cache=True)
monkeypatch.setenv("PONOS_TASK", "my_task")
return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
cache_path = tmp_path / "db.sqlite"
cache_path.touch()
monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)])
worker = ElementsWorker(use_cache=True)
worker = ElementsWorker(support_cache=True)
worker.configure()
return worker
......
......@@ -29,11 +29,11 @@ def test_init_default_xdg_data_home(monkeypatch):
def test_init_with_local_cache(monkeypatch):
worker = BaseWorker(use_cache=True)
worker = BaseWorker(support_cache=True)
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.use_cache is True
assert worker.support_cache is True
def test_init_var_ponos_data_given(monkeypatch):
......@@ -45,7 +45,9 @@ def test_init_var_ponos_data_given(monkeypatch):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
def test_init_var_worker_version_id_missing(monkeypatch, mock_user_api):
def test_init_var_worker_version_id_missing(
monkeypatch, mock_user_api, mock_process_api
):
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.delenv("WORKER_VERSION_ID")
worker = BaseWorker()
......@@ -55,7 +57,9 @@ def test_init_var_worker_version_id_missing(monkeypatch, mock_user_api):
assert worker.config == {} # default empty case
def test_init_var_worker_local_file(monkeypatch, tmp_path, mock_user_api):
def test_init_var_worker_local_file(
monkeypatch, tmp_path, mock_user_api, mock_process_api
):
# Build a dummy yaml config file
config = tmp_path / "config.yml"
config.write_text("---\nlocalKey: abcdef123")
......@@ -71,7 +75,7 @@ def test_init_var_worker_local_file(monkeypatch, tmp_path, mock_user_api):
config.unlink()
def test_cli_default(mocker, mock_worker_version_api, mock_user_api):
def test_cli_default(mocker, mock_config_api):
worker = BaseWorker()
spy = mocker.spy(worker, "add_arguments")
assert not spy.called
......@@ -93,7 +97,7 @@ def test_cli_default(mocker, mock_worker_version_api, mock_user_api):
logger.setLevel(logging.NOTSET)
def test_cli_arg_verbose_given(mocker, mock_worker_version_api, mock_user_api):
def test_cli_arg_verbose_given(mocker, mock_config_api):
worker = BaseWorker()
spy = mocker.spy(worker, "add_arguments")
assert not spy.called
......
......@@ -78,13 +78,15 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
[
# No max_size: no resize
(400, 600, 400, 600, None, "http://something/full/full/0/default.jpg"),
# No max_size: resize on bbox
(400, 600, 200, 100, None, "http://something/0,0,200,100/full/0/default.jpg"),
# max_size equal to the image size, no resize
(400, 600, 400, 600, 600, "http://something/full/full/0/default.jpg"),
(600, 400, 600, 400, 600, "http://something/full/full/0/default.jpg"),
(400, 400, 400, 400, 400, "http://something/full/full/0/default.jpg"),
# max_size is smaller than the image, resize
(400, 600, 400, 600, 400, "http://something/full/266,400/0/default.jpg"),
(400, 600, 200, 600, 400, "http://something/full/266,400/0/default.jpg"),
(400, 600, 200, 600, 400, "http://something/0,0,200,600/full/0/default.jpg"),
(600, 400, 600, 400, 400, "http://something/full/400,266/0/default.jpg"),
(400, 400, 400, 400, 200, "http://something/full/200,200/0/default.jpg"),
# max_size above the image size, no resize
......@@ -118,9 +120,9 @@ def test_element_open_image(
image=image,
polygon=[
[0, 0],
[image_width, 0],
[image_width, image_height],
[0, image_height],
[polygon_width, 0],
[polygon_width, polygon_height],
[0, polygon_height],
[0, 0],
],
)
......
......@@ -62,7 +62,7 @@ def test_open_image(mocker):
}
}
)
assert elt.open_image() == "an image!"
assert elt.open_image(use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -87,19 +87,19 @@ def test_open_image_resize_portrait(mocker):
}
)
# Resize = original size
assert elt.open_image(max_size=600) == "an image!"
assert elt.open_image(max_size=600, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
)
# Resize = smaller height
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 2
assert open_mock.call_args == mocker.call(
"http://something/full/266,400/0/default.jpg"
)
# Resize = bigger height
assert elt.open_image(max_size=800) == "an image!"
assert elt.open_image(max_size=800, use_full_image=True) == "an image!"
assert open_mock.call_count == 3
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -123,7 +123,7 @@ def test_open_image_resize_partial_element(mocker):
}
}
)
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -148,19 +148,19 @@ def test_open_image_resize_landscape(mocker):
}
)
# Resize = original size
assert elt.open_image(max_size=600) == "an image!"
assert elt.open_image(max_size=600, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
)
# Resize = smaller width
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 2
assert open_mock.call_args == mocker.call(
"http://something/full/400,266/0/default.jpg"
)
# Resize = bigger width
assert elt.open_image(max_size=800) == "an image!"
assert elt.open_image(max_size=800, use_full_image=True) == "an image!"
assert open_mock.call_count == 3
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -185,19 +185,19 @@ def test_open_image_resize_square(mocker):
}
)
# Resize = original size
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
)
# Resize = smaller
assert elt.open_image(max_size=200) == "an image!"
assert elt.open_image(max_size=200, use_full_image=True) == "an image!"
assert open_mock.call_count == 2
assert open_mock.call_args == mocker.call(
"http://something/full/200,200/0/default.jpg"
)
# Resize = bigger
assert elt.open_image(max_size=800) == "an image!"
assert elt.open_image(max_size=800, use_full_image=True) == "an image!"
assert open_mock.call_count == 3
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -234,7 +234,7 @@ def test_open_image_s3(mocker):
elt = Element(
{"zone": {"image": {"url": "http://something", "s3_url": "http://s3url"}}}
)
assert elt.open_image() == "an image!"
assert elt.open_image(use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call("http://s3url")
......@@ -256,7 +256,7 @@ def test_open_image_s3_retry(mocker):
)
with pytest.raises(NotImplementedError):
elt.open_image()
elt.open_image(use_full_image=True)
def test_open_image_s3_retry_once(mocker):
......@@ -274,4 +274,49 @@ def test_open_image_s3_retry_once(mocker):
)
with pytest.raises(NotImplementedError):
elt.open_image()
elt.open_image(use_full_image=True)
def test_open_image_use_full_image_false(mocker):
open_mock = mocker.patch(
"arkindex_worker.models.open_image", return_value="an image!"
)
elt = Element(
{
"zone": {
"image": {"url": "http://something", "s3_url": "http://s3url"},
"url": "http://zoneurl/0,0,400,600/full/0/default.jpg",
}
}
)
assert elt.open_image(use_full_image=False) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://zoneurl/0,0,400,600/full/0/default.jpg"
)
def test_open_image_resize_use_full_image_false(mocker):
open_mock = mocker.patch(
"arkindex_worker.models.open_image", return_value="an image!"
)
elt = Element(
{
"zone": {
"image": {
"url": "http://something",
"width": 400,
"height": 600,
"server": {"max_width": None, "max_height": None},
},
"polygon": [[0, 0], [400, 0], [400, 600], [0, 600], [0, 0]],
"url": "http://zoneurl/0,0,400,600/full/0/default.jpg",
}
}
)
# Resize = smaller
assert elt.open_image(max_size=200, use_full_image=False) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://zoneurl/0,0,400,600/133,200/0/default.jpg"
)
# -*- coding: utf-8 -*-
# API calls during worker configuration
BASE_API_CALLS = [
("GET", "http://testserver/api/v1/user/"),
("GET", "http://testserver/api/v1/imports/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/"),
(
"GET",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
),
]
......@@ -8,6 +8,8 @@ from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
from . import BASE_API_CALLS
def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
......@@ -31,11 +33,11 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
assert not mock_elements_worker.classes
ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good")
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
]
assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": {"good": "0000"}
......@@ -134,19 +136,16 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
# Simply request class 2, it should be reloaded
assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id"
assert len(responses.calls) == 5
assert len(responses.calls) == len(BASE_API_CALLS) + 3
assert mock_elements_worker.classes == {
corpus_id: {
"class1": "class1_id",
"class2": "class2_id",
}
}
assert [(call.request.method, call.request.url) for call in responses.calls] == [
("GET", "http://testserver/api/v1/user/"),
(
"GET",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
),
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("POST", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
......@@ -350,16 +349,16 @@ def test_create_classification_api_error(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
]
......@@ -381,14 +380,14 @@ def test_create_classification(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classifications/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
......@@ -430,14 +429,14 @@ def test_create_classification_with_cache(responses, mock_elements_worker_with_c
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classifications/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
......@@ -486,14 +485,14 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classifications/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
......
......@@ -10,7 +10,7 @@ import pytest
from arkindex_worker.worker import ElementsWorker
def test_cli_default(monkeypatch, mock_worker_version_api):
def test_cli_default(monkeypatch, mock_config_api):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
......@@ -33,7 +33,7 @@ def test_cli_default(monkeypatch, mock_worker_version_api):
os.unlink(path)
def test_cli_arg_elements_list_given(mocker, mock_worker_version_api):
def test_cli_arg_elements_list_given(mocker, mock_config_api):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
......@@ -55,14 +55,14 @@ def test_cli_arg_elements_list_given(mocker, mock_worker_version_api):
os.unlink(path)
def test_cli_arg_element_one_given_not_uuid(mocker, mock_elements_worker):
def test_cli_arg_element_one_given_not_uuid(mocker):
mocker.patch.object(sys, "argv", ["worker", "--element", "1234"])
worker = ElementsWorker()
with pytest.raises(SystemExit):
worker.configure()
def test_cli_arg_element_one_given(mocker, mock_elements_worker):
def test_cli_arg_element_one_given(mocker, mock_config_api):
mocker.patch.object(
sys, "argv", ["worker", "--element", "12341234-1234-1234-1234-123412341234"]
)
......@@ -74,7 +74,7 @@ def test_cli_arg_element_one_given(mocker, mock_elements_worker):
assert not worker.args.elements_list
def test_cli_arg_element_many_given(mocker, mock_elements_worker):
def test_cli_arg_element_many_given(mocker, mock_config_api):
mocker.patch.object(
sys,
"argv",
......
......@@ -12,6 +12,8 @@ from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker
from . import BASE_API_CALLS
def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker):
_, path = tempfile.mkstemp()
......@@ -145,7 +147,7 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path):
),
)
worker = ElementsWorker()
worker = ElementsWorker(support_cache=True)
worker.configure()
assert worker.use_cache is True
......@@ -166,16 +168,16 @@ def test_load_corpus_classes_api_error(responses, mock_elements_worker):
):
mock_elements_worker.load_corpus_classes(corpus_id)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We do 5 retries
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
]
assert not mock_elements_worker.classes
......@@ -212,11 +214,11 @@ def test_load_corpus_classes(responses, mock_elements_worker):
assert not mock_elements_worker.classes
mock_elements_worker.load_corpus_classes(corpus_id)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
]
assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": {
......@@ -371,16 +373,16 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
]
......@@ -406,13 +408,13 @@ def test_create_sub_element(responses, mock_elements_worker):
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/create/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/elements/create/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"type": "something",
"name": "0",
"image": "22222222-2222-2222-2222-222222222222",
......@@ -697,16 +699,31 @@ def test_create_elements_api_error(responses, mock_elements_worker):
],
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
]
......@@ -741,13 +758,16 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
],
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"elements": [
{
"name": "0",
......@@ -805,13 +825,16 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
],
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"elements": [
{
"name": "0",
......@@ -1036,16 +1059,31 @@ def test_list_element_children_api_error(responses, mock_elements_worker):
):
next(mock_elements_worker.list_element_children(element=elt))
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We do 5 retries
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
]
......@@ -1102,11 +1140,14 @@ def test_list_element_children(responses, mock_elements_worker):
):
assert child == expected_children[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
]
......@@ -1186,8 +1227,7 @@ def test_list_element_children_with_cache(
assert child.id == UUID(expected_id)
# Check the worker never hits the API for elements
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]
assert len(responses.calls) == len(BASE_API_CALLS)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
......@@ -14,6 +14,8 @@ from arkindex_worker.cache import (
from arkindex_worker.models import Element
from arkindex_worker.worker import EntityType
from . import BASE_API_CALLS
def test_create_entity_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
......@@ -173,16 +175,16 @@ def test_create_entity_api_error(responses, mock_elements_worker):
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
]
......@@ -202,13 +204,13 @@ def test_create_entity(responses, mock_elements_worker):
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/entity/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/entity/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob",
"type": "person",
"metas": None,
......@@ -235,14 +237,14 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/entity/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/entity/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob",
"type": "person",
"metas": None,
......@@ -387,16 +389,31 @@ def test_create_transcription_entity_api_error(responses, mock_elements_worker):
length=10,
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
......@@ -419,13 +436,16 @@ def test_create_transcription_entity(responses, mock_elements_worker):
length=10,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
......@@ -471,13 +491,16 @@ def test_create_transcription_entity_with_cache(
length=10,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
......
......@@ -7,6 +7,8 @@ from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Element
from arkindex_worker.worker import MetaType
from . import BASE_API_CALLS
def test_create_metadata_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
......@@ -133,16 +135,31 @@ def test_create_metadata_api_error(responses, mock_elements_worker):
value="La Turbine, Grenoble 38000",
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
]
......@@ -162,13 +179,16 @@ def test_create_metadata(responses, mock_elements_worker):
value="La Turbine, Grenoble 38000",
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"type": "location",
"name": "Teklia",
"value": "La Turbine, Grenoble 38000",
......
......@@ -8,6 +8,8 @@ from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
from . import BASE_API_CALLS
TRANSCRIPTIONS_SAMPLE = [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
......@@ -130,16 +132,16 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
score=0.42,
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
("POST", f"http://testserver/api/v1/element/{elt.id}/transcription/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcription/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcription/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcription/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcription/"),
]
......@@ -164,14 +166,14 @@ def test_create_transcription(responses, mock_elements_worker):
score=0.42,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", f"http://testserver/api/v1/element/{elt.id}/transcription/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
......@@ -200,14 +202,14 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
score=0.42,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", f"http://testserver/api/v1/element/{elt.id}/transcription/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"score": 0.42,
......@@ -478,16 +480,16 @@ def test_create_transcriptions_api_error(responses, mock_elements_worker):
with pytest.raises(ErrorResponse):
mock_elements_worker.create_transcriptions(transcriptions=trans)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
("POST", "http://testserver/api/v1/transcription/bulk/"),
("POST", "http://testserver/api/v1/transcription/bulk/"),
("POST", "http://testserver/api/v1/transcription/bulk/"),
("POST", "http://testserver/api/v1/transcription/bulk/"),
("POST", "http://testserver/api/v1/transcription/bulk/"),
]
......@@ -533,14 +535,14 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
transcriptions=trans,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/transcription/bulk/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/transcription/bulk/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": trans,
}
......@@ -955,16 +957,16 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
]
......@@ -999,14 +1001,14 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
......@@ -1065,14 +1067,14 @@ def test_create_element_transcriptions_with_cache(
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": TRANSCRIPTIONS_SAMPLE,
......@@ -1200,16 +1202,31 @@ def test_list_transcriptions_api_error(responses, mock_elements_worker):
):
next(mock_elements_worker.list_transcriptions(element=elt))
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We do 5 retries
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
),
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
),
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
),
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
),
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
),
]
......@@ -1254,11 +1271,14 @@ def test_list_transcriptions(responses, mock_elements_worker):
):
assert transcription == trans[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
),
]
......@@ -1360,8 +1380,7 @@ def test_list_transcriptions_with_cache(
assert transcription.id == UUID(expected_id)
# Check the worker never hits the API for elements
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]
assert len(responses.calls) == len(BASE_API_CALLS)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
# -*- coding: utf-8 -*-
# . -*- coding: utf-8 -*-
import json
import sys
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement
from arkindex_worker.worker import ActivityState
from arkindex_worker.worker import ActivityState, ElementsWorker
# Common API calls for all workers
BASE_API_CALLS = [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]
from . import BASE_API_CALLS
TEST_VERSION_ID = "test_123"
TEST_SLUG = "some_slug"
......@@ -65,45 +62,45 @@ def test_get_worker_version_slug_none(fake_dummy_worker):
assert str(excinfo.value) == "No worker version ID"
def test_defaults(responses, mock_elements_worker):
"""Test the default values from mocked calls"""
assert not mock_elements_worker.is_read_only
assert mock_elements_worker.features == {
"workers_activity": False,
"signup": False,
}
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == BASE_API_CALLS
def test_feature_disabled(responses, mock_elements_worker):
"""Test disabled calls do not trigger any API calls"""
assert not mock_elements_worker.is_read_only
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
assert out is None
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == BASE_API_CALLS
def test_readonly(responses, mock_elements_worker):
"""Test readonly worker does not trigger any API calls"""
# Setup the worker as read-only, but with workers_activity enabled
# Setup the worker as read-only
mock_elements_worker.worker_version_id = None
assert mock_elements_worker.is_read_only is True
mock_elements_worker.features["workers_activity"] = True
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
assert out is None
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == BASE_API_CALLS
assert len(responses.calls) == len(BASE_API_CALLS)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
def test_update_call(responses, mock_elements_worker):
def test_activities_disabled(
responses, monkeypatch, mock_worker_version_api, mock_user_api
):
"""Test worker process elements without updating activities when they are disabled for the process"""
responses.add(
responses.GET,
"http://testserver/api/v1/imports/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
status=200,
body=json.dumps({"activity_state": "disabled"}),
content_type="application/json",
)
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker()
worker.configure()
assert not worker.is_read_only
assert len(responses.calls) == len(BASE_API_CALLS)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
def test_update_call(responses, mock_elements_worker, mock_process_api):
"""Test an update call with feature enabled triggers an API call"""
responses.add(
responses.PUT,
......@@ -111,29 +108,34 @@ def test_update_call(responses, mock_elements_worker):
status=200,
json={
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "processed",
},
)
# Enable worker activity
mock_elements_worker.features["workers_activity"] = True
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
# Check the response received by worker
assert out == {
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "processed",
}
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == BASE_API_CALLS + [
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"PUT",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
),
]
# Check the request sent by worker
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "processed",
}
......@@ -185,6 +187,7 @@ def test_run(
status=200,
json={
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "started",
},
)
......@@ -194,13 +197,12 @@ def test_run(
status=200,
json={
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": final_state,
},
)
# Enable worker activity
assert mock_elements_worker.is_read_only is False
mock_elements_worker.features["workers_activity"] = True
# Mock exception in process_element
if process_exception:
......@@ -217,20 +219,30 @@ def test_run(
# Simply run the process
mock_elements_worker.run()
assert len(responses.calls) == 5
assert [call.request.url for call in responses.calls] == BASE_API_CALLS + [
"http://testserver/api/v1/element/1234-deadbeef/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
assert len(responses.calls) == len(BASE_API_CALLS) + 3
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", "http://testserver/api/v1/element/1234-deadbeef/"),
(
"PUT",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
),
(
"PUT",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
),
]
# Check the requests sent by worker
assert json.loads(responses.calls[3].request.body) == {
assert json.loads(responses.calls[-2].request.body) == {
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "started",
}
assert json.loads(responses.calls[4].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": final_state,
}
......