Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (4)
Showing
with 541 additions and 44 deletions
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
......@@ -157,7 +157,7 @@ class CachedTranscription(Model):
text = TextField()
confidence = FloatField()
orientation = CharField(max_length=50)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -170,7 +170,7 @@ class CachedClassification(Model):
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -183,7 +183,7 @@ class CachedEntity(Model):
name = TextField()
validated = BooleanField(default=False)
metas = JSONField(null=True)
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
class Meta:
database = db
......@@ -197,7 +197,7 @@ class CachedTranscriptionEntity(Model):
entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField()
worker_version_id = UUIDField(null=True)
confidence = FloatField(null=True)
class Meta:
......
......@@ -138,6 +138,7 @@ class ElementsWorker(
def configure(self):
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
......
......@@ -125,6 +125,9 @@ class BaseWorker(object):
# is at least one available sqlite database either given or in the parent tasks
self.use_cache = False
# Define API Client
self.setup_api_client()
@property
def is_read_only(self) -> bool:
"""
......@@ -140,6 +143,11 @@ class BaseWorker(object):
or self.worker_run_id is None
)
def setup_api_client(self):
# Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
def configure_for_developers(self):
assert self.is_read_only
# Setup logging level if verbose or if ARKINDEX_DEBUG is set to true
......@@ -174,10 +182,6 @@ class BaseWorker(object):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
# Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
# Load worker run information
worker_run = self.request("RetrieveWorkerRun", id=self.worker_run_id)
......
......@@ -261,7 +261,7 @@ class ElementMixin(object):
with_corpus: Optional[bool] = None,
with_has_children: Optional[bool] = None,
with_zone: Optional[bool] = None,
worker_version: Optional[str] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]:
"""
List children of an element.
......@@ -295,7 +295,7 @@ class ElementMixin(object):
This parameter is not supported when caching is enabled.
:type with_zone: Optional[bool]
:param worker_version: Restrict to elements created by a worker version with this UUID.
:type worker_version: Optional[str]
:type worker_version: Optional[Union[str, bool]]
:return: An iterable of dicts from the ``ListElementChildren`` API endpoint,
or an iterable of :class:`CachedElement` when caching is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedElement]]
......@@ -330,10 +330,14 @@ class ElementMixin(object):
if with_zone is not None:
assert isinstance(with_zone, bool), "with_zone should be of type bool"
query_params["with_zone"] = with_zone
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -346,8 +350,12 @@ class ElementMixin(object):
query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type:
query = query.where(CachedElement.type == type)
if worker_version:
query = query.where(CachedElement.worker_version_id == worker_version)
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
query = query.where(
CachedElement.worker_version_id == worker_version_id
)
return query
else:
......
# -*- coding: utf-8 -*-
import hashlib
import os
import tarfile
import tempfile
from contextlib import contextmanager
from typing import NewType, Tuple
import requests
import zstandard as zstd
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
CHUNK_SIZE = 1024
DirPath = NewType("DirPath", str)
Hash = NewType("Hash", str)
FileSize = NewType("FileSize", int)
Archive = Tuple[DirPath, Hash, FileSize]
@contextmanager
def create_archive(path: DirPath) -> Archive:
"""First create a tar archive, then compress to a zst archive.
Finally, get its hash and size
"""
assert path.is_dir(), "create_archive needs a directory"
compressor = zstd.ZstdCompressor(level=3)
content_hasher = hashlib.md5()
archive_hasher = hashlib.md5()
# Remove extension from the model filename
_, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar")
# Create an uncompressed tar archive with all the needed files
# Files hierarchy ifs kept in the archive.
file_list = []
with tarfile.open(path_to_tar_archive, "w") as tar:
for p in path.glob("**/*"):
x = p.relative_to(path)
tar.add(p, arcname=x, recursive=False)
file_list.append(p)
# Sort by path
file_list.sort()
# Compute hash of the files
for file_path in file_list:
with open(file_path, "rb") as file_data:
for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
content_hasher.update(chunk)
_, path_to_zst_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar.zst")
# Compress the archive
with open(path_to_zst_archive, "wb") as archive_file:
with open(path_to_tar_archive, "rb") as model_data:
for model_chunk in iter(lambda: model_data.read(CHUNK_SIZE), b""):
compressed_chunk = compressor.compress(model_chunk)
archive_hasher.update(compressed_chunk)
archive_file.write(compressed_chunk)
# Remove the tar archive
os.remove(path_to_tar_archive)
# Get content hash, archive size and hash
hash = content_hasher.hexdigest()
size = os.path.getsize(path_to_zst_archive)
archive_hash = archive_hasher.hexdigest()
yield path_to_zst_archive, hash, size, archive_hash
# Remove the zstd archive
os.remove(path_to_zst_archive)
class TrainingMixin(object):
"""
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version(self, model_path: DirPath, model_id: str):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex.
"""
# Create the zst archive, get its hash and size
with create_archive(path=model_path) as (
path_to_archive,
hash,
size,
archive_hash,
):
# Create a new model version with hash and size
model_version_details = self.create_model_version(
model_id=model_id,
hash=hash,
size=size,
archive_hash=archive_hash,
)
if model_version_details is None:
return
self.upload_to_s3(
archive_path=path_to_archive,
model_version_details=model_version_details,
)
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self.update_model_version(
model_version_details=model_version_details,
)
def create_model_version(
self,
model_id: str,
hash: str,
size: int,
archive_hash: str,
) -> dict:
"""
Create a new version of the specified model with the given information (hashes and size).
If a version matching the information already exist, there are two cases:
- The version is in `Created` state: this version's details is used
- The version is in `Available` state: you cannot create twice the same version, an error is raised
"""
# Create a new model version with hash and size
try:
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={"hash": hash, "size": size, "archive_hash": archive_hash},
)
except ErrorResponse as e:
if e.status_code >= 500:
logger.error(f"Failed to create model version: {e.content}")
model_version_details = e.content.get("hash")
# If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned.
if model_version_details and isinstance(
model_version_details, (list, tuple)
):
logger.error(model_version_details[0])
return
return model_version_details
def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None:
"""
Upload the archive of the model's files to an Amazon s3 compatible storage
"""
s3_put_url = model_version_details.get("s3_put_url")
logger.info("Uploading to s3...")
# Upload the archive on s3
with open(archive_path, "rb") as archive:
r = requests.put(
url=s3_put_url,
data=archive,
headers={"Content-Type": "application/zstd"},
)
r.raise_for_status()
def update_model_version(
self,
model_version_details: dict,
description: str = None,
configuration: dict = None,
tag: str = None,
) -> None:
"""
Update the specified model version to the state `Available` and use the given information"
"""
logger.info("Updating the model version...")
try:
self.request(
"UpdateModelVersion",
id=model_version_details.get("id"),
body={
"state": "available",
"description": description,
"configuration": configuration,
"tag": tag,
},
)
except ErrorResponse as e:
logger.error(f"Failed to update model version: {e.content}")
......@@ -4,6 +4,7 @@ ElementsWorker methods for transcriptions.
"""
from enum import Enum
from typing import Iterable, Optional, Union
from peewee import IntegrityError
......@@ -348,8 +349,12 @@ class TranscriptionMixin(object):
return annotations
def list_transcriptions(
self, element, element_type=None, recursive=None, worker_version=None
):
self,
element: Union[Element, CachedElement],
element_type: Optional[str] = None,
recursive: Optional[bool] = None,
worker_version: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedTranscription]]:
"""
List transcriptions on an element.
......@@ -359,11 +364,11 @@ class TranscriptionMixin(object):
:type element_type: Optional[str]
:param recursive: Include transcriptions of any descendant of this element, recursively.
:type recursive: Optional[bool]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID.
:type worker_version: Optional[str]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
:type worker_version: Optional[Union[str, bool]]
:returns: An iterable of dicts representing each transcription,
or an iterable of CachedTranscription when cache support is enabled.
:rtype: Iterable[dict] or Iterable[CachedTranscription]
:rtype: Union[Iterable[dict], Iterable[CachedTranscription]]
"""
assert element and isinstance(
element, (Element, CachedElement)
......@@ -375,10 +380,14 @@ class TranscriptionMixin(object):
if recursive is not None:
assert isinstance(recursive, bool), "recursive should be of type bool"
query_params["recursive"] = recursive
if worker_version:
if worker_version is not None:
assert isinstance(
worker_version, str
), "worker_version should be of type str"
worker_version, (str, bool)
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
worker_version is False
), "if of type bool, worker_version can only be set to False"
query_params["worker_version"] = worker_version
if self.use_cache:
......@@ -409,9 +418,11 @@ class TranscriptionMixin(object):
if element_type:
transcriptions = transcriptions.where(cte.c.type == element_type)
if worker_version:
if worker_version is not None:
# If worker_version=False, filter by manual worker_version e.g. None
worker_version_id = worker_version if worker_version else None
transcriptions = transcriptions.where(
CachedTranscription.worker_version_id == worker_version
CachedTranscription.worker_version_id == worker_version_id
)
else:
transcriptions = self.api_client.paginate(
......
......@@ -6,3 +6,4 @@ python-gnupg==0.4.8
sh==1.14.2
shapely==1.8.2
tenacity==8.0.1
zstandard==0.18.0
pytest==7.1.1
pytest-mock==3.7.0
pytest-responses==0.5.0
requests==2.27.1
......@@ -26,6 +26,7 @@ from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
SAMPLES_DIR = Path(__file__).resolve().parent / "samples"
__yaml_cache = {}
......@@ -239,6 +240,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api):
monkeypatch.setattr(sys, "argv", ["worker"])
worker = BaseWorker(support_cache=True)
worker.setup_api_client()
monkeypatch.setenv("PONOS_TASK", "my_task")
return worker
......@@ -275,6 +277,11 @@ def fake_transcriptions_small():
return json.load(f)
@pytest.fixture
def model_file_dir():
return SAMPLES_DIR / "model_files"
@pytest.fixture
def fake_dummy_worker():
api_client = MockApiClient()
......@@ -326,7 +333,14 @@ def mock_cached_elements():
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
)
assert CachedElement.select().count() == 2
CachedElement.create(
id=UUID("33333333-3333-3333-3333-333333333333"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="paragraph",
polygon="[[1, 1], [2, 2], [2, 1], [1, 2]]",
worker_version_id=None,
)
assert CachedElement.select().count() == 3
@pytest.fixture
......@@ -406,6 +420,14 @@ def mock_cached_transcriptions():
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
)
CachedTranscription.create(
id=UUID("66666666-6666-6666-6666-666666666666"),
element_id=UUID("11111111-1111-1111-1111-111111111111"),
text="This is a manual one",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=None,
)
@pytest.fixture(scope="function")
......
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
......@@ -90,7 +90,6 @@ def test_init_var_worker_local_file(monkeypatch, tmp_path):
def test_cli_default(mocker, mock_worker_run_api):
worker = BaseWorker()
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker"])
worker.args = worker.parser.parse_args()
......@@ -110,7 +109,6 @@ def test_cli_default(mocker, mock_worker_run_api):
def test_cli_arg_verbose_given(mocker, mock_worker_run_api):
worker = BaseWorker()
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker", "-v"])
worker.args = worker.parser.parse_args()
......@@ -131,7 +129,6 @@ def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api):
worker = BaseWorker()
assert logger.level == logging.NOTSET
assert not hasattr(worker, "api_client")
mocker.patch.object(sys, "argv", ["worker"])
monkeypatch.setenv("ARKINDEX_DEBUG", True)
worker.args = worker.parser.parse_args()
......
......@@ -58,12 +58,12 @@ def test_create_tables(tmp_path):
init_cache_db(db_path)
create_tables()
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, "confidence" REAL, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
[
......
......@@ -1255,7 +1255,18 @@ def test_list_element_children_wrong_worker_version(mock_elements_worker):
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
assert str(e.value) == "worker_version should be of type str or bool"
def test_list_element_children_wrong_bool_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_element_children(
element=elt,
worker_version=True,
)
assert str(e.value) == "if of type bool, worker_version can only be set to False"
def test_list_element_children_api_error(responses, mock_elements_worker):
......@@ -1363,6 +1374,48 @@ def test_list_element_children(responses, mock_elements_worker):
]
def test_list_element_children_manual_worker_version(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
expected_children = [
{
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
status=200,
json={
"count": 1,
"next": None,
"results": expected_children,
},
)
for idx, child in enumerate(
mock_elements_worker.list_element_children(element=elt, worker_version=False)
):
assert child == expected_children[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False",
),
]
def test_list_element_children_with_cache_unhandled_param(
mock_elements_worker_with_cache,
):
......@@ -1389,6 +1442,7 @@ def test_list_element_children_with_cache_unhandled_param(
(
"11111111-1111-1111-1111-111111111111",
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
),
),
# Filter on element and page should give the second element
......@@ -1399,7 +1453,7 @@ def test_list_element_children_with_cache_unhandled_param(
},
("22222222-2222-2222-2222-222222222222",),
),
# Filter on element and worker version should give all elements
# Filter on element and worker version should give first two elements
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1410,7 +1464,7 @@ def test_list_element_children_with_cache_unhandled_param(
"22222222-2222-2222-2222-222222222222",
),
),
# Filter on element, type something and worker version should give first
# Filter on element, type something and worker version should give first
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
......@@ -1419,6 +1473,14 @@ def test_list_element_children_with_cache_unhandled_param(
},
("11111111-1111-1111-1111-111111111111",),
),
# Filter on element, manual worker version should give third
(
{
"element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
"worker_version": False,
},
("33333333-3333-3333-3333-333333333333",),
),
),
)
def test_list_element_children_with_cache(
......@@ -1430,7 +1492,7 @@ def test_list_element_children_with_cache(
):
# Check we have 2 elements already present in database
assert CachedElement.select().count() == 2
assert CachedElement.select().count() == 3
# Query database through cache
elements = mock_elements_worker_with_cache.list_element_children(**filters)
......
# -*- coding: utf-8 -*-
import os
import pytest
import responses
from arkindex.mock import MockApiClient
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.training import TrainingMixin, create_archive
class TrainingWorker(BaseWorker, TrainingMixin):
"""
This class is only needed for tests
"""
pass
def test_create_archive(model_file_dir):
"""Create an archive when the model's file is in a folder"""
with create_archive(path=model_file_dir) as (
zst_archive_path,
hash,
size,
archive_hash,
):
assert os.path.exists(zst_archive_path), "The archive was not created"
assert (
hash == "c5aedde18a768757351068b840c8c8f9"
), "Hash was not properly computed"
assert 300 < size < 700
assert not os.path.exists(zst_archive_path), "Auto removal failed"
def test_create_model_version():
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
training.api_client.add_response(
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
)
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
== model_version_details
)
@pytest.mark.parametrize(
"content, status_code",
[
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
({"hash": ["A version for this model with this hash already exists."]}, 403),
],
)
def test_retrieve_created_model_version(content, status_code):
"""
If there is an existing model version in Created mode,
A 400 was raised, but the model is still returned in error content.
Else if an existing model version in Available mode,
403 was raised, but None will be returned
"""
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
content=content,
)
if status_code == 400:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
== content["hash"]
)
elif status_code == 403:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
is None
)
def test_handle_s3_uploading_errors(model_file_dir):
training = TrainingWorker()
training.api_client = MockApiClient()
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
file_path = model_file_dir / "model_file.pth"
with pytest.raises(Exception):
training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
......@@ -1682,7 +1682,18 @@ def test_list_transcriptions_wrong_worker_version(mock_elements_worker):
element=elt,
worker_version=1234,
)
assert str(e.value) == "worker_version should be of type str"
assert str(e.value) == "worker_version should be of type str or bool"
def test_list_transcriptions_wrong_bool_worker_version(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.list_transcriptions(
element=elt,
worker_version=True,
)
assert str(e.value) == "if of type bool, worker_version can only be set to False"
def test_list_transcriptions_api_error(responses, mock_elements_worker):
......@@ -1778,19 +1789,60 @@ def test_list_transcriptions(responses, mock_elements_worker):
]
def test_list_transcriptions_manual_worker_version(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
trans = [
{
"id": "0000",
"text": "hey",
"confidence": 0.42,
"worker_version_id": None,
"element": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False",
status=200,
json={
"count": 1,
"next": None,
"results": trans,
},
)
for idx, transcription in enumerate(
mock_elements_worker.list_transcriptions(element=elt, worker_version=False)
):
assert transcription == trans[idx]
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?worker_version=False",
),
]
@pytest.mark.parametrize(
"filters, expected_ids",
(
# Filter on element should give first transcription
# Filter on element should give first and sixth transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
},
("11111111-1111-1111-1111-111111111111",),
(
"11111111-1111-1111-1111-111111111111",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter on element and element_type should give first transcription
# Filter on element and element_type should give first and sixth transcription
(
{
"element": CachedElement(
......@@ -1798,7 +1850,10 @@ def test_list_transcriptions(responses, mock_elements_worker):
),
"element_type": "page",
},
("11111111-1111-1111-1111-111111111111",),
(
"11111111-1111-1111-1111-111111111111",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter on element and worker_version should give first transcription
(
......@@ -1824,6 +1879,7 @@ def test_list_transcriptions(responses, mock_elements_worker):
"33333333-3333-3333-3333-333333333333",
"44444444-4444-4444-4444-444444444444",
"55555555-5555-5555-5555-555555555555",
"66666666-6666-6666-6666-666666666666",
),
),
# Filter recursively on element and worker_version should give four transcriptions
......@@ -1857,6 +1913,16 @@ def test_list_transcriptions(responses, mock_elements_worker):
"55555555-5555-5555-5555-555555555555",
),
),
# Filter on element with manually created transcription should give sixth transcription
(
{
"element": CachedElement(
id="11111111-1111-1111-1111-111111111111", type="page"
),
"worker_version": False,
},
("66666666-6666-6666-6666-666666666666",),
),
),
)
def test_list_transcriptions_with_cache(
......@@ -1867,7 +1933,7 @@ def test_list_transcriptions_with_cache(
expected_ids,
):
# Check we have 5 elements already present in database
assert CachedTranscription.select().count() == 5
assert CachedTranscription.select().count() == 6
# Query database through cache
transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters)
......