Skip to content
Snippets Groups Projects
Commit 3347ded3 authored by Eva Bardou's avatar Eva Bardou
Browse files

Add Version table in SQLite cache + Check compatibility from tasks

parent 6b7c21f1
No related branches found
No related tags found
No related merge requests found
Pipeline #79033 passed
...@@ -184,6 +184,9 @@ MODELS = [ ...@@ -184,6 +184,9 @@ MODELS = [
CachedEntity, CachedEntity,
CachedTranscriptionEntity, CachedTranscriptionEntity,
] ]
SQL_VERSION = 1
SQL_VERSION_TABLE = "CREATE TABLE IF NOT EXISTS version AS SELECT ? AS version"
SQL_VERSION_QUERY = "SELECT version FROM version"
def init_cache_db(path): def init_cache_db(path):
...@@ -204,6 +207,15 @@ def create_tables(): ...@@ -204,6 +207,15 @@ def create_tables():
Creates the tables in the cache DB only if they do not already exist. Creates the tables in the cache DB only if they do not already exist.
""" """
db.create_tables(MODELS) db.create_tables(MODELS)
# If the version table already exists (e.g from tasks) nothing will be added
db.execute_sql(SQL_VERSION_TABLE, (SQL_VERSION,))
def check_version():
db_results = db.execute_sql(SQL_VERSION_QUERY).fetchall()
assert (
len(db_results) == 1 and db_results[0][0] == SQL_VERSION
), "The SQLite database does not have the correct cache version"
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None): def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
......
...@@ -19,6 +19,7 @@ from tenacity import ( ...@@ -19,6 +19,7 @@ from tenacity import (
from arkindex import ArkindexClient, options_from_env from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import ( from arkindex_worker.cache import (
check_version,
create_tables, create_tables,
init_cache_db, init_cache_db,
merge_parents_cache, merge_parents_cache,
...@@ -204,6 +205,7 @@ class BaseWorker(object): ...@@ -204,6 +205,7 @@ class BaseWorker(object):
init_cache_db(self.cache_path) init_cache_db(self.cache_path)
create_tables() create_tables()
check_version()
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden # Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
if self.args.database is None and paths is not None: if self.args.database is None and paths is not None:
......
...@@ -6,8 +6,11 @@ import pytest ...@@ -6,8 +6,11 @@ import pytest
from peewee import OperationalError from peewee import OperationalError
from arkindex_worker.cache import ( from arkindex_worker.cache import (
SQL_VERSION,
SQL_VERSION_TABLE,
CachedElement, CachedElement,
CachedImage, CachedImage,
check_version,
create_tables, create_tables,
db, db,
init_cache_db, init_cache_db,
...@@ -59,7 +62,8 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type ...@@ -59,7 +62,8 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE version(version)"""
actual_schema = "\n".join( actual_schema = "\n".join(
[ [
...@@ -73,6 +77,51 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT ...@@ -73,6 +77,51 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
assert expected_schema == actual_schema assert expected_schema == actual_schema
def test_check_version_unset_version(tmp_path):
"""
No cache was provided, the check should always pass
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_tables()
check_version()
def test_check_version_differing_version(tmp_path):
"""
The provided cache has a differing version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
# Populating the provided cache
fake_version = 420
assert fake_version != SQL_VERSION
db.execute_sql(SQL_VERSION_TABLE, (fake_version,))
db.close()
init_cache_db(db_path)
create_tables()
with pytest.raises(AssertionError) as e:
check_version()
assert str(e.value) == "The SQLite database does not have the correct cache version"
def test_check_version_same_version(tmp_path):
"""
The provied cache has the expected version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
# Populating the provided cache
db.execute_sql(SQL_VERSION_TABLE, (SQL_VERSION,))
db.close()
init_cache_db(db_path)
create_tables()
check_version()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url", "image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url",
[ [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment