Skip to content
Snippets Groups Projects
Commit 9bc3eb43 authored by Eva Bardou's avatar Eva Bardou Committed by Erwan Rouchet
Browse files

Add Version table in SQLite cache + Check compatibility from tasks

parent e141a461
No related branches found
No related tags found
1 merge request!151Add Version table in SQLite cache + Check compatibility from tasks
Pipeline #79052 passed
......@@ -13,6 +13,7 @@ from peewee import (
ForeignKeyField,
IntegerField,
Model,
OperationalError,
SqliteDatabase,
TextField,
UUIDField,
......@@ -38,6 +39,14 @@ class JSONField(Field):
return json.loads(value)
class Version(Model):
version = IntegerField(primary_key=True)
class Meta:
database = db
table_name = "version"
class CachedImage(Model):
id = UUIDField(primary_key=True)
width = IntegerField()
......@@ -184,6 +193,7 @@ MODELS = [
CachedEntity,
CachedTranscriptionEntity,
]
SQL_VERSION = 1
def init_cache_db(path):
......@@ -206,6 +216,30 @@ def create_tables():
db.create_tables(MODELS)
def create_version_table():
"""
Creates the Version table in the cache DB.
This step must be independent from other tables creation since we only
want to create the table and add the one and only Version entry when the
cache is created from scratch.
"""
db.create_tables([Version])
Version.create(version=SQL_VERSION)
def check_version(cache_path):
with SqliteDatabase(cache_path) as provided_db:
with provided_db.bind_ctx([Version]):
try:
version = Version.get().version
except OperationalError:
version = None
assert (
version == SQL_VERSION
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
......@@ -247,6 +281,9 @@ def merge_parents_cache(paths, current_database):
# Merge each table into the local database
for idx, path in enumerate(paths):
# Check that the parent cache uses a compatible version
check_version(path)
with SqliteDatabase(path) as source:
with source.bind_ctx(MODELS):
source.create_tables(MODELS)
......
......@@ -19,7 +19,9 @@ from tenacity import (
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import (
check_version,
create_tables,
create_version_table,
init_cache_db,
merge_parents_cache,
retrieve_parents_cache_path,
......@@ -203,6 +205,12 @@ class BaseWorker(object):
self.cache_path = os.path.join(cache_dir, "db.sqlite")
init_cache_db(self.cache_path)
if self.args.database is not None:
check_version(self.cache_path)
else:
create_version_table()
create_tables()
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
......
......@@ -12,7 +12,15 @@ import yaml
from peewee import SqliteDatabase
from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.cache import (
MODELS,
SQL_VERSION,
CachedElement,
CachedTranscription,
Version,
create_version_table,
init_cache_db,
)
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
......@@ -260,7 +268,8 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
cache_path = tmp_path / "db.sqlite"
cache_path.touch()
init_cache_db(cache_path)
create_version_table()
monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)])
worker = ElementsWorker(support_cache=True)
......@@ -433,9 +442,11 @@ def mock_databases(tmpdir):
path = tmpdir / name / filename
(tmpdir / name).mkdir()
local_db = SqliteDatabase(path)
with local_db.bind_ctx(MODELS):
with local_db.bind_ctx(MODELS + [Version]):
# Create tables on the current local database
# by binding temporarily the models on that database
local_db.create_tables([Version])
Version.create(version=SQL_VERSION)
local_db.create_tables(MODELS)
out[name] = {"path": path, "db": local_db}
......
......@@ -6,9 +6,13 @@ import pytest
from peewee import OperationalError
from arkindex_worker.cache import (
SQL_VERSION,
CachedElement,
CachedImage,
Version,
check_version,
create_tables,
create_version_table,
db,
init_cache_db,
)
......@@ -73,6 +77,76 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
assert expected_schema == actual_schema
def test_create_version_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_version_table()
expected_schema = 'CREATE TABLE "version" ("version" INTEGER NOT NULL PRIMARY KEY)'
actual_schema = "\n".join(
[
row[0]
for row in db.connection()
.execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name")
.fetchall()
]
)
assert expected_schema == actual_schema
assert Version.select().count() == 1
assert Version.get() == Version(version=SQL_VERSION)
def test_check_version_unset_version(tmp_path):
"""
The cache misses the version table
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
with pytest.raises(AssertionError) as e:
check_version(db_path)
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_differing_version(tmp_path):
"""
The cache has a differing version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
fake_version = 420
assert fake_version != SQL_VERSION
db.create_tables([Version])
Version.create(version=fake_version)
db.close()
with pytest.raises(AssertionError) as e:
check_version(db_path)
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_same_version(tmp_path):
"""
The cache has the expected version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_version_table()
db.close()
check_version(db_path)
@pytest.mark.parametrize(
"image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url",
[
......
......@@ -6,7 +6,13 @@ from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.cache import (
SQL_VERSION,
CachedElement,
CachedImage,
create_version_table,
init_cache_db,
)
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker.element import MissingTypeError
......@@ -177,7 +183,8 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path):
def test_database_arg(mocker, mock_elements_worker, tmp_path):
database_path = tmp_path / "my_database.sqlite"
database_path.touch()
init_cache_db(database_path)
create_version_table()
mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
......@@ -197,6 +204,32 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path):
assert worker.cache_path == str(database_path)
def test_database_arg_cache_missing_version_table(
mocker, mock_elements_worker, tmp_path
):
database_path = tmp_path / "my_database.sqlite"
database_path.touch()
mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
return_value=Namespace(
element=["volumeid", "pageid"],
verbose=False,
elements_list=None,
database=str(database_path),
dev=False,
),
)
worker = ElementsWorker(support_cache=True)
with pytest.raises(AssertionError) as e:
worker.configure()
assert (
str(e.value)
== f"The SQLite database {database_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_load_corpus_classes_api_error(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
responses.add(
......
......@@ -5,12 +5,14 @@ import pytest
from arkindex_worker.cache import (
MODELS,
SQL_VERSION,
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
Version,
merge_parents_cache,
retrieve_parents_cache_path,
)
......@@ -196,3 +198,30 @@ def test_merge_from_worker(
assert [t.id for t in CachedTranscription.select().order_by("id")] == [
UUID("11111111-1111-1111-1111-111111111111"),
]
def test_merge_conflicting_versions(mock_databases, tmpdir):
"""
Merging databases with differing versions should not be allowed
"""
with mock_databases["second"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
fake_version = 420
Version.update(version=fake_version).execute()
assert Version.get() == Version(version=fake_version)
with mock_databases["target"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
# Retrieve parents databases paths
paths = retrieve_parents_cache_path(["first", "second"], data_dir=tmpdir)
# Merge all requested parents databases into our target, the "second" parent have a differing version
with pytest.raises(AssertionError) as e:
merge_parents_cache(paths, mock_databases["target"]["path"])
assert (
str(e.value)
== f"The SQLite database {paths[1]} does not have the correct cache version, it should be {SQL_VERSION}"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment