Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (2)
# -*- coding: utf-8 -*-
import sqlite3
from collections import namedtuple
from arkindex_worker import logger
SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
id VARCHAR(32) PRIMARY KEY,
parent_id VARCHAR(32),
type TEXT NOT NULL,
polygon TEXT,
initial BOOLEAN DEFAULT 0 NOT NULL,
worker_version_id VARCHAR(32)
)"""
CachedElement = namedtuple(
"CachedElement",
["id", "type", "polygon", "worker_version_id", "parent_id", "initial"],
defaults=[None, 0],
)
class LocalDB(object):
def __init__(self, path):
self.db = sqlite3.connect(path)
self.db.row_factory = sqlite3.Row
self.cursor = self.db.cursor()
logger.info(f"Connection to local cache {path} established.")
def create_tables(self):
self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
def insert(self, table, lines):
if not lines:
return
columns = ", ".join(lines[0]._fields)
placeholders = ", ".join("?" * len(lines[0]))
values = [tuple(line) for line in lines]
self.cursor.executemany(
f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
)
self.db.commit()
# -*- coding: utf-8 -*-
import datetime
import uuid
from timeit import default_timer
......@@ -19,3 +20,7 @@ class Timer(object):
end = self.timer()
self.elapsed = end - self.start
self.delta = datetime.timedelta(seconds=self.elapsed)
def convert_str_uuid_to_hex(id):
return uuid.UUID(id).hex
......@@ -3,6 +3,7 @@ import argparse
import json
import logging
import os
import sqlite3
import sys
import uuid
import warnings
......@@ -16,14 +17,17 @@ from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, LocalDB
from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
from arkindex_worker.utils import convert_str_uuid_to_hex
MANUAL_SLUG = "manual"
CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}"
class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker"):
def __init__(self, description="Arkindex Base Worker", use_cache=False):
self.parser = argparse.ArgumentParser(description=description)
# Setup workdir either in Ponos environment or on host's home
......@@ -46,6 +50,18 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory")
if use_cache is True:
if os.environ.get("TASK_ID") and os.path.isdir(CACHE_DIR):
cache_path = os.path.join(CACHE_DIR, "db.sqlite")
else:
cache_path = os.path.join(os.getcwd(), "db.sqlite")
self.cache = LocalDB(cache_path)
self.cache.create_tables()
else:
self.cache = None
logger.debug("Cache is disabled")
@property
def is_read_only(self):
"""Worker cannot publish anything without a worker version ID"""
......@@ -202,8 +218,8 @@ class ActivityState(Enum):
class ElementsWorker(BaseWorker):
def __init__(self, description="Arkindex Elements Worker"):
super().__init__(description)
def __init__(self, description="Arkindex Elements Worker", use_cache=False):
super().__init__(description, use_cache)
# Add report concerning elements
self.report = Reporter("unknown worker")
......@@ -451,6 +467,25 @@ class ElementsWorker(BaseWorker):
for element in elements:
self.report.add_element(parent.id, element["type"])
if self.cache:
# Store elements in local cache
try:
parent_id_hex = convert_str_uuid_to_hex(parent.id)
worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id)
to_insert = [
CachedElement(
id=convert_str_uuid_to_hex(created_ids[idx]["id"]),
parent_id=parent_id_hex,
type=element["type"],
polygon=json.dumps(element["polygon"]),
worker_version_id=worker_version_id_hex,
)
for idx, element in enumerate(elements)
]
self.cache.insert("elements", to_insert)
except sqlite3.IntegrityError as e:
logger.warning(f"Couldn't save created elements in local cache: {e}")
return created_ids
def create_transcription(self, element, text, type=None, score=None):
......
......@@ -13,6 +13,8 @@ from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import ElementsWorker
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
CACHE_DIR = str(Path(__file__).resolve().parent / "data/cache")
CACHE_FILE = os.path.join(CACHE_DIR, "db.sqlite")
__yaml_cache = {}
......@@ -77,6 +79,19 @@ def setup_api(responses, monkeypatch, cache_yaml):
monkeypatch.setenv("ARKINDEX_API_TOKEN", "unittest1234")
@pytest.fixture(autouse=True)
def handle_cache_file(monkeypatch):
def _getcwd():
return CACHE_DIR
monkeypatch.setattr(os, "getcwd", _getcwd)
yield
if os.path.isfile(CACHE_FILE):
os.remove(CACHE_FILE)
@pytest.fixture(autouse=True)
def give_worker_version_id_env_variable(monkeypatch):
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
......@@ -149,6 +164,16 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker(use_cache=True)
worker.configure()
return worker
@pytest.fixture
def fake_page_element():
with open(FIXTURES_DIR / "page_element.json", "r") as f:
......
File added
File added
......@@ -12,7 +12,7 @@ from arkindex_worker import logger
from arkindex_worker.worker import BaseWorker
def test_init_default_local_share():
def test_init_default_local_share(monkeypatch):
worker = BaseWorker()
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
......@@ -28,6 +28,14 @@ def test_init_default_xdg_data_home(monkeypatch):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
def test_init_with_local_cache(monkeypatch):
worker = BaseWorker(use_cache=True)
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.cache is not None
def test_init_var_ponos_data_given(monkeypatch):
path = str(Path(__file__).absolute().parent)
monkeypatch.setenv("PONOS_DATA", path)
......
# -*- coding: utf-8 -*-
import json
import os
import sqlite3
from pathlib import Path
import pytest
from arkindex_worker.cache import CachedElement, LocalDB
from arkindex_worker.utils import convert_str_uuid_to_hex
FIXTURES = Path(__file__).absolute().parent / "data/cache"
ELEMENTS_TO_INSERT = [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
]
def test_init_non_existent_path():
with pytest.raises(sqlite3.OperationalError) as e:
LocalDB("path/not/found.sqlite")
assert str(e.value) == "unable to open database file"
def test_init():
db_path = f"{FIXTURES}/db.sqlite"
LocalDB(db_path)
assert os.path.isfile(db_path)
def test_create_tables_existing_table():
db_path = f"{FIXTURES}/tables.sqlite"
cache = LocalDB(db_path)
with open(db_path, "rb") as before_file:
before = before_file.read()
cache.create_tables()
with open(db_path, "rb") as after_file:
after = after_file.read()
assert before == after, "Cache was modified"
def test_create_tables():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
expected_cache = LocalDB(f"{FIXTURES}/tables.sqlite")
# For each table in our new generated cache, we are checking that its structure
# is the same as the one saved in data/tables.sqlite
for table in cache.cursor.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
):
name = table["name"]
expected_table = expected_cache.cursor.execute(
f"SELECT sql FROM sqlite_master WHERE name = '{name}'"
).fetchone()["sql"]
generated_table = cache.cursor.execute(
f"SELECT sql FROM sqlite_master WHERE name = '{name}'"
).fetchone()["sql"]
assert expected_table == generated_table
def test_insert_empty_lines():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
cache.insert("elements", [])
expected_cache = LocalDB(f"{FIXTURES}/tables.sqlite")
assert (
cache.cursor.execute("SELECT * FROM elements").fetchall()
== expected_cache.cursor.execute("SELECT * FROM elements").fetchall()
)
def test_insert_existing_lines():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
with open(db_path, "rb") as before_file:
before = before_file.read()
with pytest.raises(sqlite3.IntegrityError) as e:
cache.insert("elements", ELEMENTS_TO_INSERT)
assert str(e.value) == "UNIQUE constraint failed: elements.id"
with open(db_path, "rb") as after_file:
after = after_file.read()
assert before == after, "Cache was modified"
def test_insert():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
cache.insert("elements", ELEMENTS_TO_INSERT)
generated_rows = cache.cursor.execute("SELECT * FROM elements").fetchall()
expected_cache = LocalDB(f"{FIXTURES}/lines.sqlite")
assert (
generated_rows
== expected_cache.cursor.execute("SELECT * FROM elements").fetchall()
)
assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT
......@@ -3,13 +3,18 @@ import json
import os
import tempfile
from argparse import Namespace
from pathlib import Path
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
from arkindex_worker.utils import convert_str_uuid_to_hex
from arkindex_worker.worker import ElementsWorker
CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker):
_, path = tempfile.mkstemp()
......@@ -637,7 +642,7 @@ def test_create_elements_api_error(responses, mock_elements_worker):
]
def test_create_elements(responses, mock_elements_worker):
def test_create_elements(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -646,7 +651,7 @@ def test_create_elements(responses, mock_elements_worker):
json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
)
created_ids = mock_elements_worker.create_elements(
created_ids = mock_elements_worker_with_cache.create_elements(
parent=elt,
elements=[
{
......@@ -675,6 +680,25 @@ def test_create_elements(responses, mock_elements_worker):
}
assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
# Check that created elements were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM elements"
).fetchall()
assert [CachedElement(**dict(row)) for row in rows] == [
CachedElement(
id=convert_str_uuid_to_hex("497f6eca-6276-4993-bfeb-53cbbbba6f08"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
)
]
def test_list_element_children_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
......