Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (3)
0.2.2-beta2
0.2.2-rc1
......@@ -127,16 +127,22 @@ class ElementsWorker(
logger.info(f"Processing {element} ({i}/{count})")
# Process the element and report its progress if activities are enabled
self.update_activity(element.id, ActivityState.Started)
self.process_element(element)
self.update_activity(element.id, ActivityState.Processed)
if self.update_activity(element.id, ActivityState.Started):
self.process_element(element)
self.update_activity(element.id, ActivityState.Processed)
else:
logger.info(
f"Skipping element {element.id} as it was already processed"
)
continue
except Exception as e:
# Handle errors occurring while retrieving, processing or patching the activity for this element.
# Count the element as failed in case the activity update to "started" failed with no conflict.
# This prevent from processing the element
failed += 1
element_id = (
element.id
if isinstance(element, (Element, CachedElement))
else item
)
# Handle the case where we failed retrieving the element
element_id = element.id if element else item
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing element {element_id}: {e.title} - {e.content}"
......@@ -147,7 +153,12 @@ class ElementsWorker(
message,
exc_info=e if self.args.verbose else None,
)
self.update_activity(element_id, ActivityState.Error)
if element:
# Try to update the activity to error state regardless of the response
try:
self.update_activity(element.id, ActivityState.Error)
except Exception:
pass
self.report.error(element_id, e)
# Save report as local artifact
......@@ -168,13 +179,14 @@ class ElementsWorker(
def update_activity(self, element_id, state):
"""
Update worker activity for this element
This method should not raise a runtime exception, but simply warn users
Returns False when there is a conflict initializing the activity
Otherwise return True or the response payload
"""
if not self.store_activity:
logger.debug(
"Activity is not stored as the feature is disabled on this process"
)
return
return True
assert element_id and isinstance(
element_id, (uuid.UUID, str)
......@@ -183,10 +195,10 @@ class ElementsWorker(
if self.is_read_only:
logger.warning("Cannot update activity as this worker is in read-only mode")
return
return True
try:
out = self.request(
self.request(
"UpdateWorkerActivity",
id=self.worker_version_id,
body={
......@@ -195,13 +207,22 @@ class ElementsWorker(
"state": state.value,
},
)
logger.debug(f"Updated activity of element {element_id} to {state}")
return out
except ErrorResponse as e:
if state == ActivityState.Started and e.status_code == 409:
# 409 conflict error when updating the state of an activity to "started" mean that we
# cannot process this element. We assume that the reason is that the state transition
# was forbidden i.e. that the activity was already in a started or processed state.
# This allow concurrent access to an element activity between multiple processes.
# Element should not be counted as failed as it is probably handled somewhere else.
logger.debug(
f"Cannot start processing element {element_id} due to a conflict. "
f"Another process could have processed it with the same version already."
)
return False
logger.warning(
f"Failed to update activity of element {element_id} to {state.value} due to an API error: {e.content}"
)
except Exception as e:
logger.warning(
f"Failed to update activity of element {element_id} to {state.value}: {e}"
)
raise e
logger.debug(f"Updated activity of element {element_id} to {state}")
return True
......@@ -131,3 +131,71 @@ class ClassificationMixin(object):
raise
self.report.add_classification(element.id, ml_class)
def create_classifications(self, element, classifications):
"""
Create multiple classifications at once on the given element through the API
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert classifications and isinstance(
classifications, list
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
confidence = classification.get("confidence")
assert (
confidence is not None
and isinstance(confidence, float)
and 0 <= confidence <= 1
), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range"
high_confidence = classification.get("high_confidence")
if high_confidence is not None:
assert isinstance(
high_confidence, bool
), f"Classification at index {index} in classifications: high_confidence should be of type bool"
if self.is_read_only:
logger.warning(
"Cannot create classifications as this worker is in read-only mode"
)
return
created_cls = self.request(
"CreateClassifications",
body={
"parent": str(element.id),
"worker_version": self.worker_version_id,
"classifications": classifications,
},
)["classifications"]
for created_cl in created_cls:
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
# Store classifications in local cache
try:
to_insert = [
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_version_id": self.worker_version_id,
}
for created_cl in created_cls
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classifications in local cache: {e}"
)
......@@ -202,6 +202,18 @@ def mock_user_api(responses):
)
@pytest.fixture
def mock_activity_calls(responses):
"""
Mock responses when updating the activity state for multiple element of the same version
"""
responses.add(
responses.PUT,
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
status=200,
)
@pytest.fixture
def mock_elements_worker(monkeypatch, mock_config_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
......@@ -212,6 +224,27 @@ def mock_elements_worker(monkeypatch, mock_config_api):
return worker
@pytest.fixture
def mock_elements_worker_with_list(monkeypatch, responses, mock_elements_worker):
"""
Mock a worker instance to list and retrieve a single element
"""
monkeypatch.setattr(
mock_elements_worker, "list_elements", lambda: ["1234-deadbeef"]
)
responses.add(
responses.GET,
"http://testserver/api/v1/element/1234-deadbeef/",
status=200,
json={
"id": "1234-deadbeef",
"type": "page",
"name": "Test Page n°1",
},
)
return mock_elements_worker
@pytest.fixture
def mock_base_worker_with_cache(mocker, monkeypatch, mock_config_api):
"""Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK"""
......
......@@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
# Classification has NOT been created
assert mock_elements_worker.report.report_data["elements"] == {}
def test_create_classifications_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=None,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element="not element type",
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_classifications_wrong_classifications(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=None,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=1234,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": None,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": 1234,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": None,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": "wrong confidence",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 2.00,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": "wrong high_confidence",
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: high_confidence should be of type bool"
)
def test_create_classifications_api_error(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=500,
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
with pytest.raises(ErrorResponse):
mock_elements_worker.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
]
def test_create_classifications(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=200,
json={
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
},
],
},
)
mock_elements_worker_with_cache.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classification/bulk/"),
]
assert json.loads(responses.calls[-1].request.body) == {
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": classes,
}
# Check that created classifications were properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID(elt.id),
class_name="portrait",
confidence=0.75,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedClassification(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID(elt.id),
class_name="landscape",
confidence=0.25,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
......@@ -71,7 +71,8 @@ def test_readonly(responses, mock_elements_worker):
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
assert out is None
# update_activity returns False in very specific cases
assert out is True
assert len(responses.calls) == len(BASE_API_CALLS)
assert [
(call.request.method, call.request.url) for call in responses.calls
......@@ -130,11 +131,7 @@ def test_update_call(responses, mock_elements_worker, mock_process_api):
out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
# Check the response received by worker
assert out == {
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "processed",
}
assert out is True
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
......@@ -169,54 +166,17 @@ def test_update_call(responses, mock_elements_worker, mock_process_api):
],
)
def test_run(
monkeypatch, mock_elements_worker, responses, process_exception, final_state
monkeypatch,
mock_elements_worker_with_list,
responses,
process_exception,
final_state,
mock_activity_calls,
):
"""Check the normal runtime sends 2 API calls to update activity"""
# Disable second configure call from run()
monkeypatch.setattr(mock_elements_worker, "configure", lambda: None)
# Mock elements
monkeypatch.setattr(
mock_elements_worker,
"list_elements",
lambda: [
"1234-deadbeef",
],
)
responses.add(
responses.GET,
"http://testserver/api/v1/element/1234-deadbeef/",
status=200,
json={
"id": "1234-deadbeef",
"type": "page",
"name": "Test Page n°1",
},
)
# Mock Update activity
responses.add(
responses.PUT,
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
status=200,
json={
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": "started",
},
)
responses.add(
responses.PUT,
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
status=200,
json={
"element_id": "1234-deadbeef",
"process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"state": final_state,
},
)
assert mock_elements_worker.is_read_only is False
monkeypatch.setattr(mock_elements_worker_with_list, "configure", lambda: None)
assert mock_elements_worker_with_list.is_read_only is False
# Mock exception in process_element
if process_exception:
......@@ -224,14 +184,14 @@ def test_run(
def _err():
raise process_exception
monkeypatch.setattr(mock_elements_worker, "process_element", _err)
monkeypatch.setattr(mock_elements_worker_with_list, "process_element", _err)
# The worker stops because all elements failed !
with pytest.raises(SystemExit):
mock_elements_worker.run()
mock_elements_worker_with_list.run()
else:
# Simply run the process
mock_elements_worker.run()
mock_elements_worker_with_list.run()
assert len(responses.calls) == len(BASE_API_CALLS) + 3
assert [
......@@ -262,7 +222,11 @@ def test_run(
def test_run_cache(
monkeypatch, mocker, mock_elements_worker_with_cache, mock_cached_elements
monkeypatch,
mocker,
mock_elements_worker_with_cache,
mock_cached_elements,
mock_activity_calls,
):
# Disable second configure call from run()
monkeypatch.setattr(mock_elements_worker_with_cache, "configure", lambda: None)
......@@ -278,3 +242,79 @@ def test_run_cache(
mocker.call(elt)
for elt in CachedElement.select()
]
def test_start_activity_conflict(
monkeypatch, responses, mocker, mock_elements_worker_with_list
):
# Disable second configure call from run()
monkeypatch.setattr(mock_elements_worker_with_list, "configure", lambda: None)
# Mock a "normal" conflict during in activity update, which returns the Exception
responses.add(
responses.PUT,
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
body=ErrorResponse(
title="conflict",
status_code=409,
content="Either this activity does not exists or this state is not allowed.",
),
)
from arkindex_worker.worker import logger
logger.info = mocker.MagicMock()
mock_elements_worker_with_list.run()
assert len(responses.calls) == len(BASE_API_CALLS) + 2
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", "http://testserver/api/v1/element/1234-deadbeef/"),
(
"PUT",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
),
]
assert logger.info.call_args_list[:2] == [
mocker.call("Processing page Test Page n°1 (1234-deadbeef) (1/1)"),
mocker.call("Skipping element 1234-deadbeef as it was already processed"),
]
def test_start_activity_error(
monkeypatch, responses, mocker, mock_elements_worker_with_list
):
# Disable second configure call from run()
monkeypatch.setattr(mock_elements_worker_with_list, "configure", lambda: None)
# Mock a random error occurring during the activity update
responses.add(
responses.PUT,
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
body=Exception("A wild Petilil appears !"),
)
from arkindex_worker.worker import logger
logger.error = mocker.MagicMock()
with pytest.raises(SystemExit):
mock_elements_worker_with_list.run()
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("GET", "http://testserver/api/v1/element/1234-deadbeef/"),
(
"PUT",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
),
# Activity is updated to the "error" state regardless of the Exception occurring during the call
(
"PUT",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/",
),
]
assert logger.error.call_args_list == [
mocker.call("Ran on 1 elements: 0 completed, 1 failed")
]