Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (14)
Showing
with 868 additions and 314 deletions
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
0.2.4-rc3
0.3.0-rc1
......@@ -157,7 +157,7 @@ class CachedTranscription(Model):
text = TextField()
confidence = FloatField()
orientation = CharField(max_length=50)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -170,7 +170,7 @@ class CachedClassification(Model):
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -183,7 +183,7 @@ class CachedEntity(Model):
name = TextField()
validated = BooleanField(default=False)
metas = JSONField(null=True)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -197,7 +197,7 @@ class CachedTranscriptionEntity(Model):
entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
confidence = FloatField(null=True)
class Meta:
......
......@@ -207,7 +207,6 @@ class GitHelper:
in worker.configure() configure the git helper and start the cloning:
```
gitlab = GitlabHelper(...)
workflow_id = os.environ["ARKINDEX_PROCESS_ID"]
prepare_git_key(...)
self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...)
self.git_helper.run_clone_in_background()
......
......@@ -136,7 +136,15 @@ class ElementsWorker(
return self.process_information.get("activity_state") == "ready"
def configure(self):
super().configure()
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
super().configure_cache()
# Add report concerning elements
self.report = Reporter(
**self.worker_details, version=getattr(self, "worker_version_id", None)
......@@ -199,7 +207,9 @@ class ElementsWorker(
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing element {element_id}: {e.title} - {e.content}"
else:
message = f"Failed running worker on element {element_id}: {e}"
message = (
f"Failed running worker on element {element_id}: {repr(e)}"
)
logger.warning(
message,
......
......@@ -59,6 +59,39 @@ class BaseWorker(object):
"""
self.parser = argparse.ArgumentParser(description=description)
self.parser.add_argument(
"-c",
"--config",
help="Alternative configuration file when running without a Worker Version ID",
type=open,
)
self.parser.add_argument(
"-d",
"--database",
help="Alternative SQLite database to use for worker caching",
type=str,
default=None,
)
self.parser.add_argument(
"-v",
"--verbose",
"--debug",
help="Display more information on events and errors",
action="store_true",
default=False,
)
self.parser.add_argument(
"--dev",
help=(
"Run worker in developer mode. "
"Worker will be in read-only state even if a worker_version is supplied. "
),
action="store_true",
default=False,
)
# Call potential extra arguments
self.add_arguments()
# Setup workdir either in Ponos environment or on host's home
if os.environ.get("PONOS_DATA"):
......@@ -77,6 +110,11 @@ class BaseWorker(object):
logger.warning(
"Missing WORKER_VERSION_ID environment variable, worker is in read-only mode"
)
self.worker_run_id = os.environ.get("ARKINDEX_WORKER_RUN_ID")
if not self.worker_run_id:
logger.warning(
"Missing ARKINDEX_WORKER_RUN_ID environment variable, worker is in read-only mode"
)
logger.info(f"Worker will use {self.work_dir} as working directory")
......@@ -87,6 +125,9 @@ class BaseWorker(object):
# is at least one available sqlite database either given or in the parent tasks
self.use_cache = False
# Define API Client
self.setup_api_client()
@property
def is_read_only(self) -> bool:
"""
......@@ -96,90 +137,25 @@ class BaseWorker(object):
or when no worker version ID is provided.
:rtype: bool
"""
return self.args.dev or self.worker_version_id is None
def configure(self):
"""
Configure worker using CLI args and environment variables.
"""
self.parser.add_argument(
"-c",
"--config",
help="Alternative configuration file when running without a Worker Version ID",
type=open,
return (
self.args.dev
or self.worker_version_id is None
or self.worker_run_id is None
)
self.parser.add_argument(
"-d",
"--database",
help="Alternative SQLite database to use for worker caching",
type=str,
default=None,
)
self.parser.add_argument(
"-v",
"--verbose",
help="Display more information on events and errors",
action="store_true",
default=False,
)
self.parser.add_argument(
"--dev",
help=(
"Run worker in developer mode. "
"Worker will be in read-only state even if a worker_version is supplied. "
"ARKINDEX_PROCESS_ID environment variable is not required with this flag."
),
action="store_true",
)
# Call potential extra arguments
self.add_arguments()
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
# Setup logging level
if self.args.verbose:
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
def setup_api_client(self):
# Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
# Load features available on backend, and check authentication
user = self.request("RetrieveUser")
logger.debug(f"Connected as {user['display_name']} - {user['email']}")
self.features = user["features"]
# Load process information except in developer mode
if not self.args.dev:
assert os.environ.get(
"ARKINDEX_PROCESS_ID"
), "ARKINDEX_PROCESS_ID environment variable is not defined"
self.process_information = self.request(
"RetrieveDataImport", id=os.environ["ARKINDEX_PROCESS_ID"]
)
def configure_for_developers(self):
assert self.is_read_only
# Setup logging level if verbose or if ARKINDEX_DEBUG is set to true
if self.args.verbose or os.environ.get("ARKINDEX_DEBUG"):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
if self.worker_version_id:
# Retrieve initial configuration from API
worker_version = self.request(
"RetrieveWorkerVersion", id=self.worker_version_id
)
logger.info(
f"Loaded worker {worker_version['worker']['name']} revision {worker_version['revision']['hash'][0:7]} from API"
)
self.config = worker_version["configuration"]["configuration"]
if "user_configuration" in worker_version["configuration"]:
# Add default values (if set) to user_configuration
for key, value in worker_version["configuration"][
"user_configuration"
].items():
if "default" in value:
self.user_configuration[key] = value["default"]
self.worker_details = worker_version["worker"]
required_secrets = worker_version["configuration"].get("secrets", [])
elif self.args.config:
if self.args.config:
# Load config from YAML file
self.config = yaml.safe_load(self.args.config)
self.worker_details = {"name": "Local worker"}
......@@ -196,20 +172,54 @@ class BaseWorker(object):
# Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets}
# Load worker run configuration when available and not in dev mode
if os.environ.get("ARKINDEX_WORKER_RUN_ID") and not self.args.dev:
worker_run = self.request(
"RetrieveWorkerRun", id=os.environ["ARKINDEX_WORKER_RUN_ID"]
)
configuration_id = worker_run.get("configuration_id")
if configuration_id:
worker_configuration = self.request(
"RetrieveWorkerConfiguration", id=configuration_id
)
self.user_configuration = worker_configuration.get("configuration")
if self.user_configuration:
logger.info("Loaded user configuration from WorkerRun")
def configure(self):
"""
Configure worker using CLI args and environment variables.
"""
assert not self.is_read_only
# Setup logging level if verbose or if ARKINDEX_DEBUG is set to true
if self.args.verbose or os.environ.get("ARKINDEX_DEBUG"):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
# Load worker run information
worker_run = self.request("RetrieveWorkerRun", id=self.worker_run_id)
# Load process information
self.process_information = worker_run["process"]
# Load worker version information
worker_version = worker_run["worker_version"]
self.worker_details = worker_version["worker"]
logger.info(
f"Loaded worker {self.worker_details['name']} revision {worker_version['revision']['hash'][0:7]} from API"
)
# Retrieve initial configuration from API
self.config = worker_version["configuration"]["configuration"]
if "user_configuration" in worker_version["configuration"]:
# Add default values (if set) to user_configuration
for key, value in worker_version["configuration"][
"user_configuration"
].items():
if "default" in value:
self.user_configuration[key] = value["default"]
# Load all required secrets
required_secrets = worker_version["configuration"].get("secrets", [])
self.secrets = {name: self.load_secret(name) for name in required_secrets}
# Load worker run configuration when available
worker_configuration = worker_run.get("configuration")
self.user_configuration = worker_configuration.get("configuration")
if self.user_configuration:
logger.info("Loaded user configuration from WorkerRun")
# if debug mode is set to true activate debug mode in logger
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
def configure_cache(self):
task_id = os.environ.get("PONOS_TASK")
paths = None
if self.support_cache and self.args.database is not None:
......
......@@ -261,7 +261,7 @@ class ElementMixin(object):
with_corpus: Optional[bool] = None,
with_has_children: Optional[bool] = None,
with_zone: Optional[bool] = None,
worker_version: Optional[str] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]:
"""
List children of an element.
......@@ -295,7 +295,7 @@ class ElementMixin(object):
This parameter is not supported when caching is enabled.
:type with_zone: Optional[bool]
:param worker_version: Restrict to elements created by a worker version with this UUID.
:type worker_version: Optional[str]
:type worker_version: Optional[Union[str, bool]]
:return: An iterable of dicts from the ``ListElementChildren`` API endpoint,
or an iterable of :class:`CachedElement` when caching is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedElement]]
......@@ -330,10 +330,14 @@ class ElementMixin(object):
if with_zone is not None:
assert isinstance(with_zone, bool), "with_zone should be of type bool"
query_params["with_zone"] = with_zone
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -346,8 +350,12 @@ class ElementMixin(object):
query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type:
query = query.where(CachedElement.type == type)
if worker_version:
query = query.where(CachedElement.worker_version_id == worker_version)
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
query = query.where(
CachedElement.worker_version_id == worker_version_id
)
return query
else:
......
# -*- coding: utf-8 -*-
import hashlib
import os
import tarfile
import tempfile
from contextlib import contextmanager
from typing import NewType, Tuple
import requests
import zstandard as zstd
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
CHUNK_SIZE = 1024
DirPath = NewType("DirPath", str)
Hash = NewType("Hash", str)
FileSize = NewType("FileSize", int)
Archive = Tuple[DirPath, Hash, FileSize]
@contextmanager
def create_archive(path: DirPath) -> Archive:
"""First create a tar archive, then compress to a zst archive.
Finally, get its hash and size
"""
assert path.is_dir(), "create_archive needs a directory"
compressor = zstd.ZstdCompressor(level=3)
content_hasher = hashlib.md5()
archive_hasher = hashlib.md5()
# Remove extension from the model filename
_, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar")
# Create an uncompressed tar archive with all the needed files
# Files hierarchy ifs kept in the archive.
file_list = []
with tarfile.open(path_to_tar_archive, "w") as tar:
for p in path.glob("**/*"):
x = p.relative_to(path)
tar.add(p, arcname=x, recursive=False)
file_list.append(p)
# Sort by path
file_list.sort()
# Compute hash of the files
for file_path in file_list:
with open(file_path, "rb") as file_data:
for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
content_hasher.update(chunk)
_, path_to_zst_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar.zst")
# Compress the archive
with open(path_to_zst_archive, "wb") as archive_file:
with open(path_to_tar_archive, "rb") as model_data:
for model_chunk in iter(lambda: model_data.read(CHUNK_SIZE), b""):
compressed_chunk = compressor.compress(model_chunk)
archive_hasher.update(compressed_chunk)
archive_file.write(compressed_chunk)
# Remove the tar archive
os.remove(path_to_tar_archive)
# Get content hash, archive size and hash
hash = content_hasher.hexdigest()
size = os.path.getsize(path_to_zst_archive)
archive_hash = archive_hasher.hexdigest()
yield path_to_zst_archive, hash, size, archive_hash
# Remove the zstd archive
os.remove(path_to_zst_archive)
class TrainingMixin(object):
"""
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version(self, model_path: DirPath, model_id: str):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex.
"""
# Create the zst archive, get its hash and size
with create_archive(path=model_path) as (
path_to_archive,
hash,
size,
archive_hash,
):
# Create a new model version with hash and size
model_version_details = self.create_model_version(
model_id=model_id,
hash=hash,
size=size,
archive_hash=archive_hash,
)
if model_version_details is None:
return
self.upload_to_s3(
archive_path=path_to_archive,
model_version_details=model_version_details,
)
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self.update_model_version(
model_version_details=model_version_details,
)
def create_model_version(
self,
model_id: str,
hash: str,
size: int,
archive_hash: str,
) -> dict:
"""
Create a new version of the specified model with the given information (hashes and size).
If a version matching the information already exist, there are two cases:
- The version is in `Created` state: this version's details is used
- The version is in `Available` state: you cannot create twice the same version, an error is raised
"""
# Create a new model version with hash and size
try:
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={"hash": hash, "size": size, "archive_hash": archive_hash},
)
except ErrorResponse as e:
if e.status_code >= 500:
logger.error(f"Failed to create model version: {e.content}")
model_version_details = e.content.get("hash")
# If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned.
if model_version_details and isinstance(
model_version_details, (list, tuple)
):
logger.error(model_version_details[0])
return
return model_version_details
def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None:
"""
Upload the archive of the model's files to an Amazon s3 compatible storage
"""
s3_put_url = model_version_details.get("s3_put_url")
logger.info("Uploading to s3...")
# Upload the archive on s3
with open(archive_path, "rb") as archive:
r = requests.put(
url=s3_put_url,
data=archive,
headers={"Content-Type": "application/zstd"},
)
r.raise_for_status()
def update_model_version(
self,
model_version_details: dict,
description: str = None,
configuration: dict = None,
tag: str = None,
) -> None:
"""
Update the specified model version to the state `Available` and use the given information"
"""
logger.info("Updating the model version...")
try:
self.request(
"UpdateModelVersion",
id=model_version_details.get("id"),
body={
"state": "available",
"description": description,
"configuration": configuration,
"tag": tag,
},
)
except ErrorResponse as e:
logger.error(f"Failed to update model version: {e.content}")
......@@ -4,6 +4,7 @@ ElementsWorker methods for transcriptions.
"""
from enum import Enum
from typing import Iterable, Optional, Union
from peewee import IntegrityError
......@@ -348,8 +349,12 @@ class TranscriptionMixin(object):
return annotations
def list_transcriptions(
self, element, element_type=None, recursive=None, worker_version=None
):
self,
element: Union[Element, CachedElement],
element_type: Optional[str] = None,
recursive: Optional[bool] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedTranscription]]:
"""
List transcriptions on an element.
......@@ -359,11 +364,11 @@ class TranscriptionMixin(object):
:type element_type: Optional[str]
:param recursive: Include transcriptions of any descendant of this element, recursively.
:type recursive: Optional[bool]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID.
:type worker_version: Optional[str]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
:type worker_version: Optional[Union[str, bool]]
:returns: An iterable of dicts representing each transcription,
or an iterable of CachedTranscription when cache support is enabled.
:rtype: Iterable[dict] or Iterable[CachedTranscription]
:rtype: Union[Iterable[dict], Iterable[CachedTranscription]]
"""
assert element and isinstance(
element, (Element, CachedElement)
......@@ -375,10 +380,14 @@ class TranscriptionMixin(object):
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -409,9 +418,11 @@ class TranscriptionMixin(object):
if element_type:
transcriptions = transcriptions.where(cte.c.type == element_type)
if worker_version:
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
CachedTranscription.worker_version_id == worker_version_id
)
else:
transcriptions = self.api_client.paginate(
......
arkindex-client==1.0.8
arkindex-client==1.0.9
peewee==3.14.10
Pillow==9.1.0
Pillow>=9.0
python-gitlab==2.7.1
python-gnupg==0.4.8
sh==1.14.2
shapely==1.8.2
tenacity==8.0.1
zstandard==0.18.0
pytest==7.1.1
pytest-mock==3.7.0
pytest-responses==0.5.0
requests==2.27.1
......@@ -26,6 +26,7 @@ from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
SAMPLES_DIR = Path(__file__).resolve().parent / "samples"
__yaml_cache = {}
......@@ -104,142 +105,86 @@ def setup_api(responses, monkeypatch, cache_yaml):
def give_env_variable(request, monkeypatch):
"""Defines required environment variables"""
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
monkeypatch.setenv("ARKINDEX_PROCESS_ID", "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff")
monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", "56785678-5678-5678-5678-567856785678")
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
@pytest.fixture
def mock_config_api(mock_worker_version_api, mock_process_api, mock_user_api):
"""Mock all API endpoints required to configure a worker"""
pass
@pytest.fixture
def mock_worker_version_api(responses):
"""Provide a mock API response to get worker configuration"""
def mock_worker_run_api(responses):
"""Provide a mock API response to get worker run information"""
payload = {
"id": "12341234-1234-1234-1234-123412341234",
"configuration": {
"docker": {"image": "python:3"},
"configuration": {"someKey": "someValue"},
"secrets": [],
},
"revision": {
"hash": "deadbeef1234",
"name": "some git revision",
},
"docker_image": "python:3",
"docker_image_name": "python:3",
"state": "created",
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"model_version_id": None,
"dataimport_id": "0e6053d5-0d50-41dd-88b6-90907493c433",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
}
responses.add(
responses.GET,
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
@pytest.fixture
def mock_worker_version_user_configuration_api(responses):
"""
Provide a mock API response to get a worker configuration
that includes a `user_configuration`
"""
payload = {
"worker": {"id": "1234", "name": "Workerino", "slug": "workerino"},
"revision": {"hash": "1234lala-lalalalala-lala"},
"configuration": {
"configuration": {"param_1": "/some/path/file.pth", "param_2": 12},
"user_configuration": {
"param_3": {
"title": "A Third Parameter",
"type": "string",
"default": "Animula vagula blandula",
},
"param_4": {"title": "Parameter The Fourth", "type": "int"},
"param_5": {
"title": "Parameter 5 (Five)",
"type": "bool",
"default": True,
},
"configuration_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"configuration": {
"docker": {"image": "python:3"},
"configuration": {"someKey": "someValue"},
"secrets": [],
},
"revision": {
"hash": "deadbeef1234",
"name": "some git revision",
},
"docker_image": "python:3",
"docker_image_name": "python:3",
"state": "created",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
},
}
responses.add(
responses.GET,
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
@pytest.fixture
def mock_process_api(responses):
"""Provide a mock of the API response to get information on a process. Workers activity is enabled"""
payload = {
"name": None,
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "running",
"mode": "workers",
"corpus": "11111111-1111-1111-1111-111111111111",
"workflow": "http://testserver/ponos/v1/workflow/12341234-1234-1234-1234-123412341234/",
"files": [],
"revision": None,
"element": {
"id": "12341234-1234-1234-1234-123412341234",
"type": "folder",
"name": "Test folder",
"corpus": {
"id": "11111111-1111-1111-1111-111111111111",
"name": "John Doe project",
"public": False,
"configuration": {
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"name": "string",
"configuration": {},
},
"process": {
"name": None,
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "running",
"mode": "workers",
"corpus": "11111111-1111-1111-1111-111111111111",
"workflow": "http://testserver/ponos/v1/workflow/12341234-1234-1234-1234-123412341234/",
"files": [],
"revision": None,
"element": {
"id": "1234-deadbeef",
"type": "folder",
"name": "Test folder",
"corpus": {
"id": "11111111-1111-1111-1111-111111111111",
"name": "John Doe project",
"public": False,
},
"thumbnail_url": "http://testserver/thumbnail.png",
"zone": None,
"thumbnail_put_url": "http://testserver/thumbnail.png",
},
"thumbnail_url": "http://testserver/thumbnail.png",
"zone": None,
"thumbnail_put_url": "http://testserver/thumbnail.png",
"folder_type": None,
"element_type": "page",
"element_name_contains": None,
"load_children": True,
"use_cache": False,
"activity_state": "ready",
},
"folder_type": None,
"element_type": "page",
"element_name_contains": None,
"load_children": True,
"use_cache": False,
"activity_state": "ready",
}
responses.add(
responses.GET,
"http://testserver/api/v1/imports/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
@pytest.fixture
def mock_user_api(responses):
"""
Provide a mock API response to retrieve user details
Signup is disabled in this mock
"""
payload = {
"id": 1,
"email": "bot@teklia.com",
"display_name": "Bender",
"features": {
"signup": False,
},
}
responses.add(
responses.GET,
"http://testserver/api/v1/user/",
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
......@@ -259,7 +204,7 @@ def mock_activity_calls(responses):
@pytest.fixture
def mock_elements_worker(monkeypatch, mock_config_api):
def mock_elements_worker(monkeypatch, mock_worker_run_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
......@@ -290,17 +235,18 @@ def mock_elements_worker_with_list(monkeypatch, responses, mock_elements_worker)
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api):
"""Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(support_cache=True)
worker.setup_api_client()
monkeypatch.setenv("PONOS_TASK", "my_task")
return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
def mock_elements_worker_with_cache(monkeypatch, mock_worker_run_api, tmp_path):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
cache_path = tmp_path / "db.sqlite"
init_cache_db(cache_path)
......@@ -309,6 +255,7 @@ def mock_elements_worker_with_cache(monkeypatch, mock_config_api, tmp_path):
worker = ElementsWorker(support_cache=True)
worker.configure()
worker.configure_cache()
return worker
......@@ -330,6 +277,11 @@ def fake_transcriptions_small():
return json.load(f)
@pytest.fixture
def model_file_dir():
return SAMPLES_DIR / "model_files"
@pytest.fixture
def fake_dummy_worker():
api_client = MockApiClient()
......@@ -381,7 +333,14 @@ def mock_cached_elements():
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
CachedElement.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="paragraph",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=None,
)
assert CachedElement.select().count() == 3
@pytest.fixture
......@@ -461,6 +420,14 @@ def mock_cached_transcriptions():
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("66666666-6666-6666-6666-666666666666"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This is a manual one",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=None,
)
@pytest.fixture(scope="function")
......
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
# -*- coding: utf-8 -*-
import json
import logging
import os
import sys
......@@ -45,29 +46,40 @@ def test_init_var_ponos_data_given(monkeypatch):
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
def test_init_var_worker_version_id_missing(
monkeypatch, mock_user_api, mock_process_api
):
def test_init_var_worker_version_id_missing(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.delenv("WORKER_VERSION_ID")
monkeypatch.delenv("ARKINDEX_WORKER_RUN_ID")
worker = BaseWorker()
worker.configure()
worker.args = worker.parser.parse_args()
worker.configure_for_developers()
assert worker.worker_version_id is None
assert worker.is_read_only is True
assert worker.config == {} # default empty case
def test_init_var_worker_local_file(
monkeypatch, tmp_path, mock_user_api, mock_process_api
):
def test_init_var_worker_run_id_missing(monkeypatch):
monkeypatch.setattr(sys, "argv", ["worker"])
monkeypatch.delenv("ARKINDEX_WORKER_RUN_ID")
worker = BaseWorker()
worker.args = worker.parser.parse_args()
worker.configure_for_developers()
assert worker.worker_run_id is None
assert worker.is_read_only is True
assert worker.config == {} # default empty case
def test_init_var_worker_local_file(monkeypatch, tmp_path):
# Build a dummy yaml config file
config = tmp_path / "config.yml"
config.write_text("---\nlocalKey: abcdef123")
monkeypatch.setattr(sys, "argv", ["worker", "-c", str(config)])
monkeypatch.delenv("WORKER_VERSION_ID")
monkeypatch.delenv("ARKINDEX_WORKER_RUN_ID")
worker = BaseWorker()
worker.configure()
worker.args = worker.parser.parse_args()
worker.configure_for_developers()
assert worker.worker_version_id is None
assert worker.is_read_only is True
assert worker.config == {"localKey": "abcdef123"} # Use a local file for devs
......@@ -75,86 +87,128 @@ def test_init_var_worker_local_file(
config.unlink()
def test_cli_default(mocker, mock_config_api):
def test_cli_default(mocker, mock_worker_run_api):
worker = BaseWorker()
spy = mocker.spy(worker, "add_arguments")
assert not spy.called
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker"])
worker.configure()
worker.args = worker.parser.parse_args()
assert worker.is_read_only is False
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
assert spy.called
assert spy.call_count == 1
worker.configure()
assert not worker.args.verbose
assert logger.level == logging.NOTSET
assert worker.api_client
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.is_read_only is False
assert worker.config == {"someKey": "someValue"} # from API
logger.setLevel(logging.NOTSET)
def test_cli_arg_verbose_given(mocker, mock_config_api):
def test_cli_arg_verbose_given(mocker, mock_worker_run_api):
worker = BaseWorker()
spy = mocker.spy(worker, "add_arguments")
assert not spy.called
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker", "-v"])
worker.configure()
worker.args = worker.parser.parse_args()
assert worker.is_read_only is False
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
assert spy.called
assert spy.call_count == 1
worker.configure()
assert worker.args.verbose
assert logger.level == logging.DEBUG
assert worker.api_client
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.config == {"someKey": "someValue"} # from API
logger.setLevel(logging.NOTSET)
def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api):
worker = BaseWorker()
assert logger.level == logging.NOTSET
mocker.patch.object(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_DEBUG", True)
worker.args = worker.parser.parse_args()
assert worker.is_read_only is False
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
worker.configure()
assert logger.level == logging.DEBUG
assert worker.api_client
assert worker.config == {"someKey": "someValue"} # from API
logger.setLevel(logging.NOTSET)
def test_configure_dev_mode(
mocker, monkeypatch, mock_user_api, mock_worker_version_api
):
def test_configure_dev_mode(mocker, monkeypatch):
"""
Configuring a worker in developer mode avoid retrieving process information
"""
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker", "--dev"])
monkeypatch.setenv(
"ARKINDEX_WORKER_RUN_ID", "aaaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
)
worker.configure()
worker.args = worker.parser.parse_args()
worker.configure_for_developers()
assert worker.args.dev is True
assert worker.process_information is None
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
assert worker.is_read_only is True
assert worker.user_configuration == {}
def test_configure_worker_run(mocker, monkeypatch, responses, mock_config_api):
def test_configure_worker_run(mocker, monkeypatch, responses):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
run_id = "aaaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
configuration_id = "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", run_id)
responses.add(
responses.GET,
f"http://testserver/api/v1/imports/workers/{run_id}/",
json={"id": run_id, "configuration_id": configuration_id},
)
user_configuration = {
"id": "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
"name": "BBB",
"configuration": {"a": "b"},
}
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"model_version_id": None,
"dataimport_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"configuration_id": user_configuration["id"],
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"revision": {"hash": "deadbeef1234"},
"configuration": {"configuration": {}},
},
"configuration": user_configuration,
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"},
}
responses.add(
responses.GET,
f"http://testserver/api/v1/workers/configurations/{configuration_id}/",
json={"id": configuration_id, "name": "BBB", "configuration": {"a": "b"}},
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
worker.args = worker.parser.parse_args()
assert worker.is_read_only is False
assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
worker.configure()
assert worker.user_configuration == {"a": "b"}
......@@ -163,20 +217,56 @@ def test_configure_worker_run(mocker, monkeypatch, responses, mock_config_api):
def test_configure_user_configuration_defaults(
mocker,
monkeypatch,
mock_worker_version_user_configuration_api,
mock_user_api,
mock_process_api,
responses,
):
worker = BaseWorker()
mocker.patch.object(sys, "argv")
run_id = "aaaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", run_id)
worker.args = worker.parser.parse_args()
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"model_version_id": None,
"dataimport_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"configuration_id": "af0daaf4-983e-4703-a7ed-a10f146d6684",
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"revision": {"hash": "deadbeef1234"},
"configuration": {
"configuration": {"param_1": "/some/path/file.pth", "param_2": 12}
},
},
"configuration": {
"id": "af0daaf4-983e-4703-a7ed-a10f146d6684",
"name": "my-userconfig",
"configuration": {
"param_3": "Animula vagula blandula",
"param_5": True,
},
},
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"},
}
responses.add(
responses.GET,
f"http://testserver/api/v1/imports/workers/{run_id}/",
json={"id": run_id},
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
worker.configure()
assert worker.config == {"param_1": "/some/path/file.pth", "param_2": 12}
......@@ -186,24 +276,97 @@ def test_configure_user_configuration_defaults(
}
def test_configure_worker_run_missing_conf(
mocker, monkeypatch, responses, mock_config_api
):
@pytest.mark.parametrize("debug", (True, False))
def test_configure_user_config_debug(mocker, monkeypatch, responses, debug):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
run_id = "aaaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
configuration_id = "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", run_id)
assert logger.level == logging.NOTSET
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"model_version_id": None,
"dataimport_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"configuration_id": "af0daaf4-983e-4703-a7ed-a10f146d6684",
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"revision": {"hash": "deadbeef1234"},
"configuration": {"configuration": {}},
},
"configuration": {
"id": "af0daaf4-983e-4703-a7ed-a10f146d6684",
"name": "BBB",
"configuration": {"debug": debug},
},
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"},
}
responses.add(
responses.GET,
f"http://testserver/api/v1/imports/workers/{run_id}/",
json={"id": run_id, "configuration_id": configuration_id},
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
worker.args = worker.parser.parse_args()
worker.configure()
assert worker.user_configuration == {"debug": debug}
expected_log_level = logging.DEBUG if debug else logging.NOTSET
assert logger.level == expected_log_level
logger.setLevel(logging.NOTSET)
def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
worker = BaseWorker()
mocker.patch.object(sys, "argv", ["worker"])
payload = {
"id": "56785678-5678-5678-5678-567856785678",
"parents": [],
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"model_version_id": None,
"dataimport_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"configuration_id": "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"worker": {
"id": "deadbeef-1234-5678-1234-worker",
"name": "Fake worker",
"slug": "fake_worker",
"type": "classifier",
},
"revision": {"hash": "deadbeef1234"},
"configuration": {"configuration": {}},
},
"configuration": {"id": "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "name": "BBB"},
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"},
}
responses.add(
responses.GET,
f"http://testserver/api/v1/workers/configurations/{configuration_id}/",
json={"id": configuration_id, "name": "BBB"},
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
status=200,
body=json.dumps(payload),
content_type="application/json",
)
worker.args = worker.parser.parse_args()
worker.configure()
assert worker.user_configuration is None
......
......@@ -58,12 +58,12 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, "confidence" REAL, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
[
......
# -*- coding: utf-8 -*-
# API calls during worker configuration
BASE_API_CALLS = [
("GET", "http://testserver/api/v1/user/"),
("GET", "http://testserver/api/v1/imports/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/"),
(
"GET",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/imports/workers/56785678-5678-5678-5678-567856785678/",
),
]
......@@ -10,7 +10,7 @@ import pytest
from arkindex_worker.worker import ElementsWorker
def test_cli_default(monkeypatch, mock_config_api):
def test_cli_default(monkeypatch, mock_worker_run_api):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
......@@ -33,7 +33,7 @@ def test_cli_default(monkeypatch, mock_config_api):
os.unlink(path)
def test_cli_arg_elements_list_given(mocker, mock_config_api):
def test_cli_arg_elements_list_given(mocker, mock_worker_run_api):
_, path = tempfile.mkstemp()
with open(path, "w") as f:
json.dump(
......@@ -62,7 +62,7 @@ def test_cli_arg_element_one_given_not_uuid(mocker):
worker.configure()
def test_cli_arg_element_one_given(mocker, mock_config_api):
def test_cli_arg_element_one_given(mocker, mock_worker_run_api):
mocker.patch.object(
sys, "argv", ["worker", "--element", "12341234-1234-1234-1234-123412341234"]
)
......@@ -74,7 +74,7 @@ def test_cli_arg_element_one_given(mocker, mock_config_api):
assert not worker.args.elements_list
def test_cli_arg_element_many_given(mocker, mock_config_api):
def test_cli_arg_element_many_given(mocker, mock_worker_run_api):
mocker.patch.object(
sys,
"argv",
......
......@@ -1255,7 +1255,18 @@ def test_list_element_children_wrong_worker_version(mock_elements_worker):
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
assert str(e.value) == "worker_version should be of type str or bool"
def test_list_element_children_wrong_bool_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(
element=elt,
worker_version=True,
)
assert str(e.value) == "if of type bool, worker_version can only be set to False"
def test_list_element_children_api_error(responses, mock_elements_worker):
......@@ -1363,6 +1374,48 @@ def test_list_element_children(responses, mock_elements_worker):
]
def test_list_element_children_manual_worker_version(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
expected_children = [
{
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
status=200,
json={
"count": 1,
"next": None,
"results": expected_children,
},
)
for idx, child in enumerate(
mock_elements_worker.list_element_children(element=elt, worker_version=False)
):
assert child == expected_children[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
),
]
def test_list_element_children_with_cache_unhandled_param(
mock_elements_worker_with_cache,
):
......@@ -1389,6 +1442,7 @@ def test_list_element_children_with_cache_unhandled_param(
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
),
),
# Filter on element and page should give the second element
......@@ -1399,7 +1453,7 @@ def test_list_element_children_with_cache_unhandled_param(
},
("22222222-2222-2222-2222-222222222222",),
),
# Filter on element and worker version should give all elements
# Filter on element and worker version should give first two elements
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1410,7 +1464,7 @@ def test_list_element_children_with_cache_unhandled_param(
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element, type something and worker version should give first
# Filter on element, type something and worker version should give first
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1419,6 +1473,14 @@ def test_list_element_children_with_cache_unhandled_param(
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter on element, manual worker version should give third
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": False,
},
("33333333-3333-3333-3333-333333333333",),
),
),
)
def test_list_element_children_with_cache(
......@@ -1430,7 +1492,7 @@ def test_list_element_children_with_cache(
):
# Check we have 2 elements already present in database
assert CachedElement.select().count() == 2
assert CachedElement.select().count() == 3
# Query database through cache
elements = mock_elements_worker_with_cache.list_element_children(**filters)
......
# -*- coding: utf-8 -*-
import os
import pytest
import responses
from arkindex.mock import MockApiClient
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.training import TrainingMixin, create_archive
class TrainingWorker(BaseWorker, TrainingMixin):
"""
This class is only needed for tests
"""
pass
def test_create_archive(model_file_dir):
"""Create an archive when the model's file is in a folder"""
with create_archive(path=model_file_dir) as (
zst_archive_path,
hash,
size,
archive_hash,
):
assert os.path.exists(zst_archive_path), "The archive was not created"
assert (
hash == "c5aedde18a768757351068b840c8c8f9"
), "Hash was not properly computed"
assert 300 < size < 700
assert not os.path.exists(zst_archive_path), "Auto removal failed"
def test_create_model_version():
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
training.api_client.add_response(
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
)
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
== model_version_details
)
@pytest.mark.parametrize(
"content, status_code",
[
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
({"hash": ["A version for this model with this hash already exists."]}, 403),
],
)
def test_retrieve_created_model_version(content, status_code):
"""
If there is an existing model version in Created mode,
A 400 was raised, but the model is still returned in error content.
Else if an existing model version in Available mode,
403 was raised, but None will be returned
"""
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
content=content,
)
if status_code == 400:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
== content["hash"]
)
elif status_code == 403:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
is None
)
def test_handle_s3_uploading_errors(model_file_dir):
training = TrainingWorker()
training.api_client = MockApiClient()
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
file_path = model_file_dir / "model_file.pth"
with pytest.raises(Exception):
training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})