Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (11)
Showing with 1560 additions and 250 deletions
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,pytest,requests,setuptools,sh,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,pytest,requests,setuptools,sh,tenacity,yaml
# -*- coding: utf-8 -*-
import json
import logging
import os
import sqlite3
from peewee import (
BooleanField,
CharField,
Field,
FloatField,
ForeignKeyField,
Model,
SqliteDatabase,
TextField,
UUIDField,
)
logger = logging.getLogger(__name__)
db = SqliteDatabase(None)
class JSONField(Field):
field_type = "text"
def db_value(self, value):
if value is None:
return
return json.dumps(value)
def python_value(self, value):
if value is None:
return
return json.loads(value)
class CachedElement(Model):
id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True)
type = CharField(max_length=50)
polygon = JSONField(null=True)
initial = BooleanField(default=False)
worker_version_id = UUIDField(null=True)
class Meta:
database = db
table_name = "elements"
class CachedTranscription(Model):
id = UUIDField(primary_key=True)
element_id = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField()
confidence = FloatField()
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "transcriptions"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [CachedElement, CachedTranscription]
def init_cache_db(path):
db.init(
path,
pragmas={
# SQLite ignores foreign keys and check constraints by default!
"foreign_keys": 1,
"ignore_check_constraints": 0,
},
)
db.connect()
logger.info(f"Connected to cache on {path}")
def create_tables():
"""
Creates the tables in the cache DB only if they do not already exist.
"""
db.create_tables(MODELS)
def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=None):
"""
Merge all the potential parent task's databases into the existing local one
"""
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
assert os.path.exists(current_database)
# Handle possible chunk in parent task name
# This is needed to support the init_elements databases
filenames = [
"db.sqlite",
]
if chunk is not None:
filenames.append(f"db_{chunk}.sqlite")
# Find all the paths for these databases
paths = list(
filter(
lambda p: os.path.isfile(p),
[
os.path.join(data_dir, parent, name)
for parent in parent_ids
for name in filenames
],
)
)
if not paths:
logger.info("No parents cache to use")
return
# Open a connection on current database
connection = sqlite3.connect(current_database)
cursor = connection.cursor()
# Merge each table into the local database
for idx, path in enumerate(paths):
logger.info(f"Merging parent db {path} into {current_database}")
statements = [
"PRAGMA page_size=80000;",
"PRAGMA synchronous=OFF;",
f"ATTACH DATABASE '{path}' AS source_{idx};",
f"REPLACE INTO elements SELECT * FROM source_{idx}.elements;",
f"REPLACE INTO transcriptions SELECT * FROM source_{idx}.transcriptions;",
]
for statement in statements:
cursor.execute(statement)
connection.commit()
# -*- coding: utf-8 -*-
import json
import traceback
import warnings
from collections import Counter
from datetime import datetime
......@@ -83,35 +82,12 @@ class Reporter(object):
)
element["classifications"] = dict(counter)
def add_transcription(self, element_id, type=None, type_count=None):
def add_transcription(self, element_id, count=1):
"""
Report creating a transcription on an element.
Multiple transcriptions with the same parent can be declared with the type_count parameter.
"""
if type_count is None:
if isinstance(type, int):
type_count, type = type, None
else:
type_count = 1
if type is not None:
warnings.warn(
"Transcription types have been deprecated and will be removed in the next release.",
FutureWarning,
)
self._get_element(element_id)["transcriptions"] += type_count
def add_transcriptions(self, element_id, transcriptions):
"""
Report one or more transcriptions at once.
"""
assert isinstance(transcriptions, list), "A list is required for transcriptions"
warnings.warn(
"Reporter.add_transcriptions is deprecated due to transcription types being removed. Please use Reporter.add_transcription(element_id, count) instead.",
FutureWarning,
)
self.add_transcription(element_id, len(transcriptions))
self._get_element(element_id)["transcriptions"] += count
def add_entity(self, element_id, entity_id, type, name):
"""
......
This diff is collapsed.
arkindex-client==1.0.6
peewee==3.14.4
Pillow==8.1.0
python-gitlab==2.6.0
python-gnupg==0.4.6
sh==1.14.1
tenacity==6.3.1
tenacity==7.0.0
......@@ -3,20 +3,34 @@ import hashlib
import json
import os
import sys
import time
from pathlib import Path
from uuid import UUID
import pytest
import yaml
from peewee import SqliteDatabase
from arkindex.mock import MockApiClient
from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
from arkindex_worker.git import GitHelper, GitlabHelper
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.worker import BaseWorker, ElementsWorker
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
__yaml_cache = {}
@pytest.fixture(autouse=True)
def disable_sleep(monkeypatch):
"""
Do not sleep at all in between API executions
when errors occur in unit tests.
This speeds up the test execution a lot
"""
monkeypatch.setattr(time, "sleep", lambda x: None)
@pytest.fixture
def cache_yaml(monkeypatch):
"""
......@@ -77,6 +91,14 @@ def setup_api(responses, monkeypatch, cache_yaml):
monkeypatch.setenv("ARKINDEX_API_TOKEN", "unittest1234")
@pytest.fixture(autouse=True)
def temp_working_directory(monkeypatch, tmp_path):
def _getcwd():
return str(tmp_path)
monkeypatch.setattr(os, "getcwd", _getcwd)
@pytest.fixture(autouse=True)
def give_worker_version_id_env_variable(monkeypatch):
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
......@@ -149,6 +171,26 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
return worker
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
"""Build a BaseWorker using SQLite cache, also mocking a TASK_ID"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True)
monkeypatch.setenv("TASK_ID", "my_task")
return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker(use_cache=True)
worker.configure()
return worker
@pytest.fixture
def fake_page_element():
with open(FIXTURES_DIR / "page_element.json", "r") as f:
......@@ -199,3 +241,127 @@ def fake_gitlab_helper_factory():
)
return run
@pytest.fixture
def mock_cached_elements():
"""Insert few elements in local cache"""
CachedElement.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
parent_id="12341234-1234-1234-1234-123412341234",
type="something",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
@pytest.fixture
def mock_cached_transcriptions():
"""Insert few transcriptions in local cache, on a shared element"""
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello!",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="How are you?",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
@pytest.fixture(scope="function")
def mock_databases(tmpdir):
"""
Initialize several temporary databases
to help testing the merge algorithm
"""
out = {}
for name in ("target", "first", "second", "conflict", "chunk_42"):
# Build a local database in sub directory
# for each name required
filename = "db_42.sqlite" if name == "chunk_42" else "db.sqlite"
path = tmpdir / name / filename
(tmpdir / name).mkdir()
local_db = SqliteDatabase(path)
with local_db.bind_ctx(MODELS):
# Create tables on the current local database
# by binding temporarily the models on that database
local_db.create_tables(MODELS)
out[name] = {"path": path, "db": local_db}
# Add an element in first parent database
with out["first"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("56785678-5678-5678-5678-567856785678"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add another element with a transcription in second parent database
with out["second"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello!",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add a conflicting element
with out["conflict"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
initial=True,
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("42424242-4242-4242-4242-424242424242"),
text="Hello again neighbor !",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
# Add an element in chunk parent database
with out["chunk_42"]["db"].bind_ctx(MODELS):
CachedElement.create(
id=UUID("42424242-4242-4242-4242-424242424242"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
initial=True,
)
return out
File added
......@@ -12,7 +12,7 @@ from arkindex_worker import logger
from arkindex_worker.worker import BaseWorker
def test_init_default_local_share():
def test_init_default_local_share(monkeypatch):
worker = BaseWorker()
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
......@@ -28,6 +28,14 @@ def test_init_default_xdg_data_home(monkeypatch):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
def test_init_with_local_cache(monkeypatch):
worker = BaseWorker(use_cache=True)
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.use_cache is True
def test_init_var_ponos_data_given(monkeypatch):
path = str(Path(__file__).absolute().parent)
monkeypatch.setenv("PONOS_DATA", path)
......
# -*- coding: utf-8 -*-
import os
import pytest
from peewee import OperationalError
from arkindex_worker.cache import create_tables, db, init_cache_db
def test_init_non_existent_path():
with pytest.raises(OperationalError) as e:
init_cache_db("path/not/found.sqlite")
assert str(e.value) == "unable to open database file"
def test_init(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
assert os.path.isfile(db_path)
def test_create_tables_existing_table(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
# Create the tables once…
init_cache_db(db_path)
create_tables()
db.close()
with open(db_path, "rb") as before_file:
before = before_file.read()
# Create them again
init_cache_db(db_path)
create_tables()
with open(db_path, "rb") as after_file:
after = after_file.read()
assert before == after, "Existing table structure was modified"
def test_create_tables(tmp_path):
db_path = f"{tmp_path}/db.sqlite"
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT)
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
[
row[0]
for row in db.connection()
.execute("SELECT sql FROM sqlite_master WHERE type = 'table' ORDER BY name")
.fetchall()
]
)
assert expected_schema == actual_schema
......@@ -362,10 +362,15 @@ def test_create_classification_api_error(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
]
......
......@@ -3,10 +3,12 @@ import json
import os
import tempfile
from argparse import Namespace
from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker
......@@ -344,10 +346,15 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
]
......@@ -629,15 +636,20 @@ def test_create_elements_api_error(responses, mock_elements_worker):
],
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
]
def test_create_elements(responses, mock_elements_worker):
def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -646,7 +658,7 @@ def test_create_elements(responses, mock_elements_worker):
json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
)
created_ids = mock_elements_worker.create_elements(
created_ids = mock_elements_worker_with_cache.create_elements(
parent=elt,
elements=[
{
......@@ -675,6 +687,66 @@ def test_create_elements(responses, mock_elements_worker):
}
assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
# Check that created elements were properly stored in SQLite cache
assert os.path.isfile(tmp_path / "db.sqlite")
assert list(CachedElement.select()) == [
CachedElement(
id=UUID("497f6eca-6276-4993-bfeb-53cbbbba6f08"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_elements_integrity_error(
responses, mock_elements_worker_with_cache, caplog
):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
status=200,
json=[
# Duplicate IDs, which will cause an IntegrityError when stored in the cache
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
],
)
elements = [
{
"name": "0",
"type": "something",
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
},
{
"name": "1",
"type": "something",
"polygon": [[1, 1], [3, 3], [3, 1], [1, 3]],
},
]
created_ids = mock_elements_worker_with_cache.create_elements(
parent=elt,
elements=elements,
)
assert created_ids == [
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
]
assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"
assert caplog.records[0].message.startswith(
"Couldn't save created elements in local cache:"
)
assert list(CachedElement.select()) == []
def test_list_element_children_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
......@@ -881,3 +953,86 @@ def test_list_element_children(responses, mock_elements_worker):
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
]
def test_list_element_children_with_cache_unhandled_param(
mock_elements_worker_with_cache,
):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker_with_cache.list_element_children(
element=elt, with_corpus=True
)
assert (
str(e.value)
== "When using the local cache, you can only filter by 'type' and/or 'worker_version'"
)
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give all elements inserted
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
},
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element and page should give the second element
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"type": "page",
},
("22222222-2222-2222-2222-222222222222",),
),
# Filter on element and worker version should give all elements
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"worker_version": "56785678-5678-5678-5678-567856785678",
},
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element, type something and worker version should give first
(
{
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"type": "something",
"worker_version": "56785678-5678-5678-5678-567856785678",
},
("11111111-1111-1111-1111-111111111111",),
),
),
)
def test_list_element_children_with_cache(
responses,
mock_elements_worker_with_cache,
mock_cached_elements,
filters,
expected_ids,
):
# Check we have 2 elements already present in database
assert CachedElement.select().count() == 2
# Query database through cache
elements = mock_elements_worker_with_cache.list_element_children(**filters)
assert elements.count() == len(expected_ids)
for child, expected_id in zip(elements.order_by("id"), expected_ids):
assert child.id == UUID(expected_id)
# Check the worker never hits the API for elements
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]
......@@ -147,10 +147,15 @@ def test_create_entity_api_error(responses, mock_elements_worker):
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
]
......
......@@ -133,10 +133,15 @@ def test_create_metadata_api_error(responses, mock_elements_worker):
value="La Turbine, Grenoble 38000",
)
assert len(responses.calls) == 3
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
]
......
# -*- coding: utf-8 -*-
from uuid import UUID
import pytest
from arkindex_worker.cache import (
MODELS,
CachedElement,
CachedTranscription,
merge_parents_cache,
)
@pytest.mark.parametrize(
"parents, expected_elements, expected_transcriptions",
(
# Nothing happen when no parents are available
([], [], []),
# Nothing happen when the parent file does not exist
(
[
"missing",
],
[],
[],
),
# When one parent is available, its data is reused
(
[
"first",
],
[
UUID("12341234-1234-1234-1234-123412341234"),
UUID("56785678-5678-5678-5678-567856785678"),
],
[],
),
# When 2 parents are available, their data is merged
(
[
"first",
"second",
],
[
UUID("12341234-1234-1234-1234-123412341234"),
UUID("56785678-5678-5678-5678-567856785678"),
UUID("42424242-4242-4242-4242-424242424242"),
],
[
UUID("11111111-1111-1111-1111-111111111111"),
],
),
# When N parents are available, their data is merged, and conflicts are supported
(
[
"first",
"second",
"conflict",
],
[
UUID("12341234-1234-1234-1234-123412341234"),
UUID("56785678-5678-5678-5678-567856785678"),
UUID("42424242-4242-4242-4242-424242424242"),
],
[
UUID("11111111-1111-1111-1111-111111111111"),
UUID("22222222-2222-2222-2222-222222222222"),
],
),
),
)
def test_merge_databases(
mock_databases, tmpdir, parents, expected_elements, expected_transcriptions
):
"""Test multiple database merge scenarios"""
# We always start with an empty database
with mock_databases["target"]["db"].bind_ctx(MODELS):
assert CachedElement.select().count() == 0
assert CachedTranscription.select().count() == 0
# Merge all requested parents databases into our target
merge_parents_cache(
parents,
mock_databases["target"]["path"],
data_dir=tmpdir,
)
# The target now should have the expected elements and transcriptions
with mock_databases["target"]["db"].bind_ctx(MODELS):
assert CachedElement.select().count() == len(expected_elements)
assert CachedTranscription.select().count() == len(expected_transcriptions)
assert [
e.id for e in CachedElement.select().order_by("id")
] == expected_elements
assert [
t.id for t in CachedTranscription.select().order_by("id")
] == expected_transcriptions
def test_merge_chunk(mock_databases, tmpdir, monkeypatch):
"""
Check the db merge algorithm support two parents
and one of them has a chunk
"""
# At first we have nothing in target
with mock_databases["target"]["db"].bind_ctx(MODELS):
assert CachedElement.select().count() == 0
assert CachedTranscription.select().count() == 0
# Check filenames
assert mock_databases["chunk_42"]["path"].basename == "db_42.sqlite"
assert mock_databases["second"]["path"].basename == "db.sqlite"
merge_parents_cache(
[
"chunk_42",
"first",
],
mock_databases["target"]["path"],
data_dir=tmpdir,
chunk="42",
)
# The target should now have 3 elements and 0 transcription
with mock_databases["target"]["db"].bind_ctx(MODELS):
assert CachedElement.select().count() == 3
assert CachedTranscription.select().count() == 0
assert [e.id for e in CachedElement.select().order_by("id")] == [
UUID("42424242-4242-4242-4242-424242424242"),
UUID("12341234-1234-1234-1234-123412341234"),
UUID("56785678-5678-5678-5678-567856785678"),
]
def test_merge_from_worker(
responses, mock_base_worker_with_cache, mock_databases, tmpdir, monkeypatch
):
"""
High level merge from the base worker
"""
responses.add(
responses.GET,
"http://testserver/ponos/v1/task/my_task/from-agent/",
status=200,
json={"parents": ["first", "second"]},
)
# At first we have no data in our main database
assert CachedElement.select().count() == 0
assert CachedTranscription.select().count() == 0
# Configure worker with a specific data directory
monkeypatch.setenv("PONOS_DATA_DIR", str(tmpdir))
mock_base_worker_with_cache.configure()
# Then we have 2 elements and a transcription
assert CachedElement.select().count() == 3
assert CachedTranscription.select().count() == 1
assert [e.id for e in CachedElement.select().order_by("id")] == [
UUID("12341234-1234-1234-1234-123412341234"),
UUID("56785678-5678-5678-5678-567856785678"),
UUID("42424242-4242-4242-4242-424242424242"),
]
assert [t.id for t in CachedTranscription.select().order_by("id")] == [
UUID("11111111-1111-1111-1111-111111111111"),
]
......@@ -126,36 +126,12 @@ def test_add_transcription():
}
def test_add_transcription_warning():
reporter = Reporter("worker")
with pytest.warns(FutureWarning) as w:
reporter.add_transcription("myelement", "word")
assert len(w) == 1
assert (
w[0].message.args[0]
== "Transcription types have been deprecated and will be removed in the next release."
)
assert "myelement" in reporter.report_data["elements"]
element_data = reporter.report_data["elements"]["myelement"]
del element_data["started"]
assert element_data == {
"elements": {},
"transcriptions": 1,
"classifications": {},
"entities": [],
"metadata": [],
"errors": [],
}
def test_add_transcription_count():
"""
Report multiple transcriptions with the same element and type
"""
reporter = Reporter("worker")
reporter.add_transcription("myelement", type_count=1337)
reporter.add_transcription("myelement", 1337)
assert "myelement" in reporter.report_data["elements"]
element_data = reporter.report_data["elements"]["myelement"]
del element_data["started"]
......@@ -169,39 +145,6 @@ def test_add_transcription_count():
}
def test_add_transcriptions():
reporter = Reporter("worker")
with pytest.raises(AssertionError):
reporter.add_transcriptions("myelement", {"not": "a list"})
with pytest.warns(FutureWarning) as w:
reporter.add_transcriptions("myelement", [{"type": "word"}, {"type": "line"}])
reporter.add_transcriptions(
"myelement",
[
{"type": "word"},
{"type": "line", "text": "something"},
{"type": "word", "confidence": 0.42},
],
)
assert len(w) == 2
assert set(warning.message.args[0] for warning in w) == {
"Reporter.add_transcriptions is deprecated due to transcription types being removed. Please use Reporter.add_transcription(element_id, count) instead."
}
assert "myelement" in reporter.report_data["elements"]
element_data = reporter.report_data["elements"]["myelement"]
del element_data["started"]
assert element_data == {
"elements": {},
"transcriptions": 5,
"classifications": {},
"entities": [],
"metadata": [],
"errors": [],
}
def test_add_entity():
reporter = Reporter("worker")
reporter.add_entity(
......