Skip to content
Snippets Groups Projects
Commit bce57a31 authored by Eva Bardou's avatar Eva Bardou
Browse files

Fix some review related issues

parent 3347ded3
No related branches found
No related tags found
No related merge requests found
Pipeline #79042 passed
......@@ -38,6 +38,14 @@ class JSONField(Field):
return json.loads(value)
class Version(Model):
version = IntegerField(primary_key=True)
class Meta:
database = db
table_name = "version"
class CachedImage(Model):
id = UUIDField(primary_key=True)
width = IntegerField()
......@@ -185,7 +193,6 @@ MODELS = [
CachedTranscriptionEntity,
]
SQL_VERSION = 1
SQL_VERSION_TABLE = "CREATE TABLE IF NOT EXISTS version AS SELECT ? AS version"
SQL_VERSION_QUERY = "SELECT version FROM version"
......@@ -207,15 +214,28 @@ def create_tables():
Creates the tables in the cache DB only if they do not already exist.
"""
db.create_tables(MODELS)
# If the version table already exists (e.g from tasks) nothing will be added
db.execute_sql(SQL_VERSION_TABLE, (SQL_VERSION,))
def check_version():
db_results = db.execute_sql(SQL_VERSION_QUERY).fetchall()
def create_version_table():
db.create_tables([Version])
Version.create(version=SQL_VERSION)
def check_version(cache_path):
provided_db = sqlite3.connect(cache_path)
provided_db.row_factory = sqlite3.Row
cursor = provided_db.cursor()
try:
db_results = cursor.execute(SQL_VERSION_QUERY).fetchall()
except sqlite3.OperationalError:
db_results = []
assert (
len(db_results) == 1 and db_results[0][0] == SQL_VERSION
), "The SQLite database does not have the correct cache version"
len(db_results) == 1 and db_results[0]["version"] == SQL_VERSION
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
provided_db.close()
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
......@@ -259,6 +279,9 @@ def merge_parents_cache(paths, current_database):
# Merge each table into the local database
for idx, path in enumerate(paths):
# Check that the parent cache uses a compatible version
check_version(path)
with SqliteDatabase(path) as source:
with source.bind_ctx(MODELS):
source.create_tables(MODELS)
......
......@@ -21,6 +21,7 @@ from arkindex_worker import logger
from arkindex_worker.cache import (
check_version,
create_tables,
create_version_table,
init_cache_db,
merge_parents_cache,
retrieve_parents_cache_path,
......@@ -204,8 +205,13 @@ class BaseWorker(object):
self.cache_path = os.path.join(cache_dir, "db.sqlite")
init_cache_db(self.cache_path)
if self.args.database is not None:
check_version(self.cache_path)
else:
create_version_table()
create_tables()
check_version()
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
if self.args.database is None and paths is not None:
......
......@@ -12,7 +12,15 @@ import yaml
from peewee import SqliteDatabase
from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.cache import (
MODELS,
SQL_VERSION,
CachedElement,
CachedTranscription,
Version,
create_version_table,
init_cache_db,
)
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
......@@ -260,7 +268,8 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
cache_path = tmp_path / "db.sqlite"
cache_path.touch()
init_cache_db(cache_path)
create_version_table()
monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)])
worker = ElementsWorker(support_cache=True)
......@@ -433,9 +442,11 @@ def mock_databases(tmpdir):
path = tmpdir / name / filename
(tmpdir / name).mkdir()
local_db = SqliteDatabase(path)
with local_db.bind_ctx(MODELS):
with local_db.bind_ctx(MODELS + [Version]):
# Create tables on the current local database
# by binding temporarily the models on that database
local_db.create_tables([Version])
Version.create(version=SQL_VERSION)
local_db.create_tables(MODELS)
out[name] = {"path": path, "db": local_db}
......
......@@ -7,11 +7,12 @@ from peewee import OperationalError
from arkindex_worker.cache import (
SQL_VERSION,
SQL_VERSION_TABLE,
CachedElement,
CachedImage,
Version,
check_version,
create_tables,
create_version_table,
db,
init_cache_db,
)
......@@ -62,8 +63,7 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE version(version)"""
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
[
......@@ -77,49 +77,74 @@ CREATE TABLE version(version)"""
assert expected_schema == actual_schema
def test_create_version_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_version_table()
expected_schema = 'CREATE TABLE "version" ("version" INTEGER NOT NULL PRIMARY KEY)'
actual_schema = "\n".join(
[
row[0]
for row in db.connection()
.execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name")
.fetchall()
]
)
assert expected_schema == actual_schema
assert Version.select().count() == 1
assert Version.get() == Version(version=SQL_VERSION)
def test_check_version_unset_version(tmp_path):
"""
No cache was provided, the check should always pass
The cache misses the version table
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_tables()
check_version()
with pytest.raises(AssertionError) as e:
check_version(db_path)
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_differing_version(tmp_path):
"""
The provided cache has a differing version
The cache has a differing version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
# Populating the provided cache
fake_version = 420
assert fake_version != SQL_VERSION
db.execute_sql(SQL_VERSION_TABLE, (fake_version,))
db.create_tables([Version])
Version.create(version=fake_version)
db.close()
init_cache_db(db_path)
create_tables()
with pytest.raises(AssertionError) as e:
check_version()
check_version(db_path)
assert str(e.value) == "The SQLite database does not have the correct cache version"
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_same_version(tmp_path):
"""
The provied cache has the expected version
The cache has the expected version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
# Populating the provided cache
db.execute_sql(SQL_VERSION_TABLE, (SQL_VERSION,))
create_version_table()
db.close()
init_cache_db(db_path)
create_tables()
check_version()
check_version(db_path)
@pytest.mark.parametrize(
......
......@@ -6,7 +6,13 @@ from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.cache import (
SQL_VERSION,
CachedElement,
CachedImage,
create_version_table,
init_cache_db,
)
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker.element import MissingTypeError
......@@ -177,7 +183,8 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path):
def test_database_arg(mocker, mock_elements_worker, tmp_path):
database_path = tmp_path / "my_database.sqlite"
database_path.touch()
init_cache_db(database_path)
create_version_table()
mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
......@@ -197,6 +204,32 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path):
assert worker.cache_path == str(database_path)
def test_database_arg_cache_missing_version_table(
mocker, mock_elements_worker, tmp_path
):
database_path = tmp_path / "my_database.sqlite"
database_path.touch()
mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
return_value=Namespace(
element=["volumeid", "pageid"],
verbose=False,
elements_list=None,
database=str(database_path),
dev=False,
),
)
worker = ElementsWorker(support_cache=True)
with pytest.raises(AssertionError) as e:
worker.configure()
assert (
str(e.value)
== f"The SQLite database {database_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_load_corpus_classes_api_error(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
responses.add(
......
......@@ -5,12 +5,14 @@ import pytest
from arkindex_worker.cache import (
MODELS,
SQL_VERSION,
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
Version,
merge_parents_cache,
retrieve_parents_cache_path,
)
......@@ -196,3 +198,30 @@ def test_merge_from_worker(
assert [t.id for t in CachedTranscription.select().order_by("id")] == [
UUID("11111111-1111-1111-1111-111111111111"),
]
def test_merge_conflicting_versions(mock_databases, tmpdir):
"""
Merging databases with differing versions should not be allowed
"""
with mock_databases["second"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
fake_version = 420
Version.update(version=fake_version).execute()
assert Version.get() == Version(version=fake_version)
with mock_databases["target"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
# Retrieve parents databases paths
paths = retrieve_parents_cache_path(["first", "second"], data_dir=tmpdir)
# Merge all requested parents databases into our target, the "second" parent have a differing version
with pytest.raises(AssertionError) as e:
merge_parents_cache(paths, mock_databases["target"]["path"])
assert (
str(e.value)
== f"The SQLite database {paths[1]} does not have the correct cache version, it should be {SQL_VERSION}"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment