diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index b1be02e2787f3c06c53465bffbee91d737288ca4..aa40bbcbecd77f345a488607082747c55d6fa8b5 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -13,6 +13,7 @@ from peewee import ( ForeignKeyField, IntegerField, Model, + OperationalError, SqliteDatabase, TextField, UUIDField, @@ -38,6 +39,14 @@ class JSONField(Field): return json.loads(value) +class Version(Model): + version = IntegerField(primary_key=True) + + class Meta: + database = db + table_name = "version" + + class CachedImage(Model): id = UUIDField(primary_key=True) width = IntegerField() @@ -184,6 +193,7 @@ MODELS = [ CachedEntity, CachedTranscriptionEntity, ] +SQL_VERSION = 1 def init_cache_db(path): @@ -206,6 +216,30 @@ def create_tables(): db.create_tables(MODELS) +def create_version_table(): + """ + Creates the Version table in the cache DB. + This step must be independent from other tables creation since we only + want to create the table and add the one and only Version entry when the + cache is created from scratch. + """ + db.create_tables([Version]) + Version.create(version=SQL_VERSION) + + +def check_version(cache_path): + with SqliteDatabase(cache_path) as provided_db: + with provided_db.bind_ctx([Version]): + try: + version = Version.get().version + except OperationalError: + version = None + + assert ( + version == SQL_VERSION + ), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}" + + def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None): assert isinstance(parent_ids, list) assert os.path.isdir(data_dir) @@ -247,6 +281,9 @@ def merge_parents_cache(paths, current_database): # Merge each table into the local database for idx, path in enumerate(paths): + # Check that the parent cache uses a compatible version + check_version(path) + with SqliteDatabase(path) as source: with source.bind_ctx(MODELS): source.create_tables(MODELS) diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index 4ef07c0922aff671997a3d93a25bbbe80cb4dbf6..db01c555e1c79d2cf0b1fc9292a65e750cf7ecd0 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -19,7 +19,9 @@ from tenacity import ( from arkindex import ArkindexClient, options_from_env from arkindex_worker import logger from arkindex_worker.cache import ( + check_version, create_tables, + create_version_table, init_cache_db, merge_parents_cache, retrieve_parents_cache_path, @@ -203,6 +205,12 @@ class BaseWorker(object): self.cache_path = os.path.join(cache_dir, "db.sqlite") init_cache_db(self.cache_path) + + if self.args.database is not None: + check_version(self.cache_path) + else: + create_version_table() + create_tables() # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden diff --git a/tests/conftest.py b/tests/conftest.py index aa7e521f605b9561f60cab613efc794451c63667..04c8c7949f7102b1e8a492fb4f85971bbf995413 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,15 @@ import yaml from peewee import SqliteDatabase from arkindex.mock import MockApiClient -from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription +from arkindex_worker.cache import ( + MODELS, + SQL_VERSION, + CachedElement, + CachedTranscription, + Version, + create_version_table, + init_cache_db, +) from arkindex_worker.git import GitHelper, GitlabHelper from arkindex_worker.worker import BaseWorker, ElementsWorker from arkindex_worker.worker.transcription import TextOrientation @@ -260,7 +268,8 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api): def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path): """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest""" cache_path = tmp_path / "db.sqlite" - cache_path.touch() + init_cache_db(cache_path) + create_version_table() monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)]) worker = ElementsWorker(support_cache=True) @@ -433,9 +442,11 @@ def mock_databases(tmpdir): path = tmpdir / name / filename (tmpdir / name).mkdir() local_db = SqliteDatabase(path) - with local_db.bind_ctx(MODELS): + with local_db.bind_ctx(MODELS + [Version]): # Create tables on the current local database # by binding temporarily the models on that database + local_db.create_tables([Version]) + Version.create(version=SQL_VERSION) local_db.create_tables(MODELS) out[name] = {"path": path, "db": local_db} diff --git a/tests/test_cache.py b/tests/test_cache.py index fcc7e8e7406ec1d7b524e7d00b90efcdc0dc1812..94a2a3e4bb84e1db42fe6bc8d82475b3d10a711d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -6,9 +6,13 @@ import pytest from peewee import OperationalError from arkindex_worker.cache import ( + SQL_VERSION, CachedElement, CachedImage, + Version, + check_version, create_tables, + create_version_table, db, init_cache_db, ) @@ -73,6 +77,76 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT assert expected_schema == actual_schema +def test_create_version_table(tmp_path): + db_path = f"{tmp_path}/db.sqlite" + init_cache_db(db_path) + create_version_table() + + expected_schema = 'CREATE TABLE "version" ("version" INTEGER NOT NULL PRIMARY KEY)' + actual_schema = "\n".join( + [ + row[0] + for row in db.connection() + .execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name") + .fetchall() + ] + ) + + assert expected_schema == actual_schema + assert Version.select().count() == 1 + assert Version.get() == Version(version=SQL_VERSION) + + +def test_check_version_unset_version(tmp_path): + """ + The cache misses the version table + """ + db_path = f"{tmp_path}/db.sqlite" + init_cache_db(db_path) + + with pytest.raises(AssertionError) as e: + check_version(db_path) + + assert ( + str(e.value) + == f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}" + ) + + +def test_check_version_differing_version(tmp_path): + """ + The cache has a differing version + """ + db_path = f"{tmp_path}/db.sqlite" + init_cache_db(db_path) + + fake_version = 420 + assert fake_version != SQL_VERSION + db.create_tables([Version]) + Version.create(version=fake_version) + db.close() + + with pytest.raises(AssertionError) as e: + check_version(db_path) + + assert ( + str(e.value) + == f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}" + ) + + +def test_check_version_same_version(tmp_path): + """ + The cache has the expected version + """ + db_path = f"{tmp_path}/db.sqlite" + init_cache_db(db_path) + create_version_table() + db.close() + + check_version(db_path) + + @pytest.mark.parametrize( "image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url", [ diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 6e2f132bfc313d85567accea00f83dcb2e2eb5e2..fba5fe56cb0104796381d8d0b3fdc2c1e3a6865c 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -6,7 +6,13 @@ from uuid import UUID import pytest from apistar.exceptions import ErrorResponse -from arkindex_worker.cache import CachedElement, CachedImage +from arkindex_worker.cache import ( + SQL_VERSION, + CachedElement, + CachedImage, + create_version_table, + init_cache_db, +) from arkindex_worker.models import Element from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker.element import MissingTypeError @@ -177,7 +183,8 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path): def test_database_arg(mocker, mock_elements_worker, tmp_path): database_path = tmp_path / "my_database.sqlite" - database_path.touch() + init_cache_db(database_path) + create_version_table() mocker.patch( "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", @@ -197,6 +204,32 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path): assert worker.cache_path == str(database_path) +def test_database_arg_cache_missing_version_table( + mocker, mock_elements_worker, tmp_path +): + database_path = tmp_path / "my_database.sqlite" + database_path.touch() + + mocker.patch( + "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", + return_value=Namespace( + element=["volumeid", "pageid"], + verbose=False, + elements_list=None, + database=str(database_path), + dev=False, + ), + ) + + worker = ElementsWorker(support_cache=True) + with pytest.raises(AssertionError) as e: + worker.configure() + assert ( + str(e.value) + == f"The SQLite database {database_path} does not have the correct cache version, it should be {SQL_VERSION}" + ) + + def test_load_corpus_classes_api_error(responses, mock_elements_worker): corpus_id = "12341234-1234-1234-1234-123412341234" responses.add( diff --git a/tests/test_merge.py b/tests/test_merge.py index f02e32daf9e1af9196f29f034e3cde9dca194e56..687b7792526a9973e31f33da80a8651a1a87e730 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -5,12 +5,14 @@ import pytest from arkindex_worker.cache import ( MODELS, + SQL_VERSION, CachedClassification, CachedElement, CachedEntity, CachedImage, CachedTranscription, CachedTranscriptionEntity, + Version, merge_parents_cache, retrieve_parents_cache_path, ) @@ -196,3 +198,30 @@ def test_merge_from_worker( assert [t.id for t in CachedTranscription.select().order_by("id")] == [ UUID("11111111-1111-1111-1111-111111111111"), ] + + +def test_merge_conflicting_versions(mock_databases, tmpdir): + """ + Merging databases with differing versions should not be allowed + """ + + with mock_databases["second"]["db"].bind_ctx([Version]): + assert Version.get() == Version(version=SQL_VERSION) + fake_version = 420 + Version.update(version=fake_version).execute() + assert Version.get() == Version(version=fake_version) + + with mock_databases["target"]["db"].bind_ctx([Version]): + assert Version.get() == Version(version=SQL_VERSION) + + # Retrieve parents databases paths + paths = retrieve_parents_cache_path(["first", "second"], data_dir=tmpdir) + + # Merge all requested parents databases into our target, the "second" parent have a differing version + with pytest.raises(AssertionError) as e: + merge_parents_cache(paths, mock_databases["target"]["path"]) + + assert ( + str(e.value) + == f"The SQLite database {paths[1]} does not have the correct cache version, it should be {SQL_VERSION}" + )