Skip to content
Snippets Groups Projects
Commit 9bc3eb43 authored by Eva Bardou's avatar Eva Bardou Committed by Erwan Rouchet
Browse files

Add Version table in SQLite cache + Check compatibility from tasks

parent e141a461
No related branches found
No related tags found
1 merge request!151Add Version table in SQLite cache + Check compatibility from tasks
Pipeline #79052 passed
...@@ -13,6 +13,7 @@ from peewee import ( ...@@ -13,6 +13,7 @@ from peewee import (
ForeignKeyField, ForeignKeyField,
IntegerField, IntegerField,
Model, Model,
OperationalError,
SqliteDatabase, SqliteDatabase,
TextField, TextField,
UUIDField, UUIDField,
...@@ -38,6 +39,14 @@ class JSONField(Field): ...@@ -38,6 +39,14 @@ class JSONField(Field):
return json.loads(value) return json.loads(value)
class Version(Model):
version = IntegerField(primary_key=True)
class Meta:
database = db
table_name = "version"
class CachedImage(Model): class CachedImage(Model):
id = UUIDField(primary_key=True) id = UUIDField(primary_key=True)
width = IntegerField() width = IntegerField()
...@@ -184,6 +193,7 @@ MODELS = [ ...@@ -184,6 +193,7 @@ MODELS = [
CachedEntity, CachedEntity,
CachedTranscriptionEntity, CachedTranscriptionEntity,
] ]
SQL_VERSION = 1
def init_cache_db(path): def init_cache_db(path):
...@@ -206,6 +216,30 @@ def create_tables(): ...@@ -206,6 +216,30 @@ def create_tables():
db.create_tables(MODELS) db.create_tables(MODELS)
def create_version_table():
"""
Creates the Version table in the cache DB.
This step must be independent from other tables creation since we only
want to create the table and add the one and only Version entry when the
cache is created from scratch.
"""
db.create_tables([Version])
Version.create(version=SQL_VERSION)
def check_version(cache_path):
with SqliteDatabase(cache_path) as provided_db:
with provided_db.bind_ctx([Version]):
try:
version = Version.get().version
except OperationalError:
version = None
assert (
version == SQL_VERSION
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None): def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
assert isinstance(parent_ids, list) assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
...@@ -247,6 +281,9 @@ def merge_parents_cache(paths, current_database): ...@@ -247,6 +281,9 @@ def merge_parents_cache(paths, current_database):
# Merge each table into the local database # Merge each table into the local database
for idx, path in enumerate(paths): for idx, path in enumerate(paths):
# Check that the parent cache uses a compatible version
check_version(path)
with SqliteDatabase(path) as source: with SqliteDatabase(path) as source:
with source.bind_ctx(MODELS): with source.bind_ctx(MODELS):
source.create_tables(MODELS) source.create_tables(MODELS)
......
...@@ -19,7 +19,9 @@ from tenacity import ( ...@@ -19,7 +19,9 @@ from tenacity import (
from arkindex import ArkindexClient, options_from_env from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import ( from arkindex_worker.cache import (
check_version,
create_tables, create_tables,
create_version_table,
init_cache_db, init_cache_db,
merge_parents_cache, merge_parents_cache,
retrieve_parents_cache_path, retrieve_parents_cache_path,
...@@ -203,6 +205,12 @@ class BaseWorker(object): ...@@ -203,6 +205,12 @@ class BaseWorker(object):
self.cache_path = os.path.join(cache_dir, "db.sqlite") self.cache_path = os.path.join(cache_dir, "db.sqlite")
init_cache_db(self.cache_path) init_cache_db(self.cache_path)
if self.args.database is not None:
check_version(self.cache_path)
else:
create_version_table()
create_tables() create_tables()
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
......
...@@ -12,7 +12,15 @@ import yaml ...@@ -12,7 +12,15 @@ import yaml
from peewee import SqliteDatabase from peewee import SqliteDatabase
from arkindex.mock import MockApiClient from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription from arkindex_worker.cache import (
MODELS,
SQL_VERSION,
CachedElement,
CachedTranscription,
Version,
create_version_table,
init_cache_db,
)
from arkindex_worker.git import GitHelper, GitlabHelper from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import BaseWorker, ElementsWorker from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation from arkindex_worker.worker.transcription import TextOrientation
...@@ -260,7 +268,8 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api): ...@@ -260,7 +268,8 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path): def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest""" """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
cache_path = tmp_path / "db.sqlite" cache_path = tmp_path / "db.sqlite"
cache_path.touch() init_cache_db(cache_path)
create_version_table()
monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)]) monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)])
worker = ElementsWorker(support_cache=True) worker = ElementsWorker(support_cache=True)
...@@ -433,9 +442,11 @@ def mock_databases(tmpdir): ...@@ -433,9 +442,11 @@ def mock_databases(tmpdir):
path = tmpdir / name / filename path = tmpdir / name / filename
(tmpdir / name).mkdir() (tmpdir / name).mkdir()
local_db = SqliteDatabase(path) local_db = SqliteDatabase(path)
with local_db.bind_ctx(MODELS): with local_db.bind_ctx(MODELS + [Version]):
# Create tables on the current local database # Create tables on the current local database
# by binding temporarily the models on that database # by binding temporarily the models on that database
local_db.create_tables([Version])
Version.create(version=SQL_VERSION)
local_db.create_tables(MODELS) local_db.create_tables(MODELS)
out[name] = {"path": path, "db": local_db} out[name] = {"path": path, "db": local_db}
......
...@@ -6,9 +6,13 @@ import pytest ...@@ -6,9 +6,13 @@ import pytest
from peewee import OperationalError from peewee import OperationalError
from arkindex_worker.cache import ( from arkindex_worker.cache import (
SQL_VERSION,
CachedElement, CachedElement,
CachedImage, CachedImage,
Version,
check_version,
create_tables, create_tables,
create_version_table,
db, db,
init_cache_db, init_cache_db,
) )
...@@ -73,6 +77,76 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT ...@@ -73,6 +77,76 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
assert expected_schema == actual_schema assert expected_schema == actual_schema
def test_create_version_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_version_table()
expected_schema = 'CREATE TABLE "version" ("version" INTEGER NOT NULL PRIMARY KEY)'
actual_schema = "\n".join(
[
row[0]
for row in db.connection()
.execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name")
.fetchall()
]
)
assert expected_schema == actual_schema
assert Version.select().count() == 1
assert Version.get() == Version(version=SQL_VERSION)
def test_check_version_unset_version(tmp_path):
"""
The cache misses the version table
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
with pytest.raises(AssertionError) as e:
check_version(db_path)
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_differing_version(tmp_path):
"""
The cache has a differing version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
fake_version = 420
assert fake_version != SQL_VERSION
db.create_tables([Version])
Version.create(version=fake_version)
db.close()
with pytest.raises(AssertionError) as e:
check_version(db_path)
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_same_version(tmp_path):
"""
The cache has the expected version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_version_table()
db.close()
check_version(db_path)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url", "image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url",
[ [
......
...@@ -6,7 +6,13 @@ from uuid import UUID ...@@ -6,7 +6,13 @@ from uuid import UUID
import pytest import pytest
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedImage from arkindex_worker.cache import (
SQL_VERSION,
CachedElement,
CachedImage,
create_version_table,
init_cache_db,
)
from arkindex_worker.models import Element from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker.element import MissingTypeError from arkindex_worker.worker.element import MissingTypeError
...@@ -177,7 +183,8 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path): ...@@ -177,7 +183,8 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path):
def test_database_arg(mocker, mock_elements_worker, tmp_path): def test_database_arg(mocker, mock_elements_worker, tmp_path):
database_path = tmp_path / "my_database.sqlite" database_path = tmp_path / "my_database.sqlite"
database_path.touch() init_cache_db(database_path)
create_version_table()
mocker.patch( mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
...@@ -197,6 +204,32 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path): ...@@ -197,6 +204,32 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path):
assert worker.cache_path == str(database_path) assert worker.cache_path == str(database_path)
def test_database_arg_cache_missing_version_table(
mocker, mock_elements_worker, tmp_path
):
database_path = tmp_path / "my_database.sqlite"
database_path.touch()
mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
return_value=Namespace(
element=["volumeid", "pageid"],
verbose=False,
elements_list=None,
database=str(database_path),
dev=False,
),
)
worker = ElementsWorker(support_cache=True)
with pytest.raises(AssertionError) as e:
worker.configure()
assert (
str(e.value)
== f"The SQLite database {database_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_load_corpus_classes_api_error(responses, mock_elements_worker): def test_load_corpus_classes_api_error(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234" corpus_id = "12341234-1234-1234-1234-123412341234"
responses.add( responses.add(
......
...@@ -5,12 +5,14 @@ import pytest ...@@ -5,12 +5,14 @@ import pytest
from arkindex_worker.cache import ( from arkindex_worker.cache import (
MODELS, MODELS,
SQL_VERSION,
CachedClassification, CachedClassification,
CachedElement, CachedElement,
CachedEntity, CachedEntity,
CachedImage, CachedImage,
CachedTranscription, CachedTranscription,
CachedTranscriptionEntity, CachedTranscriptionEntity,
Version,
merge_parents_cache, merge_parents_cache,
retrieve_parents_cache_path, retrieve_parents_cache_path,
) )
...@@ -196,3 +198,30 @@ def test_merge_from_worker( ...@@ -196,3 +198,30 @@ def test_merge_from_worker(
assert [t.id for t in CachedTranscription.select().order_by("id")] == [ assert [t.id for t in CachedTranscription.select().order_by("id")] == [
UUID("11111111-1111-1111-1111-111111111111"), UUID("11111111-1111-1111-1111-111111111111"),
] ]
def test_merge_conflicting_versions(mock_databases, tmpdir):
"""
Merging databases with differing versions should not be allowed
"""
with mock_databases["second"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
fake_version = 420
Version.update(version=fake_version).execute()
assert Version.get() == Version(version=fake_version)
with mock_databases["target"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
# Retrieve parents databases paths
paths = retrieve_parents_cache_path(["first", "second"], data_dir=tmpdir)
# Merge all requested parents databases into our target, the "second" parent have a differing version
with pytest.raises(AssertionError) as e:
merge_parents_cache(paths, mock_databases["target"]["path"])
assert (
str(e.value)
== f"The SQLite database {paths[1]} does not have the correct cache version, it should be {SQL_VERSION}"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment