Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
......@@ -45,6 +45,24 @@ def test_create_archive(model_file_dir):
assert not os.path.exists(zst_archive_path), "Auto removal failed"
def test_create_archive_with_subfolder(model_file_dir_with_subfolder):
"""Create an archive when the model's file is in a folder containing a subfolder"""
with create_archive(path=model_file_dir_with_subfolder) as (
zst_archive_path,
hash,
size,
archive_hash,
):
assert os.path.exists(zst_archive_path), "The archive was not created"
assert (
hash == "3e453881404689e6e125144d2db3e605"
), "Hash was not properly computed"
assert 300 < size < 1500
assert not os.path.exists(zst_archive_path), "Auto removal failed"
@pytest.mark.parametrize(
"tag, description",
[
......@@ -53,7 +71,7 @@ def test_create_archive(model_file_dir):
("", "description"),
("tag", ""),
("", ""),
(None, ""),
(None, None),
],
)
def test_create_model_version(mock_training_worker, tag, description):
......@@ -76,17 +94,21 @@ def test_create_model_version(mock_training_worker, tag, description):
"s3_put_url": "http://hehehe.com",
}
expected_payload = {
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
}
if description:
expected_payload["description"] = description
if tag:
expected_payload["tag"] = tag
mock_training_worker.api_client.add_response(
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
body=expected_payload,
)
assert (
mock_training_worker.create_model_version(
......@@ -101,37 +123,19 @@ def test_create_model_version(mock_training_worker, tag, description):
[
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": "tag",
"description": "description",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": "tag",
"description": "description",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
},
400,
),
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"tag": None,
"description": "",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
({"hash": ["A version for this model with this hash already exists."]}, 403),
(["A version for this model with this hash already exists."], 403),
],
)
def test_retrieve_created_model_version(mock_training_worker, content, status_code):
......@@ -141,13 +145,10 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
Else if an existing model version in Available mode,
403 was raised, but None will be returned
"""
model_id = "fake_model_id"
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
tag = "tag"
description = "description"
mock_training_worker.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
......@@ -156,27 +157,57 @@ def test_retrieve_created_model_version(mock_training_worker, content, status_co
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"tag": tag,
"description": description,
},
content=content,
content={"hash": content},
)
if status_code == 400:
assert (
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
model_id, model_hash, size, archive_hash, tag=None, description=None
)
== content["hash"]
== content
)
elif status_code == 403:
assert (
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag, description
model_id, model_hash, size, archive_hash, tag=None, description=None
)
is None
)
@pytest.mark.parametrize(
"content, status_code",
(
# error 500
({"id": "fake_id"}, 500),
# model_version details is None
({}, 403),
(None, 403),
),
)
def test_handle_500_create_model_version(mock_training_worker, content, status_code):
model_id = "fake_model_id"
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
mock_training_worker.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
},
content=content,
)
with pytest.raises(Exception):
mock_training_worker.create_model_version(
model_id, model_hash, size, archive_hash, tag=None, description=None
)
def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir):
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
......
......@@ -133,7 +133,7 @@ def test_create_transcription_default_orientation(responses, mock_elements_worke
"id": "56785678-5678-5678-5678-567856785678",
"text": "Animula vagula blandula",
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
},
)
mock_elements_worker.create_transcription(
......@@ -143,7 +143,7 @@ def test_create_transcription_default_orientation(responses, mock_elements_worke
)
assert json.loads(responses.calls[-1].request.body) == {
"text": "Animula vagula blandula",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"confidence": 0.42,
"orientation": "horizontal-lr",
}
......@@ -159,7 +159,7 @@ def test_create_transcription_orientation(responses, mock_elements_worker):
"id": "56785678-5678-5678-5678-567856785678",
"text": "Animula vagula blandula",
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
},
)
mock_elements_worker.create_transcription(
......@@ -170,7 +170,7 @@ def test_create_transcription_orientation(responses, mock_elements_worker):
)
assert json.loads(responses.calls[-1].request.body) == {
"text": "Animula vagula blandula",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"confidence": 0.42,
"orientation": "vertical-lr",
}
......@@ -229,7 +229,7 @@ def test_create_transcription(responses, mock_elements_worker):
"id": "56785678-5678-5678-5678-567856785678",
"text": "i am a line",
"confidence": 0.42,
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
},
)
......@@ -248,7 +248,7 @@ def test_create_transcription(responses, mock_elements_worker):
assert json.loads(responses.calls[-1].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"confidence": 0.42,
"orientation": "horizontal-lr",
}
......@@ -266,7 +266,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
"text": "i am a line",
"confidence": 0.42,
"orientation": "horizontal-lr",
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
},
)
......@@ -285,7 +285,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
assert json.loads(responses.calls[-1].request.body) == {
"text": "i am a line",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"orientation": "horizontal-lr",
"confidence": 0.42,
}
......@@ -298,7 +298,8 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
text="i am a line",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_version_id=None,
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
)
]
......@@ -316,7 +317,7 @@ def test_create_transcription_orientation_with_cache(
"text": "Animula vagula blandula",
"confidence": 0.42,
"orientation": "vertical-lr",
"worker_version_id": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
},
)
mock_elements_worker_with_cache.create_transcription(
......@@ -327,7 +328,7 @@ def test_create_transcription_orientation_with_cache(
)
assert json.loads(responses.calls[-1].request.body) == {
"text": "Animula vagula blandula",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"orientation": "vertical-lr",
"confidence": 0.42,
}
......@@ -345,12 +346,14 @@ def test_create_transcription_orientation_with_cache(
"mirrored": False,
"initial": False,
"worker_version_id": None,
"worker_run_id": None,
"confidence": None,
},
"text": "Animula vagula blandula",
"confidence": 0.42,
"orientation": TextOrientation.VerticalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
}
]
......@@ -662,7 +665,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
"http://testserver/api/v1/transcription/bulk/",
status=200,
json={
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"transcriptions": [
{
"id": "00000000-0000-0000-0000-000000000000",
......@@ -694,7 +697,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
]
assert json.loads(responses.calls[-1].request.body) == {
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"transcriptions": [
{
"element_id": "11111111-1111-1111-1111-111111111111",
......@@ -719,7 +722,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
text="The",
confidence=0.75,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
CachedTranscription(
id=UUID("11111111-1111-1111-1111-111111111111"),
......@@ -727,7 +730,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
text="word",
confidence=0.42,
orientation=TextOrientation.HorizontalLeftToRight,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]
......@@ -754,7 +757,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_
"http://testserver/api/v1/transcription/bulk/",
status=200,
json={
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"transcriptions": [
{
"id": "00000000-0000-0000-0000-000000000000",
......@@ -779,7 +782,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_
)
assert json.loads(responses.calls[-1].request.body) == {
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"transcriptions": [
{
"element_id": "11111111-1111-1111-1111-111111111111",
......@@ -810,12 +813,14 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_
"mirrored": False,
"initial": False,
"worker_version_id": None,
"worker_run_id": None,
"confidence": None,
},
"text": "Animula vagula blandula",
"confidence": 0.12,
"orientation": TextOrientation.HorizontalRightToLeft.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
},
{
"id": UUID("11111111-1111-1111-1111-111111111111"),
......@@ -829,12 +834,14 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_
"mirrored": False,
"initial": False,
"worker_version_id": None,
"worker_run_id": None,
"confidence": None,
},
"text": "Hospes comesque corporis",
"confidence": 0.21,
"orientation": TextOrientation.VerticalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
},
]
......@@ -1306,7 +1313,7 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"transcriptions": [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
......@@ -1391,7 +1398,7 @@ def test_create_element_transcriptions_with_cache(
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"transcriptions": [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
......@@ -1439,14 +1446,14 @@ def test_create_element_transcriptions_with_cache(
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[100, 150], [700, 150], [700, 200], [100, 200]]",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
CachedElement(
id=UUID("22222222-2222-2222-2222-222222222222"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
polygon="[[0, 0], [2000, 0], [2000, 3000], [0, 3000]]",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]
assert list(CachedTranscription.select()) == [
......@@ -1456,7 +1463,7 @@ def test_create_element_transcriptions_with_cache(
text="The",
confidence=0.5,
orientation=TextOrientation.HorizontalLeftToRight.value,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
CachedTranscription(
id=UUID("67896789-6789-6789-6789-678967896789"),
......@@ -1464,7 +1471,7 @@ def test_create_element_transcriptions_with_cache(
text="first",
confidence=0.75,
orientation=TextOrientation.HorizontalLeftToRight.value,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
CachedTranscription(
id=UUID("78907890-7890-7890-7890-789078907890"),
......@@ -1472,7 +1479,7 @@ def test_create_element_transcriptions_with_cache(
text="line",
confidence=0.9,
orientation=TextOrientation.HorizontalLeftToRight.value,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]
......@@ -1533,7 +1540,7 @@ def test_create_transcriptions_orientation_with_cache(
assert json.loads(responses.calls[-1].request.body) == {
"element_type": "page",
"worker_version": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"transcriptions": [
{
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
......@@ -1587,13 +1594,15 @@ def test_create_transcriptions_orientation_with_cache(
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
"confidence": None,
},
"text": "Animula vagula blandula",
"confidence": 0.5,
"orientation": TextOrientation.HorizontalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
},
{
"id": UUID("67896789-6789-6789-6789-678967896789"),
......@@ -1606,13 +1615,15 @@ def test_create_transcriptions_orientation_with_cache(
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
"confidence": None,
},
"text": "Hospes comesque corporis",
"confidence": 0.75,
"orientation": TextOrientation.VerticalLeftToRight.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
},
{
"id": UUID("78907890-7890-7890-7890-789078907890"),
......@@ -1625,13 +1636,15 @@ def test_create_transcriptions_orientation_with_cache(
"rotation_angle": 0,
"mirrored": False,
"initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
"confidence": None,
},
"text": "Quae nunc abibis in loca",
"confidence": 0.9,
"orientation": TextOrientation.HorizontalRightToLeft.value,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"worker_version_id": None,
"worker_run_id": UUID("56785678-5678-5678-5678-567856785678"),
},
]
......
......@@ -66,7 +66,7 @@ def test_readonly(responses, mock_elements_worker):
"""Test readonly worker does not trigger any API calls"""
# Setup the worker as read-only
mock_elements_worker.worker_version_id = None
mock_elements_worker.worker_run_id = None
assert mock_elements_worker.is_read_only is True
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
......
......@@ -7,6 +7,7 @@ from tempfile import NamedTemporaryFile
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.models import Transcription
from arkindex_worker.reporting import Reporter
......@@ -35,6 +36,7 @@ def test_process():
"transcriptions": 0,
"classifications": {},
"entities": [],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
......@@ -51,6 +53,7 @@ def test_add_element():
"transcriptions": 0,
"classifications": {},
"entities": [],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
......@@ -70,6 +73,7 @@ def test_add_element_count():
"transcriptions": 0,
"classifications": {},
"entities": [],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
......@@ -86,6 +90,7 @@ def test_add_classification():
"transcriptions": 0,
"classifications": {"three": 1},
"entities": [],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
......@@ -116,6 +121,7 @@ def test_add_classifications():
"transcriptions": 0,
"classifications": {"three": 3, "two": 2},
"entities": [],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
......@@ -132,6 +138,7 @@ def test_add_transcription():
"transcriptions": 1,
"classifications": {},
"entities": [],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
......@@ -151,6 +158,7 @@ def test_add_transcription_count():
"transcriptions": 1337,
"classifications": {},
"entities": [],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
......@@ -175,6 +183,34 @@ def test_add_entity():
"name": "Bob Bob",
}
],
"transcription_entities": [],
"metadata": [],
"errors": [],
}
def test_add_transcription_entity():
reporter = Reporter("worker")
reporter.add_transcription_entity(
"5678",
Transcription({"id": "1234-5678", "element": {"id": "myelement"}}),
"1234",
)
assert "myelement" in reporter.report_data["elements"]
element_data = reporter.report_data["elements"]["myelement"]
del element_data["started"]
assert element_data == {
"elements": {},
"transcriptions": 0,
"classifications": {},
"entities": [],
"transcription_entities": [
{
"transcription_id": "1234-5678",
"entity_id": "5678",
"transcription_entity_id": "1234",
}
],
"metadata": [],
"errors": [],
}
......@@ -193,6 +229,7 @@ def test_add_metadata():
"transcriptions": 0,
"classifications": {},
"entities": [],
"transcription_entities": [],
"metadata": [
{
"id": "12341234-1234-1234-1234-123412341234",
......@@ -246,6 +283,7 @@ def test_reporter_save(mocker):
"classifications": {},
"elements": {"text_line": 4},
"entities": [],
"transcription_entities": [],
"errors": [],
"metadata": [],
"started": "2000-01-01T00:00:00",
......
......@@ -7,14 +7,14 @@ repos:
rev: 22.6.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
hooks:
- id: flake8
additional_dependencies:
- 'flake8-coding==1.3.1'
- 'flake8-copyright==0.2.2'
- 'flake8-debugger==3.1.0'
- flake8-coding==1.3.1
- flake8-copyright==0.2.3
- flake8-debugger==3.1.0
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
......@@ -40,6 +40,3 @@ repos:
- repo: meta
hooks:
- id: check-useless-excludes
default_language_version:
python: python3.9
# {{ cookiecutter.slug }}
{{ cookiecutter.description }}
### Development
For development and tests purpose it may be useful to install the worker as a editable package with pip.
```shell
pip3 install -e .
```
### Linter
Code syntax is analyzed before submitting the code.\
To run the linter tools suite you may use pre-commit.
```shell
pip install pre-commit
pre-commit run -a
```
### Run tests
Tests are executed with tox using [pytest](https://pytest.org).
```shell
pip install tox
tox
```
To recreate tox virtual environment (e.g. a dependencies update), you may run `tox -r`
arkindex-base-worker==0.2.4
arkindex-base-worker==0.3.1
......@@ -2,10 +2,13 @@
import os
import pytest
from arkindex_worker.worker.base import BaseWorker
from arkindex.mock import MockApiClient
@pytest.fixture(autouse=True)
def setup_environment(responses):
def setup_environment(responses, monkeypatch):
"""Setup needed environment variables"""
# Allow accessing remote API schemas
......@@ -18,7 +21,8 @@ def setup_environment(responses):
# Set schema url in environment
os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
# Setup a fake worker version ID
os.environ["WORKER_VERSION_ID"] = "1234-{{ cookiecutter.slug }}"
# Setup a fake worker run ID
os.environ["ARKINDEX_WORKER_RUN_ID"] = "1234-{{ cookiecutter.slug }}"
# Setup a mock api client instead of using a real one
monkeypatch.setattr(BaseWorker, "setup_api_client", lambda _: MockApiClient())