From 3347ded3d228e4df1d3649c4ad7ddb6f8f076bcb Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Mon, 14 Mar 2022 17:31:11 +0100
Subject: [PATCH] Add Version table in SQLite cache + Check compatibility from
 tasks

---
 arkindex_worker/cache.py       | 12 ++++++++
 arkindex_worker/worker/base.py |  2 ++
 tests/test_cache.py            | 51 +++++++++++++++++++++++++++++++++-
 3 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index b1be02e2..7be518c7 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -184,6 +184,9 @@ MODELS = [
     CachedEntity,
     CachedTranscriptionEntity,
 ]
+SQL_VERSION = 1
+SQL_VERSION_TABLE = "CREATE TABLE IF NOT EXISTS version AS SELECT ? AS version"
+SQL_VERSION_QUERY = "SELECT version FROM version"
 
 
 def init_cache_db(path):
@@ -204,6 +207,15 @@ def create_tables():
     Creates the tables in the cache DB only if they do not already exist.
     """
     db.create_tables(MODELS)
+    # If the version table already exists (e.g from tasks) nothing will be added
+    db.execute_sql(SQL_VERSION_TABLE, (SQL_VERSION,))
+
+
+def check_version():
+    db_results = db.execute_sql(SQL_VERSION_QUERY).fetchall()
+    assert (
+        len(db_results) == 1 and db_results[0][0] == SQL_VERSION
+    ), "The SQLite database does not have the correct cache version"
 
 
 def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py
index 4ef07c09..e5b2b0bb 100644
--- a/arkindex_worker/worker/base.py
+++ b/arkindex_worker/worker/base.py
@@ -19,6 +19,7 @@ from tenacity import (
 from arkindex import ArkindexClient, options_from_env
 from arkindex_worker import logger
 from arkindex_worker.cache import (
+    check_version,
     create_tables,
     init_cache_db,
     merge_parents_cache,
@@ -204,6 +205,7 @@ class BaseWorker(object):
 
             init_cache_db(self.cache_path)
             create_tables()
+            check_version()
 
             # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
             if self.args.database is None and paths is not None:
diff --git a/tests/test_cache.py b/tests/test_cache.py
index fcc7e8e7..fb2eb799 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -6,8 +6,11 @@ import pytest
 from peewee import OperationalError
 
 from arkindex_worker.cache import (
+    SQL_VERSION,
+    SQL_VERSION_TABLE,
     CachedElement,
     CachedImage,
+    check_version,
     create_tables,
     db,
     init_cache_db,
@@ -59,7 +62,8 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type
 CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
 CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
 CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
-CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
+CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
+CREATE TABLE version(version)"""
 
     actual_schema = "\n".join(
         [
@@ -73,6 +77,51 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
     assert expected_schema == actual_schema
 
 
+def test_check_version_unset_version(tmp_path):
+    """
+    No cache was provided, the check should always pass
+    """
+    db_path = f"{tmp_path}/db.sqlite"
+    init_cache_db(db_path)
+    create_tables()
+    check_version()
+
+
+def test_check_version_differing_version(tmp_path):
+    """
+    The provided cache has a differing version
+    """
+    db_path = f"{tmp_path}/db.sqlite"
+    init_cache_db(db_path)
+    # Populating the provided cache
+    fake_version = 420
+    assert fake_version != SQL_VERSION
+    db.execute_sql(SQL_VERSION_TABLE, (fake_version,))
+    db.close()
+
+    init_cache_db(db_path)
+    create_tables()
+    with pytest.raises(AssertionError) as e:
+        check_version()
+
+    assert str(e.value) == "The SQLite database does not have the correct cache version"
+
+
+def test_check_version_same_version(tmp_path):
+    """
+    The provied cache has the expected version
+    """
+    db_path = f"{tmp_path}/db.sqlite"
+    init_cache_db(db_path)
+    # Populating the provided cache
+    db.execute_sql(SQL_VERSION_TABLE, (SQL_VERSION,))
+    db.close()
+
+    init_cache_db(db_path)
+    create_tables()
+    check_version()
+
+
 @pytest.mark.parametrize(
     "image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url",
     [
-- 
GitLab