Skip to content
Snippets Groups Projects
test_base_worker.py 16.88 KiB
# -*- coding: utf-8 -*-
import json
import logging
import os
import sys
from pathlib import Path

import gnupg
import pytest

from arkindex.mock import MockApiClient
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB
from arkindex_worker.utils import convert_str_uuid_to_hex
from arkindex_worker.worker import BaseWorker

CACHE_DIR = str(Path(__file__).resolve().parent / "data/cache")
FIRST_PARENT_CACHE = f"{CACHE_DIR}/first_parent_id/db.sqlite"
SECOND_PARENT_CACHE = f"{CACHE_DIR}/second_parent_id/db.sqlite"
FIRST_ELEM_TO_INSERT = CachedElement(
    id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
    parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
    type="something",
    polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
    worker_version_id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
)
SECOND_ELEM_TO_INSERT = CachedElement(
    id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
    parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
    type="something",
    polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
    worker_version_id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
)
FIRST_TR_TO_INSERT = CachedTranscription(
    id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
    element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
    text="Hello!",
    confidence=0.42,
    worker_version_id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
)
SECOND_TR_TO_INSERT = CachedTranscription(
    id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
    element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
    text="How are you?",
    confidence=0.42,
    worker_version_id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
)


def test_init_default_local_share(monkeypatch):
    worker = BaseWorker()

    assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
    assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"


def test_init_default_xdg_data_home(monkeypatch):
    path = str(Path(__file__).absolute().parent)
    monkeypatch.setenv("XDG_DATA_HOME", path)
    worker = BaseWorker()

    assert worker.work_dir == f"{path}/arkindex"
    assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"


def test_init_with_local_cache(monkeypatch):
    worker = BaseWorker(use_cache=True)

    assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
    assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
    assert worker.cache is not None


def test_init_var_ponos_data_given(monkeypatch):
    path = str(Path(__file__).absolute().parent)
    monkeypatch.setenv("PONOS_DATA", path)
    worker = BaseWorker()

    assert worker.work_dir == f"{path}/current"
    assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"


def test_init_var_worker_version_id_missing(monkeypatch, mock_user_api):
    monkeypatch.setattr(sys, "argv", ["worker"])
    monkeypatch.delenv("WORKER_VERSION_ID")
    worker = BaseWorker()
    worker.configure()
    assert worker.worker_version_id is None
    assert worker.is_read_only is True
    assert worker.config == {}  # default empty case


def test_init_var_worker_local_file(monkeypatch, tmp_path, mock_user_api):
    # Build a dummy yaml config file
    config = tmp_path / "config.yml"
    config.write_text("---\nlocalKey: abcdef123")

    monkeypatch.setattr(sys, "argv", ["worker", "-c", str(config)])
    monkeypatch.delenv("WORKER_VERSION_ID")
    worker = BaseWorker()
    worker.configure()
    assert worker.worker_version_id is None
    assert worker.is_read_only is True
    assert worker.config == {"localKey": "abcdef123"}  # Use a local file for devs

    config.unlink()


def test_cli_default(mocker, mock_worker_version_api, mock_user_api):
    worker = BaseWorker()
    spy = mocker.spy(worker, "add_arguments")
    assert not spy.called
    assert logger.level == logging.NOTSET
    assert not hasattr(worker, "api_client")

    mocker.patch.object(sys, "argv", ["worker"])
    worker.configure()

    assert spy.called
    assert spy.call_count == 1
    assert not worker.args.verbose
    assert logger.level == logging.NOTSET
    assert worker.api_client
    assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
    assert worker.is_read_only is False
    assert worker.config == {"someKey": "someValue"}  # from API

    logger.setLevel(logging.NOTSET)


