Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (3)
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,pytest,requests,setuptools,sh,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,pytest,requests,setuptools,sh,tenacity,yaml
# -*- coding: utf-8 -*-
import sqlite3
from collections import namedtuple
from arkindex_worker import logger
SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
id VARCHAR(32) PRIMARY KEY,
parent_id VARCHAR(32),
type TEXT NOT NULL,
polygon TEXT,
initial BOOLEAN DEFAULT 0 NOT NULL,
worker_version_id VARCHAR(32)
)"""
SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions (
id VARCHAR(32) PRIMARY KEY,
element_id VARCHAR(32) NOT NULL,
text TEXT NOT NULL,
confidence REAL NOT NULL,
worker_version_id VARCHAR(32) NOT NULL,
FOREIGN KEY(element_id) REFERENCES elements(id)
)"""
CachedElement = namedtuple(
"CachedElement",
["id", "type", "polygon", "worker_version_id", "parent_id", "initial"],
defaults=[None, 0],
)
CachedTranscription = namedtuple(
"CachedTranscription",
["id", "element_id", "text", "confidence", "worker_version_id"],
import json
import logging
from peewee import (
BooleanField,
CharField,
Field,
FloatField,
ForeignKeyField,
Model,
SqliteDatabase,
TextField,
UUIDField,
)
logger = logging.getLogger(__name__)
def convert_table_tuple(table):
if table == "elements":
return CachedElement
else:
raise NotImplementedError
db = SqliteDatabase(None)
class LocalDB(object):
def __init__(self, path):
self.db = sqlite3.connect(path)
self.db.row_factory = sqlite3.Row
self.cursor = self.db.cursor()
logger.info(f"Connection to local cache {path} established.")
class JSONField(Field):
field_type = "text"
def create_tables(self):
self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION)
def db_value(self, value):
if value is None:
return
return json.dumps(value)
def insert(self, table, lines):
if not lines:
def python_value(self, value):
if value is None:
return
columns = ", ".join(lines[0]._fields)
placeholders = ", ".join("?" * len(lines[0]))
values = [tuple(line) for line in lines]
self.cursor.executemany(
f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
)
self.db.commit()
def fetch(self, table, where=[]):
"""
where parameter is a list containing 3-values tuples defining an SQL WHERE condition.
e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")]
stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL.
This method only supports 'AND' SQL conditions.
"""
sql = f"SELECT * FROM {table}"
if where:
assert isinstance(where, list), "where should be a list"
assert all(
isinstance(condition, tuple) and len(condition) == 3
for condition in where
), "where conditions should be tuples of 3 values"
sql += " WHERE "
sql += " AND ".join(
[f"{field} {operator} (?)" for field, operator, _ in where]
)
self.cursor.execute(sql, [value for _, _, value in where])
tuple_type = convert_table_tuple(table)
return [tuple_type(**dict(row)) for row in self.cursor.fetchall()]
return json.loads(value)
class CachedElement(Model):
id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True)
type = CharField(max_length=50)
polygon = JSONField(null=True)
initial = BooleanField(default=False)
worker_version_id = UUIDField(null=True)
class Meta:
database = db
table_name = "elements"
class CachedTranscription(Model):
id = UUIDField(primary_key=True)
element_id = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField()
confidence = FloatField()
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "transcriptions"
def init_cache_db(path):
db.init(
path,
pragmas={
# SQLite ignores foreign keys and check constraints by default!
"foreign_keys": 1,
"ignore_check_constraints": 0,
},
)
db.connect()
logger.info(f"Connected to cache on {path}")
def create_tables():
"""
Creates the tables in the cache DB only if they do not already exist.
"""
db.create_tables([CachedElement, CachedTranscription])
# -*- coding: utf-8 -*-
import datetime
import uuid
from timeit import default_timer
......@@ -20,7 +19,3 @@ class Timer(object):
end = self.timer()
self.elapsed = end - self.start
self.delta = datetime.timedelta(seconds=self.elapsed)
def convert_str_uuid_to_hex(id):
return uuid.UUID(id).hex
This diff is collapsed.
arkindex-client==1.0.6
peewee==3.14.4
Pillow==8.1.0
python-gitlab==2.6.0
python-gnupg==0.4.6
sh==1.14.1
tenacity==6.3.1
tenacity==7.0.0
......@@ -3,22 +3,33 @@ import hashlib
import json
import os
import sys
import time
from pathlib import Path
from uuid import UUID
import pytest
import yaml
from arkindex.mock import MockApiClient
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import ElementsWorker
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
CACHE_DIR = str(Path(__file__).resolve().parent / "data/cache")
CACHE_FILE = os.path.join(CACHE_DIR, "db.sqlite")
__yaml_cache = {}
@pytest.fixture(autouse=True)
def disable_sleep(monkeypatch):
"""
Do not sleep at all in between API executions
when errors occur in unit tests.
This speeds up the test execution a lot
"""
monkeypatch.setattr(time, "sleep", lambda x: None)
@pytest.fixture
def cache_yaml(monkeypatch):
"""
......@@ -80,17 +91,12 @@ def setup_api(responses, monkeypatch, cache_yaml):
@pytest.fixture(autouse=True)
def handle_cache_file(monkeypatch):
def temp_working_directory(monkeypatch, tmp_path):
def _getcwd():
return CACHE_DIR
return str(tmp_path)
monkeypatch.setattr(os, "getcwd", _getcwd)
yield
if os.path.isfile(CACHE_FILE):
os.remove(CACHE_FILE)
@pytest.fixture(autouse=True)
def give_worker_version_id_env_variable(monkeypatch):
......@@ -224,3 +230,48 @@ def fake_gitlab_helper_factory():
)
return run
@pytest.fixture
def mock_cached_elements():
"""Insert few elements in local cache"""
CachedElement.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
parent_id="12341234-1234-1234-1234-123412341234",
type="something",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
@pytest.fixture
def mock_cached_transcriptions():
"""Insert few transcriptions in local cache, on a shared element"""
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello!",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="How are you?",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
File deleted
......@@ -33,7 +33,7 @@ def test_init_with_local_cache(monkeypatch):
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.cache is not None
assert worker.use_cache is True
def test_init_var_ponos_data_given(monkeypatch):
......
# -*- coding: utf-8 -*-
import json
import os
import sqlite3
from pathlib import Path
import pytest
from peewee import OperationalError
from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB
from arkindex_worker.utils import convert_str_uuid_to_hex
FIXTURES = Path(__file__).absolute().parent / "data/cache"
ELEMENTS_TO_INSERT = [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
]
TRANSCRIPTIONS_TO_INSERT = [
CachedTranscription(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="Hello!",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedTranscription(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
text="How are you?",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
]
from arkindex_worker.cache import create_tables, db, init_cache_db
def test_init_non_existent_path():
with pytest.raises(sqlite3.OperationalError) as e:
LocalDB("path/not/found.sqlite")
with pytest.raises(OperationalError) as e:
init_cache_db("path/not/found.sqlite")
assert str(e.value) == "unable to open database file"
def test_init():
db_path = f"{FIXTURES}/db.sqlite"
LocalDB(db_path)
def test_init(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
assert os.path.isfile(db_path)
def test_create_tables_existing_table():
db_path = f"{FIXTURES}/tables.sqlite"
cache = LocalDB(db_path)
def test_create_tables_existing_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
with open(db_path, "rb") as before_file:
before = before_file.read()
cache.create_tables()
with open(db_path, "rb") as after_file:
after = after_file.read()
assert before == after, "Cache was modified"
def test_create_tables():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
expected_cache = LocalDB(f"{FIXTURES}/tables.sqlite")
# For each table in our new generated cache, we are checking that its structure
# is the same as the one saved in data/tables.sqlite
for table in cache.cursor.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
):
name = table["name"]
expected_table = expected_cache.cursor.execute(
f"SELECT sql FROM sqlite_master WHERE name = '{name}'"
).fetchone()["sql"]
generated_table = cache.cursor.execute(
f"SELECT sql FROM sqlite_master WHERE name = '{name}'"
).fetchone()["sql"]
assert expected_table == generated_table
def test_insert_empty_lines():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
cache.insert("elements", [])
expected_cache = LocalDB(f"{FIXTURES}/tables.sqlite")
assert (
cache.cursor.execute("SELECT * FROM elements").fetchall()
== expected_cache.cursor.execute("SELECT * FROM elements").fetchall()
)
def test_insert_existing_lines():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
# Create the tables once…
init_cache_db(db_path)
create_tables()
db.close()
with open(db_path, "rb") as before_file:
before = before_file.read()
with pytest.raises(sqlite3.IntegrityError) as e:
cache.insert("elements", ELEMENTS_TO_INSERT)
assert str(e.value) == "UNIQUE constraint failed: elements.id"
with pytest.raises(sqlite3.IntegrityError) as e:
cache.insert("transcriptions", TRANSCRIPTIONS_TO_INSERT)
assert str(e.value) == "UNIQUE constraint failed: transcriptions.id"
# Create them again
init_cache_db(db_path)
create_tables()
with open(db_path, "rb") as after_file:
after = after_file.read()
assert before == after, "Cache was modified"
assert before == after, "Existing table structure was modified"
def test_insert():
db_path = f"{FIXTURES}/db.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
cache.insert("elements", ELEMENTS_TO_INSERT)
generated_rows = cache.cursor.execute("SELECT * FROM elements").fetchall()
def test_create_tables(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_tables()
expected_cache = LocalDB(f"{FIXTURES}/lines.sqlite")
assert (
generated_rows
== expected_cache.cursor.execute("SELECT * FROM elements").fetchall()
)
expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT)
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT
cache.insert("transcriptions", TRANSCRIPTIONS_TO_INSERT)
generated_rows = cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
expected_cache = LocalDB(f"{FIXTURES}/lines.sqlite")
assert (
generated_rows
== expected_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
actual_schema = "\n".join(
[
row[0]
for row in db.connection()
.execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name")
.fetchall()
]
)
assert [
CachedTranscription(**dict(row)) for row in generated_rows
] == TRANSCRIPTIONS_TO_INSERT
def test_fetch_all():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
children = cache.fetch("elements")
assert children == ELEMENTS_TO_INSERT
def test_fetch_with_where():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
children = cache.fetch(
"elements",
where=[
(
"parent_id",
"=",
convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
),
("id", "LIKE", "%1111%"),
],
)
assert children == [ELEMENTS_TO_INSERT[0]]
assert expected_schema == actual_schema
......@@ -362,10 +362,15 @@ def test_create_classification_api_error(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
]
......
......@@ -3,47 +3,15 @@ import json
import os
import tempfile
from argparse import Namespace
from pathlib import Path
from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
from arkindex_worker.utils import convert_str_uuid_to_hex
from arkindex_worker.worker import ElementsWorker
CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
ELEMENTS_TO_INSERT = [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="page",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("33333333-3333-3333-3333-333333333333"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"90129012-9012-9012-9012-901290129012"
),
),
]
def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker):
_, path = tempfile.mkstemp()
......@@ -378,10 +346,15 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
]
......@@ -663,15 +636,20 @@ def test_create_elements_api_error(responses, mock_elements_worker):
],
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
]
def test_create_elements(responses, mock_elements_worker_with_cache):
def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -710,25 +688,66 @@ def test_create_elements(responses, mock_elements_worker_with_cache):
assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
# Check that created elements were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
assert os.path.isfile(tmp_path / "db.sqlite")
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM elements"
).fetchall()
assert [CachedElement(**dict(row)) for row in rows] == [
assert list(CachedElement.select()) == [
CachedElement(
id=convert_str_uuid_to_hex("497f6eca-6276-4993-bfeb-53cbbbba6f08"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
id=UUID("497f6eca-6276-4993-bfeb-53cbbbba6f08"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_elements_integrity_error(
responses, mock_elements_worker_with_cache, caplog
):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
status=200,
json=[
# Duplicate IDs, which will cause an IntegrityError when stored in the cache
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
],
)
elements = [
{
"name": "0",
"type": "something",
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
},
{
"name": "1",
"type": "something",
"polygon": [[1, 1], [3, 3], [3, 1], [1, 3]],
},
]
created_ids = mock_elements_worker_with_cache.create_elements(
parent=elt,
elements=elements,
)
assert created_ids == [
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
]
assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"
assert caplog.records[0].message.startswith(
"Couldn't save created elements in local cache:"
)
assert list(CachedElement.select()) == []
def test_list_element_children_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(element=None)
......@@ -951,51 +970,67 @@ def test_list_element_children_with_cache_unhandled_param(
)
def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt)
):
assert child == []
# Initialize SQLite cache with some elements
mock_elements_worker_with_cache.cache.insert("elements", ELEMENTS_TO_INSERT)
expected_children = ELEMENTS_TO_INSERT
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt)
):
assert child == expected_children[idx]
expected_children = [ELEMENTS_TO_INSERT[1]]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt, type="page")
):
assert child == expected_children[idx]
expected_children = ELEMENTS_TO_INSERT[:2]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(
element=elt, worker_version="56785678-5678-5678-5678-567856785678"
)
):
assert child == expected_children[idx]
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give all elements inserted
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
},
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element and page should give the second element
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"type": "page",
},
("22222222-2222-2222-2222-222222222222",),
),
# Filter on element and worker version should give all elements
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element, type something and worker version should give first
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"type": "something",
"worker_version": "56785678-5678-5678-5678-567856785678",
},
("11111111-1111-1111-1111-111111111111",),
),
),
)
def test_list_element_children_with_cache(
responses,
mock_elements_worker_with_cache,
mock_cached_elements,
filters,
expected_ids,
):
expected_children = [ELEMENTS_TO_INSERT[0]]
# Check we have 2 elements already present in database
assert CachedElement.select().count() == 2
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(
element=elt,
type="something",
worker_version="56785678-5678-5678-5678-567856785678",
)
):
assert child == expected_children[idx]
# Query database through cache
elements = mock_elements_worker_with_cache.list_element_children(**filters)
assert elements.count() == len(expected_ids)
for child, expected_id in zip(elements.order_by("id"), expected_ids):
assert child.id == UUID(expected_id)
# Check the worker never hits the API for elements
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
......
......@@ -147,10 +147,15 @@ def test_create_entity_api_error(responses, mock_elements_worker):
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
]
......
......@@ -133,10 +133,15 @@ def test_create_metadata_api_error(responses, mock_elements_worker):
value="La Turbine, Grenoble 38000",
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
]
......
# -*- coding: utf-8 -*-
import json
import os
from pathlib import Path
from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedTranscription
from arkindex_worker.models import Element
from arkindex_worker.utils import convert_str_uuid_to_hex
CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
TRANSCRIPTIONS_SAMPLE = [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
......@@ -37,7 +34,10 @@ def test_create_transcription_wrong_element(mock_elements_worker):
text="i am a line",
score=0.42,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription(
......@@ -45,7 +45,10 @@ def test_create_transcription_wrong_element(mock_elements_worker):
text="i am a line",
score=0.42,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_transcription_wrong_text(mock_elements_worker):
......@@ -127,16 +130,22 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
score=0.42,
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
f"http://testserver/api/v1/element/{elt.id}/transcription/",
]
def test_create_transcription(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcription/",
......@@ -170,21 +179,13 @@ def test_create_transcription(responses, mock_elements_worker_with_cache):
}
# Check that created transcription was properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
assert list(CachedTranscription.select()) == [
CachedTranscription(
id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
element_id=convert_str_uuid_to_hex(elt.id),
id=UUID("56785678-5678-5678-5678-567856785678"),
element_id=UUID(elt.id),
text="i am a line",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
......@@ -442,15 +443,21 @@ def test_create_transcriptions_api_error(responses, mock_elements_worker):
with pytest.raises(ErrorResponse):
mock_elements_worker.create_transcriptions(transcriptions=trans)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
"http://testserver/api/v1/transcription/bulk/",
]
def test_create_transcriptions(responses, mock_elements_worker_with_cache):
CachedElement.create(id="11111111-1111-1111-1111-111111111111", type="thing")
trans = [
{
"element_id": "11111111-1111-1111-1111-111111111111",
......@@ -504,30 +511,20 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
}
# Check that created transcriptions were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
assert list(CachedTranscription.select()) == [
CachedTranscription(
id=convert_str_uuid_to_hex("00000000-0000-0000-0000-000000000000"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="The",
confidence=0.75,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedTranscription(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="word",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
......@@ -539,7 +536,10 @@ def test_create_element_transcriptions_wrong_element(mock_elements_worker):
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
......@@ -547,7 +547,10 @@ def test_create_element_transcriptions_wrong_element(mock_elements_worker):
sub_element_type="page",
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_element_transcriptions_wrong_sub_element_type(mock_elements_worker):
......@@ -917,16 +920,22 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
transcriptions=TRANSCRIPTIONS_SAMPLE,
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
]
def test_create_element_transcriptions(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
......@@ -988,53 +997,43 @@ def test_create_element_transcriptions(responses, mock_elements_worker_with_cach
]
# Check that created transcriptions and elements were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM elements"
).fetchall()
assert [CachedElement(**dict(row)) for row in rows] == [
assert list(CachedElement.select()) == [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
id=UUID("11111111-1111-1111-1111-111111111111"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon=json.dumps([[100, 150], [700, 150], [700, 200], [100, 200]]),
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
)
polygon="[[100, 150], [700, 150], [700, 200], [100, 200]]",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedElement(
id=UUID("22222222-2222-2222-2222-222222222222"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[0, 0], [2000, 0], [2000, 3000], [0, 3000]]",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
assert list(CachedTranscription.select()) == [
CachedTranscription(
id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
id=UUID("56785678-5678-5678-5678-567856785678"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="The",
confidence=0.5,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedTranscription(
id=convert_str_uuid_to_hex("67896789-6789-6789-6789-678967896789"),
element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
id=UUID("67896789-6789-6789-6789-678967896789"),
element_id=UUID("22222222-2222-2222-2222-222222222222"),
text="first",
confidence=0.75,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedTranscription(
id=convert_str_uuid_to_hex("78907890-7890-7890-7890-789078907890"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
id=UUID("78907890-7890-7890-7890-789078907890"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="line",
confidence=0.9,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
......@@ -1155,3 +1154,103 @@ def test_list_transcriptions(responses, mock_elements_worker):
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/",
]
def test_list_transcriptions_with_cache_unhandled_param(
responses, mock_elements_worker_with_cache
):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker_with_cache.list_transcriptions(
element=elt, element_type="page"
)
assert (
str(e.value)
== "When using the local cache, you can only filter by 'worker_version'"
)
def test_list_transcriptions_with_cache_skip_recursive(
responses, mock_elements_worker_with_cache
):
# When the local cache is activated and the user defines the recursive filter, we should fallback to the API
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
trans = [
{
"id": "0000",
"text": "hey",
"confidence": 0.42,
"worker_version_id": "56785678-5678-5678-5678-567856785678",
"element": None,
},
]
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True",
status=200,
json={
"count": 3,
"next": None,
"results": trans,
},
)
for idx, transcription in enumerate(
mock_elements_worker_with_cache.list_transcriptions(element=elt, recursive=True)
):
assert transcription == trans[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True",
]
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give all elements inserted
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
},
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element and worker version should give first element
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
("11111111-1111-1111-1111-111111111111",),
),
),
)
def test_list_transcriptions_with_cache(
responses,
mock_elements_worker_with_cache,
mock_cached_transcriptions,
filters,
expected_ids,
):
# Check we have 2 elements already present in database
assert CachedTranscription.select().count() == 2
# Query database through cache
transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters)
assert transcriptions.count() == len(expected_ids)
for transcription, expected_id in zip(transcriptions.order_by("id"), expected_ids):
assert transcription.id == UUID(expected_id)
# Check the worker never hits the API for elements
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]