Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (6)
0.2.3-rc4
0.2.3-rc5
......@@ -13,6 +13,7 @@ from peewee import (
ForeignKeyField,
IntegerField,
Model,
OperationalError,
SqliteDatabase,
TextField,
UUIDField,
......@@ -38,6 +39,14 @@ class JSONField(Field):
return json.loads(value)
class Version(Model):
version = IntegerField(primary_key=True)
class Meta:
database = db
table_name = "version"
class CachedImage(Model):
id = UUIDField(primary_key=True)
width = IntegerField()
......@@ -184,6 +193,7 @@ MODELS = [
CachedEntity,
CachedTranscriptionEntity,
]
SQL_VERSION = 1
def init_cache_db(path):
......@@ -206,6 +216,30 @@ def create_tables():
db.create_tables(MODELS)
def create_version_table():
"""
Creates the Version table in the cache DB.
This step must be independent from other tables creation since we only
want to create the table and add the one and only Version entry when the
cache is created from scratch.
"""
db.create_tables([Version])
Version.create(version=SQL_VERSION)
def check_version(cache_path):
with SqliteDatabase(cache_path) as provided_db:
with provided_db.bind_ctx([Version]):
try:
version = Version.get().version
except OperationalError:
version = None
assert (
version == SQL_VERSION
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
......@@ -247,6 +281,9 @@ def merge_parents_cache(paths, current_database):
# Merge each table into the local database
for idx, path in enumerate(paths):
# Check that the parent cache uses a compatible version
check_version(path)
with SqliteDatabase(path) as source:
with source.bind_ctx(MODELS):
source.create_tables(MODELS)
......
......@@ -19,7 +19,9 @@ from tenacity import (
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import (
check_version,
create_tables,
create_version_table,
init_cache_db,
merge_parents_cache,
retrieve_parents_cache_path,
......@@ -203,6 +205,12 @@ class BaseWorker(object):
self.cache_path = os.path.join(cache_dir, "db.sqlite")
init_cache_db(self.cache_path)
if self.args.database is not None:
check_version(self.cache_path)
else:
create_version_table()
create_tables()
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
......
......@@ -18,7 +18,11 @@ class TextOrientation(Enum):
class TranscriptionMixin(object):
def create_transcription(
self, element, text, score, orientation=TextOrientation.HorizontalLeftToRight
self,
element,
text,
confidence,
orientation=TextOrientation.HorizontalLeftToRight,
):
"""
Create a transcription on the given element through the API.
......@@ -33,8 +37,8 @@ class TranscriptionMixin(object):
orientation, TextOrientation
), "orientation shouldn't be null and should be of type TextOrientation"
assert (
isinstance(score, float) and 0 <= score <= 1
), "score shouldn't be null and should be a float in [0..1] range"
isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence shouldn't be null and should be a float in [0..1] range"
if self.is_read_only:
logger.warning(
......@@ -48,7 +52,7 @@ class TranscriptionMixin(object):
body={
"text": text,
"worker_version": self.worker_version_id,
"score": score,
"confidence": confidence,
"orientation": orientation.value,
},
)
......@@ -99,10 +103,12 @@ class TranscriptionMixin(object):
text, str
), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str"
score = transcription.get("score")
confidence = transcription.get("confidence")
assert (
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
confidence is not None
and isinstance(confidence, float)
and 0 <= confidence <= 1
), f"Transcription at index {index} in transcriptions: confidence shouldn't be null and should be a float in [0..1] range"
orientation = transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
......@@ -169,10 +175,12 @@ class TranscriptionMixin(object):
text, str
), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str"
score = transcription.get("score")
confidence = transcription.get("confidence")
assert (
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
confidence is not None
and isinstance(confidence, float)
and 0 <= confidence <= 1
), f"Transcription at index {index} in transcriptions: confidence shouldn't be null and should be a float in [0..1] range"
orientation = transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
......@@ -255,7 +263,7 @@ class TranscriptionMixin(object):
"id": annotation["id"],
"element_id": annotation["element_id"],
"text": transcription["text"],
"confidence": transcription["score"],
"confidence": transcription["confidence"],
"orientation": transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
).value,
......
arkindex-client==1.0.8
peewee==3.14.4
peewee==3.14.10
Pillow==9.0.1
python-gitlab==2.7.1
python-gnupg==0.4.7
python-gnupg==0.4.8
sh==1.14.2
tenacity==8.0.1
pytest==7.0.0
pytest==7.1.1
pytest-mock==3.7.0
pytest-responses==0.5.0
......@@ -12,7 +12,15 @@ import yaml
from peewee import SqliteDatabase
from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.cache import (
MODELS,
SQL_VERSION,
CachedElement,
CachedTranscription,
Version,
create_version_table,
init_cache_db,
)
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
......@@ -260,7 +268,8 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
cache_path = tmp_path / "db.sqlite"
cache_path.touch()
init_cache_db(cache_path)
create_version_table()
monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)])
worker = ElementsWorker(support_cache=True)
......@@ -433,9 +442,11 @@ def mock_databases(tmpdir):
path = tmpdir / name / filename
(tmpdir / name).mkdir()
local_db = SqliteDatabase(path)
with local_db.bind_ctx(MODELS):
with local_db.bind_ctx(MODELS + [Version]):
# Create tables on the current local database
# by binding temporarily the models on that database
local_db.create_tables([Version])
Version.create(version=SQL_VERSION)
local_db.create_tables(MODELS)
out[name] = {"path": path, "db": local_db}
......
......@@ -8,7 +8,7 @@
"id": "008691ae-8133-48c4-88d5-d4cc9f65c06c",
"type": "line",
"text": "J . Caron &",
"score": 0.4781,
"confidence": 0.4781,
"zone": null,
"source": null,
"worker_version_id": "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996",
......
......@@ -6,9 +6,13 @@ import pytest
from peewee import OperationalError
from arkindex_worker.cache import (
SQL_VERSION,
CachedElement,
CachedImage,
Version,
check_version,
create_tables,
create_version_table,
db,
init_cache_db,
)
......@@ -73,6 +77,76 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
assert expected_schema == actual_schema
def test_create_version_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_version_table()
expected_schema = 'CREATE TABLE "version" ("version" INTEGER NOT NULL PRIMARY KEY)'
actual_schema = "\n".join(
[
row[0]
for row in db.connection()
.execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name")
.fetchall()
]
)
assert expected_schema == actual_schema
assert Version.select().count() == 1
assert Version.get() == Version(version=SQL_VERSION)
def test_check_version_unset_version(tmp_path):
"""
The cache misses the version table
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
with pytest.raises(AssertionError) as e:
check_version(db_path)
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_differing_version(tmp_path):
"""
The cache has a differing version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
fake_version = 420
assert fake_version != SQL_VERSION
db.create_tables([Version])
Version.create(version=fake_version)
db.close()
with pytest.raises(AssertionError) as e:
check_version(db_path)
assert (
str(e.value)
== f"The SQLite database {db_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_check_version_same_version(tmp_path):
"""
The cache has the expected version
"""
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_version_table()
db.close()
check_version(db_path)
@pytest.mark.parametrize(
"image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_size,expected_url",
[
......
......@@ -6,7 +6,13 @@ from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.cache import (
SQL_VERSION,
CachedElement,
CachedImage,
create_version_table,
init_cache_db,
)
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker.element import MissingTypeError
......@@ -177,7 +183,8 @@ def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path):
def test_database_arg(mocker, mock_elements_worker, tmp_path):
database_path = tmp_path / "my_database.sqlite"
database_path.touch()
init_cache_db(database_path)
create_version_table()
mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
......@@ -197,6 +204,32 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path):
assert worker.cache_path == str(database_path)
def test_database_arg_cache_missing_version_table(
mocker, mock_elements_worker, tmp_path
):
database_path = tmp_path / "my_database.sqlite"
database_path.touch()
mocker.patch(
"arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
return_value=Namespace(
element=["volumeid", "pageid"],
verbose=False,
elements_list=None,
database=str(database_path),
dev=False,
),
)
worker = ElementsWorker(support_cache=True)
with pytest.raises(AssertionError) as e:
worker.configure()
assert (
str(e.value)
== f"The SQLite database {database_path} does not have the correct cache version, it should be {SQL_VERSION}"
)
def test_load_corpus_classes_api_error(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
responses.add(
......
......@@ -5,12 +5,14 @@ import pytest
from arkindex_worker.cache import (
MODELS,
SQL_VERSION,
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
Version,
merge_parents_cache,
retrieve_parents_cache_path,
)
......@@ -196,3 +198,30 @@ def test_merge_from_worker(
assert [t.id for t in CachedTranscription.select().order_by("id")] == [
UUID("11111111-1111-1111-1111-111111111111"),
]
def test_merge_conflicting_versions(mock_databases, tmpdir):
"""
Merging databases with differing versions should not be allowed
"""
with mock_databases["second"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
fake_version = 420
Version.update(version=fake_version).execute()
assert Version.get() == Version(version=fake_version)
with mock_databases["target"]["db"].bind_ctx([Version]):
assert Version.get() == Version(version=SQL_VERSION)
# Retrieve parents databases paths
paths = retrieve_parents_cache_path(["first", "second"], data_dir=tmpdir)
# Merge all requested parents databases into our target, the "second" parent have a differing version
with pytest.raises(AssertionError) as e:
merge_parents_cache(paths, mock_databases["target"]["path"])
assert (
str(e.value)
== f"The SQLite database {paths[1]} does not have the correct cache version, it should be {SQL_VERSION}"
)