def test_cli_arg_verbose_given(mocker, mock_worker_version_api, mock_user_api):
    worker = BaseWorker()
    spy = mocker.spy(worker, "add_arguments")
    assert not spy.called
    assert logger.level == logging.NOTSET
    assert not hasattr(worker, "api_client")

    mocker.patch.object(sys, "argv", ["worker", "-v"])
    worker.configure()

    assert spy.called
    assert spy.call_count == 1
    assert worker.args.verbose
    assert logger.level == logging.DEBUG
    assert worker.api_client
    assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
    assert worker.is_read_only is False
    assert worker.config == {"someKey": "someValue"}  # from API

    logger.setLevel(logging.NOTSET)


def test_configure_cache_merging_no_parent(responses, mock_base_worker_with_cache):
    responses.add(
        responses.GET,
        "http://testserver/ponos/v1/task/my_task/from-agent/",
        status=200,
        json={"parents": []},
    )

    cache_path = mock_base_worker_with_cache.cache.path
    with open(cache_path, "rb") as before_file:
        before = before_file.read()

    mock_base_worker_with_cache.configure()

    with open(cache_path, "rb") as after_file:
        after = after_file.read()

    assert before == after, "Cache was modified"

    assert len(responses.calls) == 3
    assert [call.request.url for call in responses.calls] == [
        "http://testserver/api/v1/user/",
        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
        "http://testserver/ponos/v1/task/my_task/from-agent/",
    ]


def test_configure_cache_merging_one_parent_without_file(
    responses, mock_base_worker_with_cache, first_parent_folder
):
    responses.add(
        responses.GET,
        "http://testserver/ponos/v1/task/my_task/from-agent/",
        status=200,
        json={"parents": ["first_parent_id"]},
    )

    cache_path = mock_base_worker_with_cache.cache.path
    with open(cache_path, "rb") as before_file:
        before = before_file.read()

    mock_base_worker_with_cache.configure()

    with open(cache_path, "rb") as after_file:
        after = after_file.read()

    assert before == after, "Cache was modified"

    assert len(responses.calls) == 3
    assert [call.request.url for call in responses.calls] == [
        "http://testserver/api/v1/user/",
        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
        "http://testserver/ponos/v1/task/my_task/from-agent/",
    ]


def test_configure_cache_merging_one_parent(
    responses, mock_base_worker_with_cache, first_parent_cache
):
    parent_cache = LocalDB(FIRST_PARENT_CACHE)
    parent_cache.insert("elements", [FIRST_ELEM_TO_INSERT])
    parent_cache.insert("transcriptions", [FIRST_TR_TO_INSERT])

    responses.add(
        responses.GET,
        "http://testserver/ponos/v1/task/my_task/from-agent/",
        status=200,
        json={"parents": ["first_parent_id"]},
    )

    mock_base_worker_with_cache.configure()

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM elements"
    ).fetchall()
    assert (
        stored_rows == parent_cache.cursor.execute("SELECT * FROM elements").fetchall()
    )
    assert [CachedElement(**dict(row)) for row in stored_rows] == [FIRST_ELEM_TO_INSERT]

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM transcriptions"
    ).fetchall()
    assert (
        stored_rows
        == parent_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
    )
    assert [CachedTranscription(**dict(row)) for row in stored_rows] == [
        FIRST_TR_TO_INSERT
    ]

    assert len(responses.calls) == 3
    assert [call.request.url for call in responses.calls] == [
        "http://testserver/api/v1/user/",
        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
        "http://testserver/ponos/v1/task/my_task/from-agent/",
    ]


def test_configure_cache_merging_multiple_parents_one_file(
    responses, mock_base_worker_with_cache, first_parent_cache, second_parent_folder
):
    parent_cache = LocalDB(FIRST_PARENT_CACHE)
    parent_cache.insert("elements", [FIRST_ELEM_TO_INSERT])
    parent_cache.insert("transcriptions", [FIRST_TR_TO_INSERT])

    responses.add(
        responses.GET,
        "http://testserver/ponos/v1/task/my_task/from-agent/",
        status=200,
        json={"parents": ["first_parent_id", "second_parent_id"]},
    )

    mock_base_worker_with_cache.configure()

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM elements"
    ).fetchall()
    assert (
        stored_rows == parent_cache.cursor.execute("SELECT * FROM elements").fetchall()
    )
    assert [CachedElement(**dict(row)) for row in stored_rows] == [FIRST_ELEM_TO_INSERT]

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM transcriptions"
    ).fetchall()
    assert (
        stored_rows
        == parent_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
    )
    assert [CachedTranscription(**dict(row)) for row in stored_rows] == [
        FIRST_TR_TO_INSERT
    ]

    assert len(responses.calls) == 3
    assert [call.request.url for call in responses.calls] == [
        "http://testserver/api/v1/user/",
        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
        "http://testserver/ponos/v1/task/my_task/from-agent/",
    ]


def test_configure_cache_merging_multiple_parents_differing_lines(
    responses, mock_base_worker_with_cache, first_parent_cache, second_parent_cache
):
    # Inserting differing lines in both parents caches
    parent_cache = LocalDB(FIRST_PARENT_CACHE)
    parent_cache = LocalDB(FIRST_PARENT_CACHE)
    parent_cache.insert("elements", [FIRST_ELEM_TO_INSERT])
    parent_cache.insert("transcriptions", [FIRST_TR_TO_INSERT])
    second_parent_cache = LocalDB(SECOND_PARENT_CACHE)
    second_parent_cache.insert("elements", [SECOND_ELEM_TO_INSERT])
    second_parent_cache.insert("transcriptions", [SECOND_TR_TO_INSERT])

    responses.add(
        responses.GET,
        "http://testserver/ponos/v1/task/my_task/from-agent/",
        status=200,
        json={"parents": ["first_parent_id", "second_parent_id"]},
    )

    mock_base_worker_with_cache.configure()

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM elements"
    ).fetchall()
    assert (
        stored_rows
        == parent_cache.cursor.execute("SELECT * FROM elements").fetchall()
        + second_parent_cache.cursor.execute("SELECT * FROM elements").fetchall()
    )
    assert [CachedElement(**dict(row)) for row in stored_rows] == [
        FIRST_ELEM_TO_INSERT,
        SECOND_ELEM_TO_INSERT,
    ]

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM transcriptions"
    ).fetchall()
    assert (
        stored_rows
        == parent_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
        + second_parent_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
    )
    assert [CachedTranscription(**dict(row)) for row in stored_rows] == [
        FIRST_TR_TO_INSERT,
        SECOND_TR_TO_INSERT,
    ]

    assert len(responses.calls) == 3
    assert [call.request.url for call in responses.calls] == [
        "http://testserver/api/v1/user/",
        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
        "http://testserver/ponos/v1/task/my_task/from-agent/",
    ]


def test_configure_cache_merging_multiple_parents_identical_lines(
    responses, mock_base_worker_with_cache, first_parent_cache, second_parent_cache
):
    # Inserting identical lines in both parents caches
    parent_cache = LocalDB(FIRST_PARENT_CACHE)
    parent_cache.insert("elements", [FIRST_ELEM_TO_INSERT, SECOND_ELEM_TO_INSERT])
    parent_cache.insert("transcriptions", [FIRST_TR_TO_INSERT, SECOND_TR_TO_INSERT])
    second_parent_cache = LocalDB(SECOND_PARENT_CACHE)
    second_parent_cache.insert(
        "elements", [FIRST_ELEM_TO_INSERT, SECOND_ELEM_TO_INSERT]
    )
    second_parent_cache.insert(
        "transcriptions", [FIRST_TR_TO_INSERT, SECOND_TR_TO_INSERT]
    )

    responses.add(
        responses.GET,
        "http://testserver/ponos/v1/task/my_task/from-agent/",
        status=200,
        json={"parents": ["first_parent_id", "second_parent_id"]},
    )

    mock_base_worker_with_cache.configure()

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM elements"
    ).fetchall()
    assert (
        stored_rows == parent_cache.cursor.execute("SELECT * FROM elements").fetchall()
    )
    assert (
        stored_rows
        == second_parent_cache.cursor.execute("SELECT * FROM elements").fetchall()
    )
    assert [CachedElement(**dict(row)) for row in stored_rows] == [
        FIRST_ELEM_TO_INSERT,
        SECOND_ELEM_TO_INSERT,
    ]

    stored_rows = mock_base_worker_with_cache.cache.cursor.execute(
        "SELECT * FROM transcriptions"
    ).fetchall()
    assert (
        stored_rows
        == parent_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
    )
    assert (
        stored_rows
        == second_parent_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
    )
    assert [CachedTranscription(**dict(row)) for row in stored_rows] == [
        FIRST_TR_TO_INSERT,
        SECOND_TR_TO_INSERT,
    ]

    assert len(responses.calls) == 3
    assert [call.request.url for call in responses.calls] == [
        "http://testserver/api/v1/user/",
        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
        "http://testserver/ponos/v1/task/my_task/from-agent/",
    ]


def test_load_missing_secret():
    worker = BaseWorker()
    worker.api_client = MockApiClient()

    with pytest.raises(
        Exception, match="Secret missing/secret is not available on the API nor locally"
    ):
        worker.load_secret("missing/secret")


def test_load_remote_secret():
    worker = BaseWorker()
    worker.api_client = MockApiClient()
    worker.api_client.add_response(
        "RetrieveSecret",
        name="testRemote",
        response={"content": "this is a secret value !"},
    )

    assert worker.load_secret("testRemote") == "this is a secret value !"

    # The one mocked call has been used
    assert len(worker.api_client.history) == 1
    assert len(worker.api_client.responses) == 0


def test_load_json_secret():
    worker = BaseWorker()
    worker.api_client = MockApiClient()
    worker.api_client.add_response(
        "RetrieveSecret",
        name="path/to/file.json",
        response={"content": '{"key": "value", "number": 42}'},
    )

    assert worker.load_secret("path/to/file.json") == {
        "key": "value",
        "number": 42,
    }

    # The one mocked call has been used
    assert len(worker.api_client.history) == 1
    assert len(worker.api_client.responses) == 0


def test_load_yaml_secret():
    worker = BaseWorker()
    worker.api_client = MockApiClient()
    worker.api_client.add_response(
        "RetrieveSecret",
        name="path/to/file.yaml",
        response={
            "content": """---
somekey: value
aList:
  - A
  - B
  - C
struct:
 level:
   X
"""
        },
    )

    assert worker.load_secret("path/to/file.yaml") == {
        "aList": ["A", "B", "C"],
        "somekey": "value",
        "struct": {"level": "X"},
    }

    # The one mocked call has been used
    assert len(worker.api_client.history) == 1
    assert len(worker.api_client.responses) == 0


def test_load_local_secret(monkeypatch, tmpdir):
    # Setup arkindex config dir in a temp directory
    monkeypatch.setenv("XDG_CONFIG_HOME", str(tmpdir))

    # Write a dummy secret
    secrets_dir = tmpdir / "arkindex" / "secrets"
    os.makedirs(secrets_dir)
    secret = secrets_dir / "testLocal"
    secret.write_text("this is a local secret value", encoding="utf-8")

    # Mock GPG decryption
    class GpgDecrypt(object):
        def __init__(self, fd):
            self.ok = True
            self.data = fd.read()

    monkeypatch.setattr(gnupg.GPG, "decrypt_file", lambda gpg, f: GpgDecrypt(f))

    worker = BaseWorker()
    worker.api_client = MockApiClient()

    assert worker.load_secret("testLocal") == "this is a local secret value"

    # The remote api is checked first
    assert len(worker.api_client.history) == 1
    assert worker.api_client.history[0].operation == "RetrieveSecret"