Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (25)
Showing
with 989 additions and 258 deletions
0.2.0-rc2
0.2.1-beta2
......@@ -6,6 +6,8 @@ import sqlite3
from peewee import (
BooleanField,
CharField,
Check,
CompositeKey,
Field,
FloatField,
ForeignKeyField,
......@@ -70,19 +72,28 @@ class CachedElement(Model):
"""
if not self.image_id or not self.polygon:
raise ValueError(f"Element {self.id} has no image")
# Always fetch the image from the bounding box when size differs from full image
bounding_box = polygon_bounding_box(self.polygon)
if (
bounding_box.width != self.image.width
or bounding_box.height != self.image.height
):
box = f"{bounding_box.x},{bounding_box.y},{bounding_box.x + bounding_box.width},{bounding_box.y + bounding_box.height}"
else:
box = "full"
if max_size is None:
resize = "full"
else:
bounding_box = polygon_bounding_box(self.polygon)
# Do not resize for polygons that do not exactly match the images
# as the resize is made directly by the IIIF server using the box parameter
if (
bounding_box.width != self.image.width
or bounding_box.height != self.image.height
):
resize = "full"
logger.warning(
"Only full size elements covered, downloading full size image"
)
# Do not resize when the image is below the maximum size
elif self.image.width <= max_size and self.image.height <= max_size:
resize = "full"
......@@ -97,7 +108,7 @@ class CachedElement(Model):
if not url.endswith("/"):
url += "/"
return open_image(f"{url}full/{resize}/0/default.jpg", *args, **kwargs)
return open_image(f"{url}{box}/{resize}/0/default.jpg", *args, **kwargs)
class CachedTranscription(Model):
......@@ -125,9 +136,43 @@ class CachedClassification(Model):
table_name = "classifications"
class CachedEntity(Model):
id = UUIDField(primary_key=True)
type = CharField(max_length=50)
name = TextField()
validated = BooleanField(default=False)
metas = JSONField(null=True)
worker_version_id = UUIDField()
class Meta:
database = db
table_name = "entities"
class CachedTranscriptionEntity(Model):
transcription = ForeignKeyField(
CachedTranscription, backref="transcription_entities"
)
entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
class Meta:
primary_key = CompositeKey("transcription", "entity")
database = db
table_name = "transcription_entities"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [CachedImage, CachedElement, CachedTranscription, CachedClassification]
MODELS = [
CachedImage,
CachedElement,
CachedTranscription,
CachedClassification,
CachedEntity,
CachedTranscriptionEntity,
]
def init_cache_db(path):
......@@ -150,13 +195,9 @@ def create_tables():
db.create_tables(MODELS)
def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=None):
"""
Merge all the potential parent task's databases into the existing local one
"""
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
assert os.path.exists(current_database)
# Handle possible chunk in parent task name
# This is needed to support the init_elements databases
......@@ -167,7 +208,7 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No
filenames.append(f"db_{chunk}.sqlite")
# Find all the paths for these databases
paths = list(
return list(
filter(
lambda p: os.path.isfile(p),
[
......@@ -178,6 +219,13 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No
)
)
def merge_parents_cache(paths, current_database):
"""
Merge all the potential parent task's databases into the existing local one
"""
assert os.path.exists(current_database)
if not paths:
logger.info("No parents cache to use")
return
......@@ -188,13 +236,19 @@ def merge_parents_cache(parent_ids, current_database, data_dir="/data", chunk=No
# Merge each table into the local database
for idx, path in enumerate(paths):
with SqliteDatabase(path) as source:
with source.bind_ctx(MODELS):
source.create_tables(MODELS)
logger.info(f"Merging parent db {path} into {current_database}")
statements = [
"PRAGMA page_size=80000;",
"PRAGMA synchronous=OFF;",
f"ATTACH DATABASE '{path}' AS source_{idx};",
f"REPLACE INTO images SELECT * FROM source_{idx}.images;",
f"REPLACE INTO elements SELECT * FROM source_{idx}.elements;",
f"REPLACE INTO transcriptions SELECT * FROM source_{idx}.transcriptions;",
f"REPLACE INTO classifications SELECT * FROM source_{idx}.classifications;",
]
for statement in statements:
......
......@@ -351,7 +351,10 @@ class GitHelper:
while keeping the same directory structure
"""
file_count = 0
for file in export_out_dir.rglob("*.*"):
file_names = [
file_name for file_name in export_out_dir.rglob("*") if file_name.is_file()
]
for file in file_names:
rel_file_path = file.relative_to(export_out_dir)
out_file = self.export_path / rel_file_path
if not out_file.exists():
......
......@@ -55,6 +55,14 @@ class Element(MagicDict):
Describes any kind of element.
"""
def resize_zone_url(self, size="full"):
if size == "full":
return self.zone.url
else:
parts = self.zone.url.split("/")
parts[-3] = size
return "/".join(parts)
def image_url(self, size="full"):
"""
When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers
......@@ -79,10 +87,16 @@ class Element(MagicDict):
bounding_box = polygon_bounding_box(self.zone.polygon)
return bounding_box.width > max_width or bounding_box.height > max_height
def open_image(self, *args, max_size=None, **kwargs):
def open_image(self, *args, max_size=None, use_full_image=False, **kwargs):
"""
Open this element's image as a Pillow image.
This does not crop the image to the element's polygon.
If use_full_image is False:
Use zone url, instead of the full image url to be able to get
single page from double page image.
Else:
This does not crop the image to the element's polygon.
:param max_size: Subresolution of the image.
"""
if not self.get("zone"):
......@@ -119,7 +133,10 @@ class Element(MagicDict):
resize = "full"
try:
return open_image(self.image_url(resize), *args, **kwargs)
if use_full_image:
return open_image(self.image_url(resize), *args, **kwargs)
else:
return open_image(self.resize_zone_url(resize), *args, **kwargs)
except HTTPError as e:
if (
self.zone.image.get("s3_url") is not None
......
......@@ -36,8 +36,8 @@ class ElementsWorker(
EntityMixin,
MetaDataMixin,
):
def __init__(self, description="Arkindex Elements Worker", use_cache=False):
super().__init__(description, use_cache)
def __init__(self, description="Arkindex Elements Worker", support_cache=False):
super().__init__(description, support_cache)
# Add report concerning elements
self.report = Reporter("unknown worker")
......@@ -84,6 +84,13 @@ class ElementsWorker(
return out
@property
def store_activity(self):
assert (
self.process_information
), "Worker must be configured to access its process activity state"
return self.process_information.get("activity_state") == "ready"
def run(self):
"""
Process every elements from the provided list
......@@ -97,6 +104,11 @@ class ElementsWorker(
logger.warning("No elements to process, stopping.")
sys.exit(1)
if not self.store_activity:
logger.info(
"No worker activity will be stored as it is disabled for this process"
)
# Process every element
count = len(elements)
failed = 0
......@@ -112,7 +124,7 @@ class ElementsWorker(
logger.info(f"Processing {element} ({i}/{count})")
# Report start of process, run process, then report end of process
# Process the element and report its progress if activities are enabled
self.update_activity(element.id, ActivityState.Started)
self.process_element(element)
self.update_activity(element.id, ActivityState.Processed)
......@@ -156,15 +168,17 @@ class ElementsWorker(
Update worker activity for this element
This method should not raise a runtime exception, but simply warn users
"""
if not self.store_activity:
logger.debug(
"Activity is not stored as the feature is disabled on this process"
)
return
assert element_id and isinstance(
element_id, (uuid.UUID, str)
), "element_id shouldn't be null and should be an UUID or str"
assert isinstance(state, ActivityState), "state should be an ActivityState"
if not self.features.get("workers_activity"):
logger.debug("Skipping Worker activity update as it's disabled on backend")
return
if self.is_read_only:
logger.warning("Cannot update activity as this worker is in read-only mode")
return
......@@ -175,6 +189,7 @@ class ElementsWorker(
id=self.worker_version_id,
body={
"element_id": str(element_id),
"process_id": self.process_information["id"],
"state": state.value,
},
)
......
......@@ -18,7 +18,12 @@ from tenacity import (
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import create_tables, init_cache_db, merge_parents_cache
from arkindex_worker.cache import (
create_tables,
init_cache_db,
merge_parents_cache,
retrieve_parents_cache_path,
)
def _is_500_error(exc):
......@@ -33,7 +38,7 @@ def _is_500_error(exc):
class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker", use_cache=False):
def __init__(self, description="Arkindex Base Worker", support_cache=False):
self.parser = argparse.ArgumentParser(description=description)
# Setup workdir either in Ponos environment or on host's home
......@@ -56,7 +61,10 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory")
self.use_cache = use_cache
self.support_cache = support_cache
# use_cache will be updated in configure() if the cache is supported and if there
# is at least one available sqlite database either given or in the parent tasks
self.use_cache = False
@property
def is_read_only(self):
......@@ -108,6 +116,14 @@ class BaseWorker(object):
logger.debug(f"Connected as {user['display_name']} - {user['email']}")
self.features = user["features"]
# Load process information
assert os.environ.get(
"ARKINDEX_PROCESS_ID"
), "ARKINDEX_PROCESS_ID environment variable is not defined"
self.process_information = self.request(
"RetrieveDataImport", id=os.environ["ARKINDEX_PROCESS_ID"]
)
if self.worker_version_id:
# Retrieve initial configuration from API
worker_version = self.request(
......@@ -133,38 +149,39 @@ class BaseWorker(object):
# Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets}
if self.args.database is not None:
task_id = os.environ.get("PONOS_TASK")
paths = None
if self.support_cache and self.args.database is not None:
self.use_cache = True
elif self.support_cache and task_id:
task = self.request("RetrieveTaskFromAgent", id=task_id)
paths = retrieve_parents_cache_path(
task["parents"],
data_dir=os.environ.get("PONOS_DATA", "/data"),
chunk=os.environ.get("ARKINDEX_TASK_CHUNK"),
)
self.use_cache = len(paths) > 0
task_id = os.environ.get("PONOS_TASK")
if self.use_cache is True:
if self.use_cache:
if self.args.database is not None:
assert os.path.isfile(
self.args.database
), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database
elif task_id:
else:
cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
else:
self.cache_path = os.path.join(os.getcwd(), "db.sqlite")
init_cache_db(self.cache_path)
create_tables()
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
if self.args.database is None and paths is not None:
merge_parents_cache(paths, self.cache_path)
else:
logger.debug("Cache is disabled")
# Merging parents caches (if there are any) in the current task local cache, unless the database got overridden
if self.use_cache and self.args.database is None and task_id is not None:
task = self.request("RetrieveTaskFromAgent", id=task_id)
merge_parents_cache(
task["parents"],
self.cache_path,
data_dir=os.environ.get("PONOS_DATA", "/data"),
chunk=os.environ.get("ARKINDEX_TASK_CHUNK"),
)
def load_secret(self, name):
"""Load all secrets described in the worker configuration"""
secret = None
......
# -*- coding: utf-8 -*-
import os
from enum import Enum
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity
from arkindex_worker.models import Element
......@@ -18,14 +20,19 @@ class EntityType(Enum):
class EntityMixin(object):
def create_entity(self, element, name, type, corpus, metas=None, validated=None):
def create_entity(
self, element, name, type, corpus=None, metas=None, validated=None
):
"""
Create an entity on the given corpus through API
Return the ID of the created entity
"""
if corpus is None:
corpus = os.environ.get("ARKINDEX_CORPUS_ID")
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
......@@ -56,4 +63,71 @@ class EntityMixin(object):
)
self.report.add_entity(element.id, entity["id"], type.value, name)
if self.use_cache:
# Store entity in local cache
try:
to_insert = [
{
"id": entity["id"],
"type": type.value,
"name": name,
"validated": validated if validated is not None else False,
"metas": metas,
"worker_version_id": self.worker_version_id,
}
]
CachedEntity.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(f"Couldn't save created entity in local cache: {e}")
return entity["id"]
def create_transcription_entity(self, transcription, entity, offset, length):
"""
Create a link between an existing entity and an existing transcription through API
"""
assert transcription and isinstance(
transcription, str
), "transcription shouldn't be null and should be of type str"
assert entity and isinstance(
entity, str
), "entity shouldn't be null and should be of type str"
assert (
offset is not None and isinstance(offset, int) and offset >= 0
), "offset shouldn't be null and should be a positive integer"
assert (
length is not None and isinstance(length, int) and length > 0
), "length shouldn't be null and should be a strictly positive integer"
if self.is_read_only:
logger.warning(
"Cannot create transcription entity as this worker is in read-only mode"
)
return
self.request(
"CreateTranscriptionEntity",
id=transcription,
body={
"entity": entity,
"length": length,
"offset": offset,
},
)
# TODO: Report transcription entity creation
if self.use_cache:
# Store transcription entity in local cache
try:
to_insert = [
{
"transcription": transcription,
"entity": entity,
"offset": offset,
"length": length,
}
]
CachedTranscriptionEntity.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcription entity in local cache: {e}"
)
......@@ -248,24 +248,39 @@ class TranscriptionMixin(object):
), "worker_version should be of type str"
query_params["worker_version"] = worker_version
if self.use_cache and recursive is None:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"worker_version",
}, "When using the local cache, you can only filter by 'worker_version'"
transcriptions = CachedTranscription.select().where(
CachedTranscription.element_id == element.id
)
if self.use_cache:
if not recursive:
# In this case we don't have to return anything, it's easier to use an
# impossible condition (False) rather than filtering by type for nothing
if element_type and element_type != element.type:
return CachedTranscription.select().where(False)
transcriptions = CachedTranscription.select().where(
CachedTranscription.element_id == element.id
)
else:
base_case = (
CachedElement.select()
.where(CachedElement.id == element.id)
.cte("base", recursive=True)
)
recursive = CachedElement.select().join(
base_case, on=(CachedElement.parent_id == base_case.c.id)
)
cte = base_case.union_all(recursive)
transcriptions = (
CachedTranscription.select()
.join(cte, on=(CachedTranscription.element_id == cte.c.id))
.with_cte(cte)
)
if element_type:
transcriptions = transcriptions.where(cte.c.type == element_type)
if worker_version:
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
)
else:
if self.use_cache:
logger.warning(
"'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter."
)
transcriptions = self.api_client.paginate(
"ListTranscriptions", id=element.id, **query_params
)
......
......@@ -3,5 +3,5 @@ peewee==3.14.4
Pillow==8.1.0
python-gitlab==2.6.0
python-gnupg==0.4.7
sh==1.14.1
sh==1.14.2
tenacity==7.0.0
pytest==6.2.3
pytest==6.2.4
pytest-mock==3.5.1
pytest-responses==0.4.0
pytest-responses==0.5.0
......@@ -92,20 +92,21 @@ def setup_api(responses, monkeypatch, cache_yaml):
@pytest.fixture(autouse=True)
def temp_working_directory(monkeypatch, tmp_path):
def _getcwd():
return str(tmp_path)
monkeypatch.setattr(os, "getcwd", _getcwd)
def give_env_variable(monkeypatch):
"""Defines required environment variables"""
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
monkeypatch.setenv("ARKINDEX_PROCESS_ID", "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff")
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
@pytest.fixture(autouse=True)
def give_worker_version_id_env_variable(monkeypatch):
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
@pytest.fixture
def mock_config_api(mock_worker_version_api, mock_process_api, mock_user_api):
"""Mock all API endpoints required to configure a worker"""
pass
@pytest.fixture
def mock_worker_version_api(responses, mock_user_api):
def mock_worker_version_api(responses):
"""Provide a mock API response to get worker configuration"""
payload = {
"id": "12341234-1234-1234-1234-123412341234",
......@@ -137,18 +138,58 @@ def mock_worker_version_api(responses, mock_user_api):
)
@pytest.fixture
def mock_process_api(responses):
"""Provide a mock of the API response to get information on a process. Workers activity is enabled"""
payload = {
"name": None,
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "running",
"mode": "workers",
"corpus": "11111111-1111-1111-1111-111111111111",
"workflow": "http://testserver/ponos/v1/workflow/12341234-1234-1234-1234-123412341234/",
"files": [],
"revision": None,
"element": {
"id": "12341234-1234-1234-1234-123412341234",
"type": "folder",
"name": "Test folder",
"corpus": {
"id": "11111111-1111-1111-1111-111111111111",
"name": "John Doe project",
"public": False,
},
"thumbnail_url": "http://testserver/thumbnail.png",
"zone": None,
"thumbnail_put_url": "http://testserver/thumbnail.png",
},
"folder_type": None,
"element_type": "page",
"element_name_contains": None,
"load_children": True,
"use_cache": False,
"activity_state": "ready",
}
responses.add(
responses.GET,
"http://testserver/api/v1/imports/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
@pytest.fixture
def mock_user_api(responses):
"""
Provide a mock API response to retrieve user details
Workers Activity is disabled in this mock
Signup is disabled in this mock
"""
payload = {
"id": 1,
"email": "bot@teklia.com",
"display_name": "Bender",
"features": {
"workers_activity": False,
"signup": False,
},
}
......@@ -162,10 +203,9 @@ def mock_user_api(responses):
@pytest.fixture
def mock_elements_worker(monkeypatch, mock_worker_version_api):
def mock_elements_worker(monkeypatch, mock_config_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
worker = ElementsWorker()
worker.configure()
......@@ -173,22 +213,23 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
"""Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(use_cache=True)
worker = BaseWorker(support_cache=True)
monkeypatch.setenv("PONOS_TASK", "my_task")
return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
cache_path = tmp_path / "db.sqlite"
cache_path.touch()
monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)])
worker = ElementsWorker(use_cache=True)
worker = ElementsWorker(support_cache=True)
worker.configure()
return worker
......@@ -269,22 +310,71 @@ def mock_cached_elements():
def mock_cached_transcriptions():
"""Insert few transcriptions in local cache, on a shared element"""
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
id=UUID("11111111-1111-1111-1111-111111111111"),
type="page",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
type="something_else",
parent_id=UUID("11111111-1111-1111-1111-111111111111"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
type="page",
parent_id=UUID("11111111-1111-1111-1111-111111111111"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("44444444-4444-4444-4444-444444444444"),
type="something_else",
parent_id=UUID("22222222-2222-2222-2222-222222222222"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedElement.create(
id=UUID("55555555-5555-5555-5555-555555555555"),
type="something_else",
parent_id=UUID("44444444-4444-4444-4444-444444444444"),
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello!",
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This",
confidence=0.42,
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
CachedTranscription.create(
id=UUID("22222222-2222-2222-2222-222222222222"),
element_id=UUID("12341234-1234-1234-1234-123412341234"),
text="How are you?",
element_id=UUID("22222222-2222-2222-2222-222222222222"),
text="is",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
element_id=UUID("33333333-3333-3333-3333-333333333333"),
text="a",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("44444444-4444-4444-4444-444444444444"),
element_id=UUID("44444444-4444-4444-4444-444444444444"),
text="good",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("55555555-5555-5555-5555-555555555555"),
element_id=UUID("55555555-5555-5555-5555-555555555555"),
text="test",
confidence=0.42,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
......
......@@ -29,11 +29,11 @@ def test_init_default_xdg_data_home(monkeypatch):
def test_init_with_local_cache(monkeypatch):
worker = BaseWorker(use_cache=True)
worker = BaseWorker(support_cache=True)
assert worker.work_dir == os.path.expanduser("~/.local/share/arkindex")
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.use_cache is True
assert worker.support_cache is True
def test_init_var_ponos_data_given(monkeypatch):
......@@ -45,7 +45,9 @@ def test_init_var_ponos_data_given(monkeypatch):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
def test_init_var_worker_version_id_missing(monkeypatch, mock_user_api):
def test_init_var_worker_version_id_missing(
monkeypatch, mock_user_api, mock_process_api
):
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.delenv("WORKER_VERSION_ID")
worker = BaseWorker()
......@@ -55,7 +57,9 @@ def test_init_var_worker_version_id_missing(monkeypatch, mock_user_api):
assert worker.config == {} # default empty case
def test_init_var_worker_local_file(monkeypatch, tmp_path, mock_user_api):
def test_init_var_worker_local_file(
monkeypatch, tmp_path, mock_user_api, mock_process_api
):
# Build a dummy yaml config file
config = tmp_path / "config.yml"
config.write_text("---\nlocalKey: abcdef123")
......@@ -71,7 +75,7 @@ def test_init_var_worker_local_file(monkeypatch, tmp_path, mock_user_api):
config.unlink()
def test_cli_default(mocker, mock_worker_version_api, mock_user_api):
def test_cli_default(mocker, mock_config_api):
worker = BaseWorker()
spy = mocker.spy(worker, "add_arguments")
assert not spy.called
......@@ -93,7 +97,7 @@ def test_cli_default(mocker, mock_worker_version_api, mock_user_api):
logger.setLevel(logging.NOTSET)
def test_cli_arg_verbose_given(mocker, mock_worker_version_api, mock_user_api):
def test_cli_arg_verbose_given(mocker, mock_config_api):
worker = BaseWorker()
spy = mocker.spy(worker, "add_arguments")
assert not spy.called
......
......@@ -56,7 +56,9 @@ def test_create_tables(tmp_path):
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
......@@ -76,13 +78,15 @@ CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT
[
# No max_size: no resize
(400, 600, 400, 600, None, "http://something/full/full/0/default.jpg"),
# No max_size: resize on bbox
(400, 600, 200, 100, None, "http://something/0,0,200,100/full/0/default.jpg"),
# max_size equal to the image size, no resize
(400, 600, 400, 600, 600, "http://something/full/full/0/default.jpg"),
(600, 400, 600, 400, 600, "http://something/full/full/0/default.jpg"),
(400, 400, 400, 400, 400, "http://something/full/full/0/default.jpg"),
# max_size is smaller than the image, resize
(400, 600, 400, 600, 400, "http://something/full/266,400/0/default.jpg"),
(400, 600, 200, 600, 400, "http://something/full/266,400/0/default.jpg"),
(400, 600, 200, 600, 400, "http://something/0,0,200,600/full/0/default.jpg"),
(600, 400, 600, 400, 400, "http://something/full/400,266/0/default.jpg"),
(400, 400, 400, 400, 200, "http://something/full/200,200/0/default.jpg"),
# max_size above the image size, no resize
......@@ -116,9 +120,9 @@ def test_element_open_image(
image=image,
polygon=[
[0, 0],
[image_width, 0],
[image_width, image_height],
[0, image_height],
[polygon_width, 0],
[polygon_width, polygon_height],
[0, polygon_height],
[0, 0],
],
)
......
......@@ -62,7 +62,7 @@ def test_open_image(mocker):
}
}
)
assert elt.open_image() == "an image!"
assert elt.open_image(use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -87,19 +87,19 @@ def test_open_image_resize_portrait(mocker):
}
)
# Resize = original size
assert elt.open_image(max_size=600) == "an image!"
assert elt.open_image(max_size=600, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
)
# Resize = smaller height
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 2
assert open_mock.call_args == mocker.call(
"http://something/full/266,400/0/default.jpg"
)
# Resize = bigger height
assert elt.open_image(max_size=800) == "an image!"
assert elt.open_image(max_size=800, use_full_image=True) == "an image!"
assert open_mock.call_count == 3
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -123,7 +123,7 @@ def test_open_image_resize_partial_element(mocker):
}
}
)
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -148,19 +148,19 @@ def test_open_image_resize_landscape(mocker):
}
)
# Resize = original size
assert elt.open_image(max_size=600) == "an image!"
assert elt.open_image(max_size=600, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
)
# Resize = smaller width
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 2
assert open_mock.call_args == mocker.call(
"http://something/full/400,266/0/default.jpg"
)
# Resize = bigger width
assert elt.open_image(max_size=800) == "an image!"
assert elt.open_image(max_size=800, use_full_image=True) == "an image!"
assert open_mock.call_count == 3
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -185,19 +185,19 @@ def test_open_image_resize_square(mocker):
}
)
# Resize = original size
assert elt.open_image(max_size=400) == "an image!"
assert elt.open_image(max_size=400, use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
)
# Resize = smaller
assert elt.open_image(max_size=200) == "an image!"
assert elt.open_image(max_size=200, use_full_image=True) == "an image!"
assert open_mock.call_count == 2
assert open_mock.call_args == mocker.call(
"http://something/full/200,200/0/default.jpg"
)
# Resize = bigger
assert elt.open_image(max_size=800) == "an image!"
assert elt.open_image(max_size=800, use_full_image=True) == "an image!"
assert open_mock.call_count == 3
assert open_mock.call_args == mocker.call(
"http://something/full/full/0/default.jpg"
......@@ -234,7 +234,7 @@ def test_open_image_s3(mocker):
elt = Element(
{"zone": {"image": {"url": "http://something", "s3_url": "http://s3url"}}}
)
assert elt.open_image() == "an image!"
assert elt.open_image(use_full_image=True) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call("http://s3url")
......@@ -256,7 +256,7 @@ def test_open_image_s3_retry(mocker):
)
with pytest.raises(NotImplementedError):
elt.open_image()
elt.open_image(use_full_image=True)
def test_open_image_s3_retry_once(mocker):
......@@ -274,4 +274,49 @@ def test_open_image_s3_retry_once(mocker):
)
with pytest.raises(NotImplementedError):
elt.open_image()
elt.open_image(use_full_image=True)
def test_open_image_use_full_image_false(mocker):
open_mock = mocker.patch(
"arkindex_worker.models.open_image", return_value="an image!"
)
elt = Element(
{
"zone": {
"image": {"url": "http://something", "s3_url": "http://s3url"},
"url": "http://zoneurl/0,0,400,600/full/0/default.jpg",
}
}
)
assert elt.open_image(use_full_image=False) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://zoneurl/0,0,400,600/full/0/default.jpg"
)
def test_open_image_resize_use_full_image_false(mocker):
open_mock = mocker.patch(
"arkindex_worker.models.open_image", return_value="an image!"
)
elt = Element(
{
"zone": {
"image": {
"url": "http://something",
"width": 400,
"height": 600,
"server": {"max_width": None, "max_height": None},
},
"polygon": [[0, 0], [400, 0], [400, 600], [0, 600], [0, 0]],
"url": "http://zoneurl/0,0,400,600/full/0/default.jpg",
}
}
)
# Resize = smaller
assert elt.open_image(max_size=200, use_full_image=False) == "an image!"
assert open_mock.call_count == 1
assert open_mock.call_args == mocker.call(
"http://zoneurl/0,0,400,600/133,200/0/default.jpg"
)
# -*- coding: utf-8 -*-
# API calls during worker configuration
BASE_API_CALLS = [
("GET", "http://testserver/api/v1/user/"),
("GET", "http://testserver/api/v1/imports/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/"),
(
"GET",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
),
]
......@@ -8,6 +8,8 @@ from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
from . import BASE_API_CALLS
def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
......@@ -31,11 +33,11 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
assert not mock_elements_worker.classes
ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good")
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
]
assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": {"good": "0000"}
......@@ -134,19 +136,16 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
# Simply request class 2, it should be reloaded
assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id"
assert len(responses.calls) == 5
assert len(responses.calls) == len(BASE_API_CALLS) + 3
assert mock_elements_worker.classes == {
corpus_id: {
"class1": "class1_id",
"class2": "class2_id",
}
}
assert [(call.request.method, call.request.url) for call in responses.calls] == [
("GET", "http://testserver/api/v1/user/"),
(
"GET",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
),
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("POST", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
......@@ -350,16 +349,16 @@ def test_create_classification_api_error(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
"http://testserver/api/v1/classifications/",
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
("POST", "http://testserver/api/v1/classifications/"),
]
......@@ -381,14 +380,14 @@ def test_create_classification(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classifications/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
......@@ -430,14 +429,14 @@ def test_create_classification_with_cache(responses, mock_elements_worker_with_c
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classifications/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
......@@ -486,14 +485,14 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
high_confidence=True,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/classifications/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classifications/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"element": "12341234-1234-1234-1234-123412341234",
"ml_class": "0000",
"worker_version": "12341234-1234-1234-1234-123412341234",
......
......@@ -10,7 +10,7 @@ import pytest
from arkindex_worker.worker import ElementsWorker
def test_cli_default(monkeypatch, mock_worker_version_api):
def test_cli_default(monkeypatch, mock_config_api):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
......@@ -33,7 +33,7 @@ def test_cli_default(monkeypatch, mock_worker_version_api):
os.unlink(path)
def test_cli_arg_elements_list_given(mocker, mock_worker_version_api):
def test_cli_arg_elements_list_given(mocker, mock_config_api):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
......@@ -55,14 +55,14 @@ def test_cli_arg_elements_list_given(mocker, mock_worker_version_api):
os.unlink(path)
def test_cli_arg_element_one_given_not_uuid(mocker, mock_elements_worker):
def test_cli_arg_element_one_given_not_uuid(mocker):
mocker.patch.object(sys, "argv", ["worker", "--element", "1234"])
worker = ElementsWorker()
with pytest.raises(SystemExit):
worker.configure()
def test_cli_arg_element_one_given(mocker, mock_elements_worker):
def test_cli_arg_element_one_given(mocker, mock_config_api):
mocker.patch.object(
sys, "argv", ["worker", "--element", "12341234-1234-1234-1234-123412341234"]
)
......@@ -74,7 +74,7 @@ def test_cli_arg_element_one_given(mocker, mock_elements_worker):
assert not worker.args.elements_list
def test_cli_arg_element_many_given(mocker, mock_elements_worker):
def test_cli_arg_element_many_given(mocker, mock_config_api):
mocker.patch.object(
sys,
"argv",
......
......@@ -12,6 +12,8 @@ from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker
from . import BASE_API_CALLS
def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker):
_, path = tempfile.mkstemp()
......@@ -145,7 +147,7 @@ def test_database_arg(mocker, mock_elements_worker, tmp_path):
),
)
worker = ElementsWorker()
worker = ElementsWorker(support_cache=True)
worker.configure()
assert worker.use_cache is True
......@@ -166,16 +168,16 @@ def test_load_corpus_classes_api_error(responses, mock_elements_worker):
):
mock_elements_worker.load_corpus_classes(corpus_id)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We do 5 retries
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
]
assert not mock_elements_worker.classes
......@@ -212,11 +214,11 @@ def test_load_corpus_classes(responses, mock_elements_worker):
assert not mock_elements_worker.classes
mock_elements_worker.load_corpus_classes(corpus_id)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
]
assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": {
......@@ -371,16 +373,16 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
"http://testserver/api/v1/elements/create/",
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
("POST", "http://testserver/api/v1/elements/create/"),
]
......@@ -406,13 +408,13 @@ def test_create_sub_element(responses, mock_elements_worker):
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/create/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/elements/create/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"type": "something",
"name": "0",
"image": "22222222-2222-2222-2222-222222222222",
......@@ -697,16 +699,31 @@ def test_create_elements_api_error(responses, mock_elements_worker):
],
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
]
......@@ -741,13 +758,16 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
],
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"elements": [
{
"name": "0",
......@@ -805,13 +825,16 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
],
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"elements": [
{
"name": "0",
......@@ -1036,16 +1059,31 @@ def test_list_element_children_api_error(responses, mock_elements_worker):
):
next(mock_elements_worker.list_element_children(element=elt))
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We do 5 retries
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
]
......@@ -1102,11 +1140,14 @@ def test_list_element_children(responses, mock_elements_worker):
):
assert child == expected_children[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
),
]
......@@ -1186,8 +1227,7 @@ def test_list_element_children_with_cache(
assert child.id == UUID(expected_id)
# Check the worker never hits the API for elements
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]
assert len(responses.calls) == len(BASE_API_CALLS)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
# -*- coding: utf-8 -*-
import json
from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import (
CachedElement,
CachedEntity,
CachedTranscription,
CachedTranscriptionEntity,
)
from arkindex_worker.models import Element
from arkindex_worker.worker import EntityType
from . import BASE_API_CALLS
def test_create_entity_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
......@@ -16,7 +25,10 @@ def test_create_entity_wrong_element(mock_elements_worker):
type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
......@@ -25,7 +37,10 @@ def test_create_entity_wrong_element(mock_elements_worker):
type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
)
assert str(e.value) == "element shouldn't be null and should be of type Element"
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_entity_wrong_name(mock_elements_worker):
......@@ -81,9 +96,22 @@ def test_create_entity_wrong_type(mock_elements_worker):
assert str(e.value) == "type shouldn't be null and should be of type EntityType"
def test_create_entity_wrong_corpus(mock_elements_worker):
def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
# Triggering an error on metas param, not giving corpus should work since
# ARKINDEX_CORPUS_ID environment variable is set on mock_elements_worker
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
metas="wrong metas",
)
assert str(e.value) == "metas should be of type dict"
# Removing ARKINDEX_CORPUS_ID environment variable should give an error when corpus=None
monkeypatch.delenv("ARKINDEX_CORPUS_ID")
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
element=elt,
......@@ -147,16 +175,16 @@ def test_create_entity_api_error(responses, mock_elements_worker):
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
"http://testserver/api/v1/entity/",
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
("POST", "http://testserver/api/v1/entity/"),
]
......@@ -176,13 +204,47 @@ def test_create_entity(responses, mock_elements_worker):
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/entity/"),
]
assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob",
"type": "person",
"metas": None,
"validated": None,
"corpus": "12341234-1234-1234-1234-123412341234",
"worker_version": "12341234-1234-1234-1234-123412341234",
}
assert entity_id == "12345678-1234-1234-1234-123456789123"
def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
responses.POST,
"http://testserver/api/v1/entity/",
status=200,
json={"id": "12345678-1234-1234-1234-123456789123"},
)
entity_id = mock_elements_worker_with_cache.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/entity/"),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"name": "Bob Bob",
"type": "person",
"metas": None,
......@@ -191,3 +253,265 @@ def test_create_entity(responses, mock_elements_worker):
"worker_version": "12341234-1234-1234-1234-123412341234",
}
assert entity_id == "12345678-1234-1234-1234-123456789123"
# Check that created entity was properly stored in SQLite cache
assert list(CachedEntity.select()) == [
CachedEntity(
id=UUID("12345678-1234-1234-1234-123456789123"),
type="person",
name="Bob Bob",
validated=False,
metas=None,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_transcription_entity_wrong_transcription(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription=None,
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert str(e.value) == "transcription shouldn't be null and should be of type str"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription=1234,
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert str(e.value) == "transcription shouldn't be null and should be of type str"
def test_create_transcription_entity_wrong_entity(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity=None,
offset=5,
length=10,
)
assert str(e.value) == "entity shouldn't be null and should be of type str"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity=1234,
offset=5,
length=10,
)
assert str(e.value) == "entity shouldn't be null and should be of type str"
def test_create_transcription_entity_wrong_offset(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=None,
length=10,
)
assert str(e.value) == "offset shouldn't be null and should be a positive integer"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset="not an int",
length=10,
)
assert str(e.value) == "offset shouldn't be null and should be a positive integer"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=-1,
length=10,
)
assert str(e.value) == "offset shouldn't be null and should be a positive integer"
def test_create_transcription_entity_wrong_length(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=None,
)
assert (
str(e.value)
== "length shouldn't be null and should be a strictly positive integer"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length="not an int",
)
assert (
str(e.value)
== "length shouldn't be null and should be a strictly positive integer"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=0,
)
assert (
str(e.value)
== "length shouldn't be null and should be a strictly positive integer"
)
def test_create_transcription_entity_api_error(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
def test_create_transcription_entity(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=200,
json={
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
},
)
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
assert json.loads(responses.calls[-1].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
}
def test_create_transcription_entity_with_cache(
responses, mock_elements_worker_with_cache
):
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello, it's me.",
confidence=0.42,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
CachedEntity.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
type="person",
name="Bob Bob",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=200,
json={
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
},
)
mock_elements_worker_with_cache.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
),
]
assert json.loads(responses.calls[-1].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
}
# Check that created transcription entity was properly stored in SQLite cache
assert list(CachedTranscriptionEntity.select()) == [
CachedTranscriptionEntity(
transcription=UUID("11111111-1111-1111-1111-111111111111"),
entity=UUID("11111111-1111-1111-1111-111111111111"),
offset=5,
length=10,
)
]
......@@ -7,6 +7,8 @@ from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Element
from arkindex_worker.worker import MetaType
from . import BASE_API_CALLS
def test_create_metadata_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
......@@ -133,16 +135,31 @@ def test_create_metadata_api_error(responses, mock_elements_worker):
value="La Turbine, Grenoble 38000",
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
]
......@@ -162,13 +179,16 @@ def test_create_metadata(responses, mock_elements_worker):
value="La Turbine, Grenoble 38000",
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
),
]
assert json.loads(responses.calls[2].request.body) == {
assert json.loads(responses.calls[-1].request.body) == {
"type": "location",
"name": "Teklia",
"value": "La Turbine, Grenoble 38000",
......