Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Showing
with 431 additions and 824 deletions
# -*- coding: utf-8 -*-
import json
import logging
import sys
......@@ -15,7 +14,7 @@ from arkindex_worker.worker.base import ExtrasDirNotFoundError
from tests.conftest import FIXTURES_DIR
def test_init_default_local_share(monkeypatch):
def test_init_default_local_share():
worker = BaseWorker()
assert worker.work_dir == Path("~/.local/share/arkindex").expanduser()
......@@ -29,7 +28,7 @@ def test_init_default_xdg_data_home(monkeypatch):
assert str(worker.work_dir) == f"{path}/arkindex"
def test_init_with_local_cache(monkeypatch):
def test_init_with_local_cache():
worker = BaseWorker(support_cache=True)
assert worker.work_dir == Path("~/.local/share/arkindex").expanduser()
......@@ -72,7 +71,8 @@ def test_init_var_worker_local_file(monkeypatch, tmp_path):
config.unlink()
def test_cli_default(mocker, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_cli_default(mocker):
worker = BaseWorker()
assert logger.level == logging.NOTSET
......@@ -91,7 +91,8 @@ def test_cli_default(mocker, mock_worker_run_api):
logger.setLevel(logging.NOTSET)
def test_cli_arg_verbose_given(mocker, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_cli_arg_verbose_given(mocker):
worker = BaseWorker()
assert logger.level == logging.NOTSET
......@@ -110,7 +111,8 @@ def test_cli_arg_verbose_given(mocker, mock_worker_run_api):
logger.setLevel(logging.NOTSET)
def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_cli_envvar_debug_given(mocker, monkeypatch):
worker = BaseWorker()
assert logger.level == logging.NOTSET
......@@ -129,7 +131,7 @@ def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api):
logger.setLevel(logging.NOTSET)
def test_configure_dev_mode(mocker, monkeypatch):
def test_configure_dev_mode(mocker):
"""
Configuring a worker in developer mode avoid retrieving process information
"""
......@@ -145,7 +147,7 @@ def test_configure_dev_mode(mocker, monkeypatch):
assert worker.user_configuration == {}
def test_configure_worker_run(mocker, monkeypatch, responses, caplog):
def test_configure_worker_run(mocker, responses, caplog):
# Capture log messages
caplog.set_level(logging.INFO)
......@@ -214,9 +216,8 @@ def test_configure_worker_run(mocker, monkeypatch, responses, caplog):
assert worker.user_configuration == {"a": "b"}
def test_configure_worker_run_no_revision(
mocker, monkeypatch, mock_worker_run_no_revision_api, caplog
):
@pytest.mark.usefixtures("_mock_worker_run_no_revision_api")
def test_configure_worker_run_no_revision(mocker, caplog):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
......@@ -234,11 +235,7 @@ def test_configure_worker_run_no_revision(
]
def test_configure_user_configuration_defaults(
mocker,
monkeypatch,
responses,
):
def test_configure_user_configuration_defaults(mocker, responses):
worker = BaseWorker()
mocker.patch.object(sys, "argv")
worker.args = worker.parser.parse_args()
......@@ -300,8 +297,8 @@ def test_configure_user_configuration_defaults(
}
@pytest.mark.parametrize("debug", (True, False))
def test_configure_user_config_debug(mocker, monkeypatch, responses, debug):
@pytest.mark.parametrize("debug", [True, False])
def test_configure_user_config_debug(mocker, responses, debug):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
assert logger.level == logging.NOTSET
......@@ -347,7 +344,7 @@ def test_configure_user_config_debug(mocker, monkeypatch, responses, debug):
logger.setLevel(logging.NOTSET)
def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
def test_configure_worker_run_missing_conf(mocker, responses):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
......@@ -392,7 +389,7 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
assert worker.user_configuration == {}
def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses):
def test_configure_worker_run_no_worker_run_conf(mocker, responses):
"""
No configuration is provided but should not crash
"""
......@@ -434,7 +431,7 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses)
assert worker.user_configuration == {}
def test_configure_load_model_configuration(mocker, monkeypatch, responses):
def test_configure_load_model_configuration(mocker, responses):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
payload = {
......@@ -585,7 +582,7 @@ def test_load_local_secret(monkeypatch, tmp_path):
secret.write_text("this is a local secret value", encoding="utf-8")
# Mock GPG decryption
class GpgDecrypt(object):
class GpgDecrypt:
def __init__(self, fd):
self.ok = True
self.data = fd.read()
......@@ -638,15 +635,15 @@ def test_find_extras_directory_from_config(monkeypatch):
@pytest.mark.parametrize(
"extras_path, exists, error",
(
[
("extras_path", "exists", "error"),
[
(
None,
True,
"No path to the directory for extra files was provided. Please provide extras_dir either through configuration or as CLI argument.",
],
["extra_files", False, "The path extra_files does not link to any directory"],
),
),
("extra_files", False, "The path extra_files does not link to any directory"),
],
)
def test_find_extras_directory_not_found(monkeypatch, extras_path, exists, error):
if extras_path:
......@@ -673,7 +670,9 @@ def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_pat
)
filename = Path("my_file.txt")
for parent_id, content in zip(["first", "third"], ["Some text", "Other text"]):
for parent_id, content in zip(
["first", "third"], ["Some text", "Other text"], strict=True
):
(tmp_path / parent_id).mkdir()
file_path = tmp_path / parent_id / filename
with file_path.open("w", encoding="utf-8") as f:
......@@ -749,7 +748,7 @@ def test_corpus_id_not_set_read_only_mode(
with pytest.raises(
Exception, match="Missing ARKINDEX_CORPUS_ID environment variable"
):
mock_elements_worker_read_only.corpus_id
_ = mock_elements_worker_read_only.corpus_id
def test_corpus_id_set_read_only_mode(
......
# -*- coding: utf-8 -*-
from pathlib import Path
from uuid import UUID
......@@ -31,22 +30,20 @@ def test_init(tmp_path):
def test_create_tables_existing_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
db_path = tmp_path / "db.sqlite"
# Create the tables once…
init_cache_db(db_path)
create_tables()
db.close()
with open(db_path, "rb") as before_file:
before = before_file.read()
before = db_path.read_bytes()
# Create them again
init_cache_db(db_path)
create_tables()
with open(db_path, "rb") as after_file:
after = after_file.read()
after = db_path.read_bytes()
assert before == after, "Existing table structure was modified"
......@@ -56,6 +53,9 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
# WARNING: If you are updating this schema following a development you have made
# in base-worker, make sure to upgrade the arkindex_worker.cache.SQL_VERSION in
# the same merge request as your changes.
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_run_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "dataset_elements" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "dataset_id" TEXT NOT NULL, "set_name" VARCHAR(255) NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"), FOREIGN KEY ("dataset_id") REFERENCES "datasets" ("id"))
CREATE TABLE "datasets" ("id" TEXT NOT NULL PRIMARY KEY, "name" VARCHAR(255) NOT NULL, "state" VARCHAR(255) NOT NULL DEFAULT 'open', "sets" TEXT NOT NULL)
......@@ -144,7 +144,17 @@ def test_check_version_same_version(tmp_path):
@pytest.mark.parametrize(
"image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_width,max_height,expected_url",
(
"image_width",
"image_height",
"polygon_x",
"polygon_y",
"polygon_width",
"polygon_height",
"max_width",
"max_height",
"expected_url",
),
[
# No max_size: no resize
(
......
......@@ -413,7 +413,7 @@ def test_list_datasets(responses, mock_dataset_worker):
]
@pytest.mark.parametrize("generator", (True, False))
@pytest.mark.parametrize("generator", [True, False])
def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator):
mocker.patch("arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[])
mock_dataset_worker.generator = generator
......@@ -428,7 +428,7 @@ def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator):
@pytest.mark.parametrize(
"generator, error",
("generator", "error"),
[
(True, "When generating a new dataset, its state should be Open."),
(False, "When processing an existing dataset, its state should be Complete."),
......@@ -657,7 +657,7 @@ def test_run_no_downloaded_artifact_error(
@pytest.mark.parametrize(
"generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)]
("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
)
def test_run(
mocker,
......@@ -749,7 +749,7 @@ def test_run(
@pytest.mark.parametrize(
"generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)]
("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
)
def test_run_read_only(
mocker,
......
# -*- coding: utf-8 -*-
import pytest
from requests import HTTPError
......
# -*- coding: utf-8 -*-
# API calls during worker configuration
BASE_API_CALLS = [
(
......
# -*- coding: utf-8 -*-
import json
import re
from uuid import UUID, uuid4
......
# -*- coding: utf-8 -*-
import json
import os
import sys
import tempfile
from pathlib import Path
from uuid import UUID
import pytest
......@@ -10,49 +9,53 @@ import pytest
from arkindex_worker.worker import ElementsWorker
def test_cli_default(monkeypatch, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_cli_default(monkeypatch):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
path = Path(path)
path.write_text(
json.dumps(
[
{"id": "volumeid", "type": "volume"},
{"id": "pageid", "type": "page"},
{"id": "actid", "type": "act"},
{"id": "surfaceid", "type": "surface"},
],
f,
)
)
monkeypatch.setenv("TASK_ELEMENTS", path)
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker()
worker.configure()
assert worker.args.elements_list.name == path
assert worker.args.elements_list.name == str(path)
assert not worker.args.element
os.unlink(path)
path.unlink()
def test_cli_arg_elements_list_given(mocker, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_cli_arg_elements_list_given(mocker):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
path = Path(path)
path.write_text(
json.dumps(
[
{"id": "volumeid", "type": "volume"},
{"id": "pageid", "type": "page"},
{"id": "actid", "type": "act"},
{"id": "surfaceid", "type": "surface"},
],
f,
)
)
mocker.patch.object(sys, "argv", ["worker", "--elements-list", path])
mocker.patch.object(sys, "argv", ["worker", "--elements-list", str(path)])
worker = ElementsWorker()
worker.configure()
assert worker.args.elements_list.name == path
assert worker.args.elements_list.name == str(path)
assert not worker.args.element
os.unlink(path)
path.unlink()
def test_cli_arg_element_one_given_not_uuid(mocker):
......@@ -62,7 +65,8 @@ def test_cli_arg_element_one_given_not_uuid(mocker):
worker.configure()
def test_cli_arg_element_one_given(mocker, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_cli_arg_element_one_given(mocker):
mocker.patch.object(
sys, "argv", ["worker", "--element", "12341234-1234-1234-1234-123412341234"]
)
......@@ -74,7 +78,8 @@ def test_cli_arg_element_one_given(mocker, mock_worker_run_api):
assert not worker.args.elements_list
def test_cli_arg_element_many_given(mocker, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_cli_arg_element_many_given(mocker):
mocker.patch.object(
sys,
"argv",
......
# -*- coding: utf-8 -*-
import json
import logging
......@@ -107,8 +106,8 @@ def test_list_process_datasets(
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Dataset
(
{"dataset": None},
......@@ -118,7 +117,7 @@ def test_list_process_datasets(
{"dataset": "not Dataset type"},
"dataset shouldn't be null and should be a Dataset",
),
),
],
)
def test_list_dataset_elements_wrong_param_dataset(mock_dataset_worker, payload, error):
with pytest.raises(AssertionError, match=error):
......@@ -265,8 +264,8 @@ def test_list_dataset_elements(
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Dataset
(
{"dataset": None},
......@@ -276,7 +275,7 @@ def test_list_dataset_elements(
{"dataset": "not dataset type"},
"dataset shouldn't be null and should be a Dataset",
),
),
],
)
def test_update_dataset_state_wrong_param_dataset(
mock_dataset_worker, default_dataset, payload, error
......@@ -292,8 +291,8 @@ def test_update_dataset_state_wrong_param_dataset(
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# DatasetState
(
{"state": None},
......@@ -303,7 +302,7 @@ def test_update_dataset_state_wrong_param_dataset(
{"state": "not dataset type"},
"state shouldn't be null and should be a str from DatasetState",
),
),
],
)
def test_update_dataset_state_wrong_param_state(
mock_dataset_worker, default_dataset, payload, error
......
# -*- coding: utf-8 -*-
import json
import re
from argparse import Namespace
......@@ -429,19 +428,7 @@ def test_create_sub_element_wrong_name(mock_elements_worker):
def test_create_sub_element_wrong_polygon(mock_elements_worker):
elt = Element({"zone": None})
with pytest.raises(
AssertionError, match="polygon shouldn't be null and should be of type list"
):
mock_elements_worker.create_sub_element(
element=elt,
type="something",
name="0",
polygon=None,
)
with pytest.raises(
AssertionError, match="polygon shouldn't be null and should be of type list"
):
with pytest.raises(AssertionError, match="polygon should be None or a list"):
mock_elements_worker.create_sub_element(
element=elt,
type="something",
......@@ -505,6 +492,42 @@ def test_create_sub_element_wrong_confidence(mock_elements_worker, confidence):
)
@pytest.mark.parametrize(
("image", "error_type", "error_message"),
[
(1, AssertionError, "image should be None or string"),
("not a uuid", ValueError, "image is not a valid uuid."),
],
)
def test_create_sub_element_wrong_image(
mock_elements_worker, image, error_type, error_message
):
with pytest.raises(error_type, match=re.escape(error_message)):
mock_elements_worker.create_sub_element(
element=Element({"zone": None}),
type="something",
name="blah",
polygon=[[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]],
image=image,
)
def test_create_sub_element_wrong_image_and_polygon(mock_elements_worker):
with pytest.raises(
AssertionError,
match=re.escape(
"An image or a parent with an image is required to create an element with a polygon."
),
):
mock_elements_worker.create_sub_element(
element=Element({"zone": None}),
type="something",
name="blah",
polygon=[[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]],
image=None,
)
def test_create_sub_element_api_error(responses, mock_elements_worker):
elt = Element(
{
......@@ -581,7 +604,7 @@ def test_create_sub_element(responses, mock_elements_worker, slim_output):
assert json.loads(responses.calls[-1].request.body) == {
"type": "something",
"name": "0",
"image": "22222222-2222-2222-2222-222222222222",
"image": None,
"corpus": "11111111-1111-1111-1111-111111111111",
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
"parent": "12341234-1234-1234-1234-123412341234",
......@@ -626,7 +649,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker):
assert json.loads(responses.calls[-1].request.body) == {
"type": "something",
"name": "0",
"image": "22222222-2222-2222-2222-222222222222",
"image": None,
"corpus": "11111111-1111-1111-1111-111111111111",
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
"parent": "12341234-1234-1234-1234-123412341234",
......@@ -1219,8 +1242,96 @@ def test_create_elements_integrity_error(
@pytest.mark.parametrize(
"payload, error",
(
("params", "error_message"),
[
(
{"parent": None, "child": None},
"parent shouldn't be null and should be of type Element",
),
(
{"parent": "not an element", "child": None},
"parent shouldn't be null and should be of type Element",
),
(
{"parent": Element(zone=None), "child": None},
"child shouldn't be null and should be of type Element",
),
(
{"parent": Element(zone=None), "child": "not an element"},
"child shouldn't be null and should be of type Element",
),
],
)
def test_create_element_parent_invalid_params(
mock_elements_worker, params, error_message
):
with pytest.raises(AssertionError, match=re.escape(error_message)):
mock_elements_worker.create_element_parent(**params)
def test_create_element_parent_api_error(responses, mock_elements_worker):
parent = Element({"id": "12341234-1234-1234-1234-123412341234"})
child = Element({"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"})
responses.add(
responses.POST,
"http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_elements_worker.create_element_parent(
parent=parent,
child=child,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
(
"POST",
"http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
),
] * 5
def test_create_element_parent(responses, mock_elements_worker):
parent = Element({"id": "12341234-1234-1234-1234-123412341234"})
child = Element({"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"})
responses.add(
responses.POST,
"http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
status=200,
json={
"parent": "12341234-1234-1234-1234-123412341234",
"child": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
},
)
created_element_parent = mock_elements_worker.create_element_parent(
parent=parent,
child=child,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/",
),
]
assert created_element_parent == {
"parent": "12341234-1234-1234-1234-123412341234",
"child": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
}
@pytest.mark.parametrize(
("payload", "error"),
[
# Element
(
{"element": None},
......@@ -1230,7 +1341,7 @@ def test_create_elements_integrity_error(
{"element": "not element type"},
"element shouldn't be null and should be an Element or CachedElement",
),
),
],
)
def test_partial_update_element_wrong_param_element(
mock_elements_worker, payload, error
......@@ -1247,12 +1358,12 @@ def test_partial_update_element_wrong_param_element(
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Type
({"type": 1234}, "type should be a str"),
({"type": None}, "type should be a str"),
),
],
)
def test_partial_update_element_wrong_param_type(mock_elements_worker, payload, error):
api_payload = {
......@@ -1267,12 +1378,12 @@ def test_partial_update_element_wrong_param_type(mock_elements_worker, payload,
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Name
({"name": 1234}, "name should be a str"),
({"name": None}, "name should be a str"),
),
],
)
def test_partial_update_element_wrong_param_name(mock_elements_worker, payload, error):
api_payload = {
......@@ -1287,8 +1398,8 @@ def test_partial_update_element_wrong_param_name(mock_elements_worker, payload,
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Polygon
({"polygon": "not a polygon"}, "polygon should be a list"),
({"polygon": None}, "polygon should be a list"),
......@@ -1305,7 +1416,7 @@ def test_partial_update_element_wrong_param_name(mock_elements_worker, payload,
{"polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]]},
"polygon points should be lists of two numbers",
),
),
],
)
def test_partial_update_element_wrong_param_polygon(
mock_elements_worker, payload, error
......@@ -1322,8 +1433,8 @@ def test_partial_update_element_wrong_param_polygon(
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Confidence
({"confidence": "lol"}, "confidence should be None or a float in [0..1] range"),
({"confidence": "0.2"}, "confidence should be None or a float in [0..1] range"),
......@@ -1333,7 +1444,7 @@ def test_partial_update_element_wrong_param_polygon(
{"confidence": float("inf")},
"confidence should be None or a float in [0..1] range",
),
),
],
)
def test_partial_update_element_wrong_param_conf(mock_elements_worker, payload, error):
api_payload = {
......@@ -1348,14 +1459,14 @@ def test_partial_update_element_wrong_param_conf(mock_elements_worker, payload,
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Rotation angle
({"rotation_angle": "lol"}, "rotation_angle should be a positive integer"),
({"rotation_angle": -1}, "rotation_angle should be a positive integer"),
({"rotation_angle": 0.5}, "rotation_angle should be a positive integer"),
({"rotation_angle": None}, "rotation_angle should be a positive integer"),
),
],
)
def test_partial_update_element_wrong_param_rota(mock_elements_worker, payload, error):
api_payload = {
......@@ -1370,13 +1481,13 @@ def test_partial_update_element_wrong_param_rota(mock_elements_worker, payload,
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Mirrored
({"mirrored": "lol"}, "mirrored should be a boolean"),
({"mirrored": 1234}, "mirrored should be a boolean"),
({"mirrored": None}, "mirrored should be a boolean"),
),
],
)
def test_partial_update_element_wrong_param_mir(mock_elements_worker, payload, error):
api_payload = {
......@@ -1391,13 +1502,13 @@ def test_partial_update_element_wrong_param_mir(mock_elements_worker, payload, e
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Image
({"image": "lol"}, "image should be a UUID"),
({"image": 1234}, "image should be a UUID"),
({"image": None}, "image should be a UUID"),
),
],
)
def test_partial_update_element_wrong_param_image(mock_elements_worker, payload, error):
api_payload = {
......@@ -1440,9 +1551,10 @@ def test_partial_update_element_api_error(responses, mock_elements_worker):
]
@pytest.mark.usefixtures("_mock_cached_elements", "_mock_cached_images")
@pytest.mark.parametrize(
"payload",
(
[
(
{
"polygon": [[10, 10], [20, 20], [20, 10], [10, 20]],
......@@ -1463,15 +1575,9 @@ def test_partial_update_element_api_error(responses, mock_elements_worker):
"mirrored": False,
}
),
),
],
)
def test_partial_update_element(
responses,
mock_elements_worker_with_cache,
mock_cached_elements,
mock_cached_images,
payload,
):
def test_partial_update_element(responses, mock_elements_worker_with_cache, payload):
elt = CachedElement.select().first()
new_image = CachedImage.select().first()
......@@ -1516,9 +1622,10 @@ def test_partial_update_element(
assert getattr(cached_element, param) == elt_response[param]
@pytest.mark.parametrize("confidence", (None, 0.42))
@pytest.mark.usefixtures("_mock_cached_elements")
@pytest.mark.parametrize("confidence", [None, 0.42])
def test_partial_update_element_confidence(
responses, mock_elements_worker_with_cache, mock_cached_elements, confidence
responses, mock_elements_worker_with_cache, confidence
):
elt = CachedElement.select().first()
elt_response = {
......@@ -1661,13 +1768,13 @@ def test_list_element_children_wrong_with_metadata(mock_elements_worker):
@pytest.mark.parametrize(
"param, value",
(
("param", "value"),
[
("worker_version", 1234),
("worker_run", 1234),
("transcription_worker_version", 1234),
("transcription_worker_run", 1234),
),
],
)
def test_list_element_children_wrong_worker_version(mock_elements_worker, param, value):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
......@@ -1681,12 +1788,12 @@ def test_list_element_children_wrong_worker_version(mock_elements_worker, param,
@pytest.mark.parametrize(
"param",
(
("worker_version"),
("worker_run"),
("transcription_worker_version"),
("transcription_worker_run"),
),
[
"worker_version",
"worker_run",
"transcription_worker_version",
"transcription_worker_run",
],
)
def test_list_element_children_wrong_bool_worker_version(mock_elements_worker, param):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
......@@ -1908,9 +2015,10 @@ def test_list_element_children_with_cache_unhandled_param(
)
@pytest.mark.usefixtures("_mock_cached_elements")
@pytest.mark.parametrize(
"filters, expected_ids",
(
("filters", "expected_ids"),
[
# Filter on element should give all elements inserted
(
{
......@@ -1977,12 +2085,11 @@ def test_list_element_children_with_cache_unhandled_param(
"33333333-3333-3333-3333-333333333333",
),
),
),
],
)
def test_list_element_children_with_cache(
responses,
mock_elements_worker_with_cache,
mock_cached_elements,
filters,
expected_ids,
):
......@@ -1992,7 +2099,7 @@ def test_list_element_children_with_cache(
# Query database through cache
elements = mock_elements_worker_with_cache.list_element_children(**filters)
assert elements.count() == len(expected_ids)
for child, expected_id in zip(elements.order_by("id"), expected_ids):
for child, expected_id in zip(elements.order_by("id"), expected_ids, strict=True):
assert child.id == UUID(expected_id)
# Check the worker never hits the API for elements
......@@ -2109,13 +2216,13 @@ def test_list_element_parents_wrong_with_metadata(mock_elements_worker):
@pytest.mark.parametrize(
"param, value",
(
("param", "value"),
[
("worker_version", 1234),
("worker_run", 1234),
("transcription_worker_version", 1234),
("transcription_worker_run", 1234),
),
],
)
def test_list_element_parents_wrong_worker_version(mock_elements_worker, param, value):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
......@@ -2129,12 +2236,12 @@ def test_list_element_parents_wrong_worker_version(mock_elements_worker, param,
@pytest.mark.parametrize(
"param",
(
("worker_version"),
("worker_run"),
("transcription_worker_version"),
("transcription_worker_run"),
),
[
"worker_version",
"worker_run",
"transcription_worker_version",
"transcription_worker_run",
],
)
def test_list_element_parents_wrong_bool_worker_version(mock_elements_worker, param):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
......@@ -2356,9 +2463,10 @@ def test_list_element_parents_with_cache_unhandled_param(
)
@pytest.mark.usefixtures("_mock_cached_elements")
@pytest.mark.parametrize(
"filters, expected_id",
(
("filters", "expected_id"),
[
# Filter on element
(
{
......@@ -2415,12 +2523,11 @@ def test_list_element_parents_with_cache_unhandled_param(
},
"99999999-9999-9999-9999-999999999999",
),
),
],
)
def test_list_element_parents_with_cache(
responses,
mock_elements_worker_with_cache,
mock_cached_elements,
filters,
expected_id,
):
......
# -*- coding: utf-8 -*-
import json
import re
from uuid import UUID
......@@ -56,7 +55,7 @@ def test_create_entity_wrong_type(mock_elements_worker):
)
def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker):
def test_create_entity_wrong_corpus(mock_elements_worker):
# Triggering an error on metas param, not giving corpus should work since
# ARKINDEX_CORPUS_ID environment variable is set on mock_elements_worker
with pytest.raises(AssertionError, match="metas should be of type dict"):
......@@ -742,12 +741,13 @@ def test_list_corpus_entities(responses, mock_elements_worker):
},
)
# list is required to actually do the request
assert list(mock_elements_worker.list_corpus_entities()) == [
{
mock_elements_worker.list_corpus_entities()
assert mock_elements_worker.entities == {
"fake_entity_id": {
"id": "fake_entity_id",
}
]
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
......@@ -760,22 +760,13 @@ def test_list_corpus_entities(responses, mock_elements_worker):
]
@pytest.mark.parametrize(
"wrong_name",
[
1234,
12.5,
],
)
@pytest.mark.parametrize("wrong_name", [1234, 12.5])
def test_list_corpus_entities_wrong_name(mock_elements_worker, wrong_name):
with pytest.raises(AssertionError, match="name should be of type str"):
mock_elements_worker.list_corpus_entities(name=wrong_name)
@pytest.mark.parametrize(
"wrong_parent",
[{"id": "element_id"}, 12.5, "blabla"],
)
@pytest.mark.parametrize("wrong_parent", [{"id": "element_id"}, 12.5, "blabla"])
def test_list_corpus_entities_wrong_parent(mock_elements_worker, wrong_parent):
with pytest.raises(AssertionError, match="parent should be of type Element"):
mock_elements_worker.list_corpus_entities(parent=wrong_parent)
......@@ -850,7 +841,7 @@ def test_check_required_entity_types_no_creation_allowed(
] == BASE_API_CALLS
@pytest.mark.parametrize("transcription", (None, "not a transcription", 1))
@pytest.mark.parametrize("transcription", [None, "not a transcription", 1])
def test_create_transcription_entities_wrong_transcription(
mock_elements_worker, transcription
):
......@@ -865,8 +856,8 @@ def test_create_transcription_entities_wrong_transcription(
@pytest.mark.parametrize(
"entities, error",
(
("entities", "error"),
[
(None, "entities shouldn't be null and should be of type list"),
(
"not a list of entities",
......@@ -886,7 +877,7 @@ def test_create_transcription_entities_wrong_transcription(
* 2,
"entities should be unique",
),
),
],
)
def test_create_transcription_entities_wrong_entities(
mock_elements_worker, entities, error
......@@ -909,8 +900,8 @@ def test_create_transcription_entities_wrong_entities_subtype(mock_elements_work
@pytest.mark.parametrize(
"entity, error",
(
("entity", "error"),
[
(
{
"name": None,
......@@ -989,7 +980,7 @@ def test_create_transcription_entities_wrong_entities_subtype(mock_elements_work
},
"Entity at index 0 in entities: confidence should be None or a float in [0..1] range",
),
),
],
)
def test_create_transcription_entities_wrong_entity(
mock_elements_worker, entity, error
......
# -*- coding: utf-8 -*-
import json
import re
......@@ -247,22 +246,20 @@ def test_create_metadata_cached_element(responses, mock_elements_worker_with_cac
@pytest.mark.parametrize(
"metadatas",
"metadata_list",
[
([{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}]),
(
[
{
"type": MetaType.Text,
"name": "fake_name",
"value": "fake_value",
"entity_id": "fake_entity_id",
}
]
),
[{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}],
[
{
"type": MetaType.Text,
"name": "fake_name",
"value": "fake_value",
"entity_id": "fake_entity_id",
}
],
],
)
def test_create_metadatas(responses, mock_elements_worker, metadatas):
def test_create_metadatas(responses, mock_elements_worker, metadata_list):
element = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -273,17 +270,19 @@ def test_create_metadatas(responses, mock_elements_worker, metadatas):
"metadata_list": [
{
"id": "fake_metadata_id",
"type": metadatas[0]["type"].value,
"name": metadatas[0]["name"],
"value": metadatas[0]["value"],
"type": metadata_list[0]["type"].value,
"name": metadata_list[0]["name"],
"value": metadata_list[0]["value"],
"dates": [],
"entity_id": metadatas[0].get("entity_id"),
"entity_id": metadata_list[0].get("entity_id"),
}
],
},
)
created_metadatas = mock_elements_worker.create_metadatas(element, metadatas)
created_metadata_list = mock_elements_worker.create_metadatas(
element, metadata_list
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
......@@ -296,42 +295,40 @@ def test_create_metadatas(responses, mock_elements_worker, metadatas):
]
assert json.loads(responses.calls[-1].request.body)["metadata_list"] == [
{
"type": metadatas[0]["type"].value,
"name": metadatas[0]["name"],
"value": metadatas[0]["value"],
"entity_id": metadatas[0].get("entity_id"),
"type": metadata_list[0]["type"].value,
"name": metadata_list[0]["name"],
"value": metadata_list[0]["value"],
"entity_id": metadata_list[0].get("entity_id"),
}
]
assert created_metadatas == [
assert created_metadata_list == [
{
"id": "fake_metadata_id",
"type": metadatas[0]["type"].value,
"name": metadatas[0]["name"],
"value": metadatas[0]["value"],
"type": metadata_list[0]["type"].value,
"name": metadata_list[0]["name"],
"value": metadata_list[0]["value"],
"dates": [],
"entity_id": metadatas[0].get("entity_id"),
"entity_id": metadata_list[0].get("entity_id"),
}
]
@pytest.mark.parametrize(
"metadatas",
"metadata_list",
[
([{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}]),
(
[
{
"type": MetaType.Text,
"name": "fake_name",
"value": "fake_value",
"entity_id": "fake_entity_id",
}
]
),
[{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}],
[
{
"type": MetaType.Text,
"name": "fake_name",
"value": "fake_value",
"entity_id": "fake_entity_id",
}
],
],
)
def test_create_metadatas_cached_element(
responses, mock_elements_worker_with_cache, metadatas
responses, mock_elements_worker_with_cache, metadata_list
):
element = CachedElement.create(
id="12341234-1234-1234-1234-123412341234", type="thing"
......@@ -345,18 +342,18 @@ def test_create_metadatas_cached_element(
"metadata_list": [
{
"id": "fake_metadata_id",
"type": metadatas[0]["type"].value,
"name": metadatas[0]["name"],
"value": metadatas[0]["value"],
"type": metadata_list[0]["type"].value,
"name": metadata_list[0]["name"],
"value": metadata_list[0]["value"],
"dates": [],
"entity_id": metadatas[0].get("entity_id"),
"entity_id": metadata_list[0].get("entity_id"),
}
],
},
)
created_metadatas = mock_elements_worker_with_cache.create_metadatas(
element, metadatas
created_metadata_list = mock_elements_worker_with_cache.create_metadatas(
element, metadata_list
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
......@@ -370,35 +367,27 @@ def test_create_metadatas_cached_element(
]
assert json.loads(responses.calls[-1].request.body)["metadata_list"] == [
{
"type": metadatas[0]["type"].value,
"name": metadatas[0]["name"],
"value": metadatas[0]["value"],
"entity_id": metadatas[0].get("entity_id"),
"type": metadata_list[0]["type"].value,
"name": metadata_list[0]["name"],
"value": metadata_list[0]["value"],
"entity_id": metadata_list[0].get("entity_id"),
}
]
assert created_metadatas == [
assert created_metadata_list == [
{
"id": "fake_metadata_id",
"type": metadatas[0]["type"].value,
"name": metadatas[0]["name"],
"value": metadatas[0]["value"],
"type": metadata_list[0]["type"].value,
"name": metadata_list[0]["name"],
"value": metadata_list[0]["value"],
"dates": [],
"entity_id": metadatas[0].get("entity_id"),
"entity_id": metadata_list[0].get("entity_id"),
}
]
@pytest.mark.parametrize(
"wrong_element",
[
None,
"not_element_type",
1234,
12.5,
],
)
@pytest.mark.parametrize("wrong_element", [None, "not_element_type", 1234, 12.5])
def test_create_metadatas_wrong_element(mock_elements_worker, wrong_element):
wrong_metadatas = [
wrong_metadata_list = [
{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}
]
with pytest.raises(
......@@ -406,48 +395,42 @@ def test_create_metadatas_wrong_element(mock_elements_worker, wrong_element):
match="element shouldn't be null and should be of type Element or CachedElement",
):
mock_elements_worker.create_metadatas(
element=wrong_element, metadatas=wrong_metadatas
element=wrong_element, metadatas=wrong_metadata_list
)
@pytest.mark.parametrize(
"wrong_type",
[
None,
"not_metadata_type",
1234,
12.5,
],
)
@pytest.mark.parametrize("wrong_type", [None, "not_metadata_type", 1234, 12.5])
def test_create_metadatas_wrong_type(mock_elements_worker, wrong_type):
element = Element({"id": "12341234-1234-1234-1234-123412341234"})
wrong_metadatas = [{"type": wrong_type, "name": "fake_name", "value": "fake_value"}]
wrong_metadata_list = [
{"type": wrong_type, "name": "fake_name", "value": "fake_value"}
]
with pytest.raises(
AssertionError, match="type shouldn't be null and should be of type MetaType"
):
mock_elements_worker.create_metadatas(
element=element, metadatas=wrong_metadatas
element=element, metadatas=wrong_metadata_list
)
@pytest.mark.parametrize("wrong_name", [(None), (1234), (12.5), ([1, 2, 3, 4])])
@pytest.mark.parametrize("wrong_name", [None, 1234, 12.5, [1, 2, 3, 4]])
def test_create_metadatas_wrong_name(mock_elements_worker, wrong_name):
element = Element({"id": "fake_element_id"})
wrong_metadatas = [
wrong_metadata_list = [
{"type": MetaType.Text, "name": wrong_name, "value": "fake_value"}
]
with pytest.raises(
AssertionError, match="name shouldn't be null and should be of type str"
):
mock_elements_worker.create_metadatas(
element=element, metadatas=wrong_metadatas
element=element, metadatas=wrong_metadata_list
)
@pytest.mark.parametrize("wrong_value", [(None), ([1, 2, 3, 4])])
@pytest.mark.parametrize("wrong_value", [None, [1, 2, 3, 4]])
def test_create_metadatas_wrong_value(mock_elements_worker, wrong_value):
element = Element({"id": "fake_element_id"})
wrong_metadatas = [
wrong_metadata_list = [
{"type": MetaType.Text, "name": "fake_name", "value": wrong_value}
]
with pytest.raises(
......@@ -457,21 +440,14 @@ def test_create_metadatas_wrong_value(mock_elements_worker, wrong_value):
),
):
mock_elements_worker.create_metadatas(
element=element, metadatas=wrong_metadatas
element=element, metadatas=wrong_metadata_list
)
@pytest.mark.parametrize(
"wrong_entity",
[
[1, 2, 3, 4],
1234,
12.5,
],
)
@pytest.mark.parametrize("wrong_entity", [[1, 2, 3, 4], 1234, 12.5])
def test_create_metadatas_wrong_entity(mock_elements_worker, wrong_entity):
element = Element({"id": "fake_element_id"})
wrong_metadatas = [
wrong_metadata_list = [
{
"type": MetaType.Text,
"name": "fake_name",
......@@ -481,13 +457,13 @@ def test_create_metadatas_wrong_entity(mock_elements_worker, wrong_entity):
]
with pytest.raises(AssertionError, match="entity_id should be None or a str"):
mock_elements_worker.create_metadatas(
element=element, metadatas=wrong_metadatas
element=element, metadatas=wrong_metadata_list
)
def test_create_metadatas_api_error(responses, mock_elements_worker):
element = Element({"id": "12341234-1234-1234-1234-123412341234"})
metadatas = [
metadata_list = [
{
"type": MetaType.Text,
"name": "fake_name",
......@@ -502,7 +478,7 @@ def test_create_metadatas_api_error(responses, mock_elements_worker):
)
with pytest.raises(ErrorResponse):
mock_elements_worker.create_metadatas(element, metadatas)
mock_elements_worker.create_metadatas(element, metadata_list)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
......
# -*- coding: utf-8 -*-
import uuid
import pytest
......@@ -12,8 +11,8 @@ TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe")
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Task ID
(
{"task_id": None},
......@@ -23,7 +22,7 @@ TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe")
{"task_id": "12341234-1234-1234-1234-123412341234"},
"task_id shouldn't be null and should be an UUID",
),
),
],
)
def test_list_artifacts_wrong_param_task_id(mock_dataset_worker, payload, error):
with pytest.raises(AssertionError, match=error):
......@@ -97,8 +96,8 @@ def test_list_artifacts(
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Task ID
(
{"task_id": None},
......@@ -108,7 +107,7 @@ def test_list_artifacts(
{"task_id": "12341234-1234-1234-1234-123412341234"},
"task_id shouldn't be null and should be an UUID",
),
),
],
)
def test_download_artifact_wrong_param_task_id(
mock_dataset_worker, default_artifact, payload, error
......@@ -124,8 +123,8 @@ def test_download_artifact_wrong_param_task_id(
@pytest.mark.parametrize(
"payload, error",
(
("payload", "error"),
[
# Artifact
(
{"artifact": None},
......@@ -135,7 +134,7 @@ def test_download_artifact_wrong_param_task_id(
{"artifact": "not artifact type"},
"artifact shouldn't be null and should be an Artifact",
),
),
],
)
def test_download_artifact_wrong_param_artifact(
mock_dataset_worker, default_artifact, payload, error
......
# -*- coding: utf-8 -*-
import logging
import sys
import pytest
import responses
from arkindex.mock import MockApiClient
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.training import TrainingMixin, create_archive
@pytest.fixture
@pytest.fixture()
def mock_training_worker(monkeypatch):
class TrainingWorker(BaseWorker, TrainingMixin):
"""
......@@ -24,7 +22,7 @@ def mock_training_worker(monkeypatch):
return training_worker
@pytest.fixture
@pytest.fixture()
def default_model_version():
return {
"id": "model_version_id",
......@@ -79,23 +77,32 @@ def test_create_archive_with_subfolder(model_file_dir_with_subfolder):
assert not zst_archive_path.exists(), "Auto removal failed"
def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir):
def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_dir):
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
responses.add(responses.PUT, s3_endpoint_url, status=400)
mock_training_worker.model_version = {
"state": "Created",
"s3_put_url": s3_endpoint_url,
}
file_path = model_file_dir / "model_file.pth"
with pytest.raises(Exception):
mock_training_worker.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
with pytest.raises(
Exception,
match="400 Client Error: Bad Request for url: http://s3.localhost.com/",
):
mock_training_worker.upload_to_s3(file_path)
@pytest.mark.parametrize(
"method",
[
("publish_model_version"),
("create_model_version"),
("update_model_version"),
("upload_to_s3"),
("validate_model_version"),
"publish_model_version",
"create_model_version",
"update_model_version",
"upload_to_s3",
"validate_model_version",
],
)
def test_training_mixin_read_only(mock_training_worker, method, caplog):
......
# -*- coding: utf-8 -*-
import json
import re
from uuid import UUID
......@@ -1867,9 +1866,10 @@ def test_list_transcriptions_manual_worker_version(responses, mock_elements_work
]
@pytest.mark.usefixtures("_mock_cached_transcriptions")
@pytest.mark.parametrize(
"filters, expected_ids",
(
("filters", "expected_ids"),
[
# Filter on element should give first and sixth transcription
(
{
......@@ -1963,14 +1963,10 @@ def test_list_transcriptions_manual_worker_version(responses, mock_elements_work
},
("66666666-6666-6666-6666-666666666666",),
),
),
],
)
def test_list_transcriptions_with_cache(
responses,
mock_elements_worker_with_cache,
mock_cached_transcriptions,
filters,
expected_ids,
responses, mock_elements_worker_with_cache, filters, expected_ids
):
# Check we have 5 elements already present in database
assert CachedTranscription.select().count() == 6
......@@ -1979,7 +1975,7 @@ def test_list_transcriptions_with_cache(
transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters)
assert transcriptions.count() == len(expected_ids)
for transcription, expected_id in zip(
transcriptions.order_by(CachedTranscription.id), expected_ids
transcriptions.order_by(CachedTranscription.id), expected_ids, strict=True
):
assert transcription.id == UUID(expected_id)
......
# . -*- coding: utf-8 -*-
import json
import sys
......@@ -78,7 +77,8 @@ def test_readonly(responses, mock_elements_worker):
] == BASE_API_CALLS
def test_activities_disabled(responses, monkeypatch, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_activities_disabled(responses, monkeypatch):
"""Test worker process elements without updating activities when they are disabled for the process"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker()
......@@ -105,7 +105,8 @@ def test_activities_dev_mode(mocker):
assert worker.store_activity is False
def test_update_call(responses, mock_elements_worker, mock_worker_run_api):
@pytest.mark.usefixtures("_mock_worker_run_api")
def test_update_call(responses, mock_elements_worker):
"""Test an update call with feature enabled triggers an API call"""
responses.add(
responses.PUT,
......@@ -141,8 +142,9 @@ def test_update_call(responses, mock_elements_worker, mock_worker_run_api):
}
@pytest.mark.usefixtures("_mock_activity_calls")
@pytest.mark.parametrize(
"process_exception, final_state",
("process_exception", "final_state"),
[
# Successful process_element
(None, "processed"),
......@@ -161,7 +163,6 @@ def test_run(
responses,
process_exception,
final_state,
mock_activity_calls,
):
"""Check the normal runtime sends 2 API calls to update activity"""
# Disable second configure call from run()
......@@ -210,13 +211,8 @@ def test_run(
}
def test_run_cache(
monkeypatch,
mocker,
mock_elements_worker_with_cache,
mock_cached_elements,
mock_activity_calls,
):
@pytest.mark.usefixtures("_mock_cached_elements", "_mock_activity_calls")
def test_run_cache(monkeypatch, mocker, mock_elements_worker_with_cache):
# Disable second configure call from run()
monkeypatch.setattr(mock_elements_worker_with_cache, "configure", lambda: None)
......@@ -310,8 +306,14 @@ def test_start_activity_error(
@pytest.mark.parametrize(
"wk_version_config,wk_version_user_config,frontend_user_config,model_config,expected_config",
(
"wk_version_config",
"wk_version_user_config",
"frontend_user_config",
"model_config",
"expected_config",
),
[
({}, {}, {}, {}, {}),
# Keep parameters from worker version configuration
({"parameter": 0}, {}, {}, {}, {"parameter": 0}),
......@@ -411,7 +413,7 @@ def test_start_activity_error(
{"parameter": 2},
{"parameter": 3},
),
),
],
)
def test_worker_config_multiple_source(
monkeypatch,
......
# -*- coding: utf-8 -*-
from pathlib import Path
import pytest
from gitlab import GitlabCreateError, GitlabError
from requests import ConnectionError
from responses import matchers
from arkindex_worker.git import GitlabHelper
PROJECT_ID = 21259233
MERGE_REQUEST_ID = 7
SOURCE_BRANCH = "new_branch"
TARGET_BRANCH = "master"
MR_TITLE = "merge request title"
CREATE_MR_RESPONSE_JSON = {
"id": 107,
"iid": MERGE_REQUEST_ID,
"project_id": PROJECT_ID,
"title": MR_TITLE,
"target_branch": TARGET_BRANCH,
"source_branch": SOURCE_BRANCH,
# several fields omitted
}
@pytest.fixture
def fake_responses(responses):
responses.add(
responses.GET,
"https://gitlab.com/api/v4/projects/balsac_exporter%2Fbalsac-exported-xmls-testing",
json={
"id": PROJECT_ID,
# several fields omitted
},
)
return responses
def test_clone_done(fake_git_helper):
assert not fake_git_helper.is_clone_finished
fake_git_helper._clone_done(None, None, None)
assert fake_git_helper.is_clone_finished
def test_clone(fake_git_helper):
command = fake_git_helper.run_clone_in_background()
cmd_str = " ".join(list(map(str, command.cmd)))
assert "git" in cmd_str
assert "clone" in cmd_str
def _get_fn_name_from_call(call):
# call.add(2, 3) => "add"
return str(call)[len("call.") :].split("(")[0]
def test_save_files(fake_git_helper, mocker):
mocker.patch("sh.wc", return_value=2)
fake_git_helper._git = mocker.MagicMock()
fake_git_helper.is_clone_finished = True
fake_git_helper.success = True
fake_git_helper.save_files(Path("/tmp/test_1234/tmp/"))
expected_calls = ["checkout", "add", "commit", "show", "push"]
actual_calls = list(map(_get_fn_name_from_call, fake_git_helper._git.mock_calls))
assert actual_calls == expected_calls
assert fake_git_helper.gitlab_helper.merge.call_count == 1
def test_save_files__fail_with_failed_clone(fake_git_helper, mocker):
mocker.patch("sh.wc", return_value=2)
fake_git_helper._git = mocker.MagicMock()
fake_git_helper.is_clone_finished = True
with pytest.raises(Exception) as execinfo:
fake_git_helper.save_files(Path("/tmp/test_1234/tmp/"))
assert execinfo.value.args[0] == "Clone was not a success"
def test_merge(mocker):
api = mocker.MagicMock()
project = mocker.MagicMock()
api.projects.get.return_value = project
merqe_request = mocker.MagicMock()
project.mergerequests.create.return_value = merqe_request
mocker.patch("gitlab.Gitlab", return_value=api)
gitlab_helper = GitlabHelper("project_id", "url", "token", "branch")
gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock()
gitlab_helper._wait_for_rebase_to_finish.return_value = True
success = gitlab_helper.merge("source", "merge title")
assert success
assert project.mergerequests.create.call_count == 1
assert merqe_request.merge.call_count == 1
def test_merge__rebase_failed(mocker):
api = mocker.MagicMock()
project = mocker.MagicMock()
api.projects.get.return_value = project
merqe_request = mocker.MagicMock()
project.mergerequests.create.return_value = merqe_request
mocker.patch("gitlab.Gitlab", return_value=api)
gitlab_helper = GitlabHelper("project_id", "url", "token", "branch")
gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock()
gitlab_helper._wait_for_rebase_to_finish.return_value = False
success = gitlab_helper.merge("source", "merge title")
assert not success
assert project.mergerequests.create.call_count == 1
assert merqe_request.merge.call_count == 0
def test_wait_for_rebase_to_finish(fake_responses, fake_gitlab_helper_factory):
get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True"
fake_responses.add(
fake_responses.GET,
get_mr_url,
json={
"rebase_in_progress": True,
"merge_error": None,
},
)
fake_responses.add(
fake_responses.GET,
get_mr_url,
json={
"rebase_in_progress": True,
"merge_error": None,
},
)
fake_responses.add(
fake_responses.GET,
get_mr_url,
json={
"rebase_in_progress": False,
"merge_error": None,
},
)
gitlab_helper = fake_gitlab_helper_factory()
success = gitlab_helper._wait_for_rebase_to_finish(MERGE_REQUEST_ID)
assert success
assert len(fake_responses.calls) == 4
assert gitlab_helper.is_rebase_finished
def test_wait_for_rebase_to_finish__fail_connection_error(
fake_responses, fake_gitlab_helper_factory
):
get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True"
fake_responses.add(
fake_responses.GET,
get_mr_url,
body=ConnectionError(),
)
gitlab_helper = fake_gitlab_helper_factory()
with pytest.raises(ConnectionError):
gitlab_helper._wait_for_rebase_to_finish(MERGE_REQUEST_ID)
assert len(fake_responses.calls) == 2
assert not gitlab_helper.is_rebase_finished
def test_wait_for_rebase_to_finish__fail_server_error(
fake_responses, fake_gitlab_helper_factory
):
get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True"
fake_responses.add(
fake_responses.GET,
get_mr_url,
body="Service Unavailable",
status=503,
)
gitlab_helper = fake_gitlab_helper_factory()
with pytest.raises(GitlabError):
gitlab_helper._wait_for_rebase_to_finish(MERGE_REQUEST_ID)
assert len(fake_responses.calls) == 2
assert not gitlab_helper.is_rebase_finished
def test_merge_request(fake_responses, fake_gitlab_helper_factory, mocker):
fake_responses.add(
fake_responses.POST,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests",
json=CREATE_MR_RESPONSE_JSON,
)
fake_responses.add(
fake_responses.PUT,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase",
json={},
)
fake_responses.add(
fake_responses.PUT,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/merge",
json={
"iid": MERGE_REQUEST_ID,
"state": "merged",
# several fields omitted
},
match=[matchers.json_params_matcher({"should_remove_source_branch": True})],
)
# the fake_responses are defined in the same order as they are expected to be called
expected_http_methods = [r.method for r in fake_responses.registered()]
expected_urls = [r.url for r in fake_responses.registered()]
gitlab_helper = fake_gitlab_helper_factory()
gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock()
gitlab_helper._wait_for_rebase_to_finish.return_value = True
success = gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE)
assert success
assert len(fake_responses.calls) == 4
assert [c.request.method for c in fake_responses.calls] == expected_http_methods
assert [c.request.url for c in fake_responses.calls] == expected_urls
def test_merge_request_fail(fake_responses, fake_gitlab_helper_factory, mocker):
fake_responses.add(
fake_responses.POST,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests",
json=CREATE_MR_RESPONSE_JSON,
)
fake_responses.add(
fake_responses.PUT,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase",
json={},
)
fake_responses.add(
fake_responses.PUT,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/merge",
json={"error": "Method not allowed"},
status=405,
match=[matchers.json_params_matcher({"should_remove_source_branch": True})],
)
# the fake_responses are defined in the same order as they are expected to be called
expected_http_methods = [r.method for r in fake_responses.registered()]
expected_urls = [r.url for r in fake_responses.registered()]
gitlab_helper = fake_gitlab_helper_factory()
gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock()
gitlab_helper._wait_for_rebase_to_finish.return_value = True
success = gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE)
assert not success
assert len(fake_responses.calls) == 4
assert [c.request.method for c in fake_responses.calls] == expected_http_methods
assert [c.request.url for c in fake_responses.calls] == expected_urls
def test_merge_request__success_after_errors(
fake_responses, fake_gitlab_helper_factory
):
fake_responses.add(
fake_responses.POST,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests",
json=CREATE_MR_RESPONSE_JSON,
)
rebase_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase"
fake_responses.add(
fake_responses.PUT,
rebase_url,
json={"rebase_in_progress": True},
)
get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True"
fake_responses.add(
fake_responses.GET,
get_mr_url,
body="Service Unavailable",
status=503,
)
fake_responses.add(
fake_responses.PUT,
rebase_url,
json={"rebase_in_progress": True},
)
fake_responses.add(
fake_responses.GET,
get_mr_url,
body=ConnectionError(),
)
fake_responses.add(
fake_responses.PUT,
rebase_url,
json={"rebase_in_progress": True},
)
fake_responses.add(
fake_responses.GET,
get_mr_url,
json={
"rebase_in_progress": True,
"merge_error": None,
},
)
fake_responses.add(
fake_responses.GET,
get_mr_url,
json={
"rebase_in_progress": False,
"merge_error": None,
},
)
fake_responses.add(
fake_responses.PUT,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/merge",
json={
"iid": MERGE_REQUEST_ID,
"state": "merged",
# several fields omitted
},
match=[matchers.json_params_matcher({"should_remove_source_branch": True})],
)
# the fake_responses are defined in the same order as they are expected to be called
expected_http_methods = [r.method for r in fake_responses.registered()]
expected_urls = [r.url for r in fake_responses.registered()]
gitlab_helper = fake_gitlab_helper_factory()
success = gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE)
assert success
assert len(fake_responses.calls) == 10
assert [c.request.method for c in fake_responses.calls] == expected_http_methods
assert [c.request.url for c in fake_responses.calls] == expected_urls
def test_merge_request__fail_bad_request(fake_responses, fake_gitlab_helper_factory):
fake_responses.add(
fake_responses.POST,
f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests",
json=CREATE_MR_RESPONSE_JSON,
)
rebase_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase"
fake_responses.add(
fake_responses.PUT,
rebase_url,
json={"rebase_in_progress": True},
)
get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True"
fake_responses.add(
fake_responses.GET,
get_mr_url,
body="Bad Request",
status=400,
)
# the fake_responses are defined in the same order as they are expected to be called
expected_http_methods = [r.method for r in fake_responses.registered()]
expected_urls = [r.url for r in fake_responses.registered()]
gitlab_helper = fake_gitlab_helper_factory()
with pytest.raises(GitlabError):
gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE)
assert len(fake_responses.calls) == 4
assert [c.request.method for c in fake_responses.calls] == expected_http_methods
assert [c.request.url for c in fake_responses.calls] == expected_urls
def test_create_merge_request__no_retry_5xx_error(
fake_responses, fake_gitlab_helper_factory
):
request_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests"
fake_responses.add(
fake_responses.POST,
request_url,
body="Service Unavailable",
status=503,
)
# the fake_responses are defined in the same order as they are expected to be called
expected_http_methods = [r.method for r in fake_responses.registered()]
expected_urls = [r.url for r in fake_responses.registered()]
gitlab_helper = fake_gitlab_helper_factory()
with pytest.raises(GitlabCreateError):
gitlab_helper.project.mergerequests.create(
{
"source_branch": "branch",
"target_branch": gitlab_helper.branch,
"title": "MR title",
}
)
assert len(fake_responses.calls) == 2
assert [c.request.method for c in fake_responses.calls] == expected_http_methods
assert [c.request.url for c in fake_responses.calls] == expected_urls
def test_create_merge_request__retry_5xx_error(
fake_responses, fake_gitlab_helper_factory
):
request_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests"
fake_responses.add(
fake_responses.POST,
request_url,
body="Service Unavailable",
status=503,
)
fake_responses.add(
fake_responses.POST,
request_url,
body="Service Unavailable",
status=503,
)
fake_responses.add(
fake_responses.POST,
request_url,
json=CREATE_MR_RESPONSE_JSON,
)
# the fake_responses are defined in the same order as they are expected to be called
expected_http_methods = [r.method for r in fake_responses.registered()]
expected_urls = [r.url for r in fake_responses.registered()]
gitlab_helper = fake_gitlab_helper_factory()
gitlab_helper.project.mergerequests.create(
{
"source_branch": "branch",
"target_branch": gitlab_helper.branch,
"title": "MR title",
},
retry_transient_errors=True,
)
assert len(fake_responses.calls) == 4
assert [c.request.method for c in fake_responses.calls] == expected_http_methods
assert [c.request.url for c in fake_responses.calls] == expected_urls
# -*- coding: utf-8 -*-
import math
import unittest
import uuid
......@@ -124,13 +123,13 @@ def test_download_tiles_small(responses):
@pytest.mark.parametrize(
"path, is_local",
(
("path", "is_local"),
[
("http://somewhere/test.jpg", False),
("https://somewhere/test.jpg", False),
("path/to/something", True),
("/absolute/path/to/something", True),
),
],
)
def test_open_image(path, is_local, monkeypatch):
"""Check if the path triggers a local load or a remote one"""
......@@ -149,13 +148,13 @@ def test_open_image(path, is_local, monkeypatch):
@pytest.mark.parametrize(
"rotation_angle,mirrored,expected_path",
(
("rotation_angle", "mirrored", "expected_path"),
[
(0, False, TILE),
(45, False, ROTATED_IMAGE),
(0, True, MIRRORED_IMAGE),
(45, True, ROTATED_MIRRORED_IMAGE),
),
],
)
def test_open_image_rotate_mirror(rotation_angle, mirrored, expected_path):
expected = Image.open(expected_path).convert("RGB")
......@@ -245,8 +244,9 @@ class TestTrimPolygon(unittest.TestCase):
[99, 208],
]
}
with self.assertRaises(
AssertionError, msg="Input polygon must be a valid list or tuple of points."
with pytest.raises(
AssertionError,
match="Input polygon must be a valid list or tuple of points.",
):
trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"])
......@@ -305,8 +305,8 @@ class TestTrimPolygon(unittest.TestCase):
[997, 206],
[999, 200],
]
with self.assertRaises(
AssertionError, msg="This polygon is entirely outside the image's bounds."
with pytest.raises(
AssertionError, match="This polygon is entirely outside the image's bounds."
):
trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"])
......@@ -328,8 +328,8 @@ class TestTrimPolygon(unittest.TestCase):
[197, 206],
[99, 20.8],
]
with self.assertRaises(
AssertionError, msg="Polygon point coordinates must be integers."
with pytest.raises(
AssertionError, match="Polygon point coordinates must be integers."
):
trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"])
......@@ -347,8 +347,8 @@ class TestTrimPolygon(unittest.TestCase):
[72, 57],
[12, 56],
]
with self.assertRaises(
AssertionError, msg="Polygon points must be tuples or lists."
with pytest.raises(
AssertionError, match="Polygon points must be tuples or lists."
):
trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"])
......@@ -366,15 +366,16 @@ class TestTrimPolygon(unittest.TestCase):
[72, 57],
[12, 56],
]
with self.assertRaises(
AssertionError, msg="Polygon points must be tuples or lists of 2 elements."
with pytest.raises(
AssertionError,
match="Polygon points must be tuples or lists of 2 elements.",
):
trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"])
@pytest.mark.parametrize(
"angle, mirrored, updated_bounds, reverse",
(
("angle", "mirrored", "updated_bounds", "reverse"),
[
(
0,
False,
......@@ -471,7 +472,7 @@ class TestTrimPolygon(unittest.TestCase):
{"x": 11, "y": 295, "width": 47, "height": 111}, # upper right
False,
),
),
],
)
def test_revert_orientation(angle, mirrored, updated_bounds, reverse, tmp_path):
"""Test cases, for both Elements and CachedElements:
......
# -*- coding: utf-8 -*-
from uuid import UUID
import pytest
......@@ -18,8 +17,8 @@ from arkindex_worker.cache import (
@pytest.mark.parametrize(
"parents, expected_elements, expected_transcriptions",
(
("parents", "expected_elements", "expected_transcriptions"),
[
# Nothing happen when no parents are available
([], [], []),
# Nothing happen when the parent file does not exist
......@@ -73,7 +72,7 @@ from arkindex_worker.cache import (
UUID("22222222-2222-2222-2222-222222222222"),
],
),
),
],
)
def test_merge_databases(
mock_databases, tmp_path, parents, expected_elements, expected_transcriptions
......@@ -114,7 +113,7 @@ def test_merge_databases(
] == expected_transcriptions
def test_merge_chunk(mock_databases, tmp_path, monkeypatch):
def test_merge_chunk(mock_databases, tmp_path):
"""
Check the db merge algorithm support two parents
and one of them has a chunk
......@@ -155,7 +154,7 @@ def test_merge_chunk(mock_databases, tmp_path, monkeypatch):
def test_merge_from_worker(
responses, mock_base_worker_with_cache, mock_databases, tmp_path, monkeypatch
responses, mock_base_worker_with_cache, mock_databases, tmp_path
):
"""
High level merge from the base worker
......
# -*- coding: utf-8 -*-
from pathlib import Path
from arkindex_worker.utils import close_delete_file, extract_tar_zst_archive
......
......@@ -4,5 +4,5 @@
"description": "{{ cookiecutter.description }}",
"worker_type": "{{ cookiecutter.worker_type }}",
"author": "{{ cookiecutter.author }}",
"email": "{{ cookiecutter.email}}"
"email": "{{ cookiecutter.email }}"
}