diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 86ce70267b019a86fe43544094d856dc9a9acb06..e00e4952ec20e694d916f914fe28599b51dc6592 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -84,6 +84,11 @@ class BaseWorker(object): self.api_client = ArkindexClient(**options_from_env()) logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}") + # Load features available on backend, and check authentication + user = self.api_client.request("RetrieveUser") + logger.debug(f"Connected as {user['display_name']} - {user['email']}") + self.features = user["features"] + if self.worker_version_id: # Retrieve initial configuration from API worker_version = self.api_client.request( @@ -189,6 +194,13 @@ class MetaType(Enum): Reference = "reference" +class ActivityState(Enum): + Queued = "queued" + Started = "started" + Processed = "processed" + Error = "error" + + class ElementsWorker(BaseWorker): def __init__(self, description="Arkindex Elements Worker"): super().__init__(description) @@ -254,13 +266,18 @@ class ElementsWorker(BaseWorker): **self.api_client.request("RetrieveElement", id=element_id) ) logger.info(f"Processing {element} ({i}/{count})") + + # Report start of process, run process, then report end of process + self.update_activity(element, ActivityState.Started) self.process_element(element) + self.update_activity(element, ActivityState.Processed) except ErrorResponse as e: failed += 1 logger.warning( f"An API error occurred while processing element {element_id}: {e.title} - {e.content}", exc_info=e if self.args.verbose else None, ) + self.update_activity(element, ActivityState.Error) self.report.error(element_id, e) except Exception as e: failed += 1 @@ -268,6 +285,7 @@ class ElementsWorker(BaseWorker): f"Failed running worker on element {element_id}: {e}", exc_info=e if self.args.verbose else None, ) + self.update_activity(element, ActivityState.Error) self.report.error(element_id, e) # Save report as local artifact @@ -782,3 +800,41 @@ class ElementsWorker(BaseWorker): ) return children + + def update_activity(self, element, state): + """ + Update worker activity for this element + This method should not raise a runtime exception, but simply warn users + """ + assert element and isinstance( + element, Element + ), "element shouldn't be null and should be of type Element" + assert isinstance(state, ActivityState), "state should be an ActivityState" + + if not self.features.get("workers_activity"): + logger.debug("Skipping Worker activity update as it's disabled on backend") + return + + if self.is_read_only: + logger.warning("Cannot update activity as this worker is in read-only mode") + return + + try: + out = self.api_client.request( + "UpdateWorkerActivity", + id=self.worker_version_id, + body={ + "element_id": element.id, + "state": state.value, + }, + ) + logger.debug(f"Updated activity of element {element.id} to {state}") + return out + except ErrorResponse as e: + logger.warning( + f"Failed to update activity of element {element.id} to {state.value} due to an API error: {e.content}" + ) + except Exception as e: + logger.warning( + f"Failed to update activity of element {element.id} to {state.value}: {e}" + ) diff --git a/tests/conftest.py b/tests/conftest.py index fd91f7015ca1776c1089cf7d02d51e63735f843e..5d65a20ee248c4af7ff2a932ec8010799b588a6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,7 +51,7 @@ def give_worker_version_id_env_variable(monkeypatch): @pytest.fixture -def mock_worker_version_api(responses): +def mock_worker_version_api(responses, mock_user_api): """Provide a mock API response to get worker configuration""" payload = { "id": "12341234-1234-1234-1234-123412341234", @@ -83,6 +83,30 @@ def mock_worker_version_api(responses): ) +@pytest.fixture +def mock_user_api(responses): + """ + Provide a mock API response to retrieve user details + Workers Activity is disabled in this mock + """ + payload = { + "id": 1, + "email": "bot@teklia.com", + "display_name": "Bender", + "features": { + "workers_activity": False, + "signup": False, + }, + } + responses.add( + responses.GET, + "http://testserver/api/v1/user/", + status=200, + body=json.dumps(payload), + content_type="application/json", + ) + + @pytest.fixture def mock_elements_worker(monkeypatch, mock_worker_version_api): """Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest""" diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py index dc39b4c0fdb6cc1db7d264bed5ecb820291215f3..d19fd4e333e0a1a9e374b42d26d2705005328dbe 100644 --- a/tests/test_base_worker.py +++ b/tests/test_base_worker.py @@ -37,7 +37,7 @@ def test_init_var_ponos_data_given(monkeypatch): assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234" -def test_init_var_worker_version_id_missing(monkeypatch): +def test_init_var_worker_version_id_missing(monkeypatch, mock_user_api): monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.delenv("WORKER_VERSION_ID") worker = BaseWorker() @@ -47,7 +47,7 @@ def test_init_var_worker_version_id_missing(monkeypatch): assert worker.config == {} # default empty case -def test_init_var_worker_local_file(monkeypatch, tmp_path): +def test_init_var_worker_local_file(monkeypatch, tmp_path, mock_user_api): # Build a dummy yaml config file config = tmp_path / "config.yml" config.write_text("---\nlocalKey: abcdef123") @@ -63,7 +63,7 @@ def test_init_var_worker_local_file(monkeypatch, tmp_path): config.unlink() -def test_cli_default(mocker, mock_worker_version_api): +def test_cli_default(mocker, mock_worker_version_api, mock_user_api): worker = BaseWorker() spy = mocker.spy(worker, "add_arguments") assert not spy.called @@ -85,7 +85,7 @@ def test_cli_default(mocker, mock_worker_version_api): logger.setLevel(logging.NOTSET) -def test_cli_arg_verbose_given(mocker, mock_worker_version_api): +def test_cli_arg_verbose_given(mocker, mock_worker_version_api, mock_user_api): worker = BaseWorker() spy = mocker.spy(worker, "add_arguments") assert not spy.called diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py index 64372a30589b9961ed3c682989d336da683e7871..2b2e33ac2d638cae0d84d0da7396fbe26bba16c8 100644 --- a/tests/test_elements_worker.py +++ b/tests/test_elements_worker.py @@ -254,8 +254,9 @@ def test_load_corpus_classes_api_error(responses, mock_elements_worker): ): mock_elements_worker.load_corpus_classes(corpus_id) - assert len(responses.calls) == 6 + assert len(responses.calls) == 7 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", # We do 5 retries f"http://testserver/api/v1/corpus/{corpus_id}/classes/", @@ -299,8 +300,9 @@ def test_load_corpus_classes(responses, mock_elements_worker): assert not mock_elements_worker.classes mock_elements_worker.load_corpus_classes(corpus_id) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/corpus/{corpus_id}/classes/", ] @@ -335,8 +337,9 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker): assert not mock_elements_worker.classes ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good") - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/corpus/{corpus_id}/classes/", ] @@ -437,7 +440,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker): # Simply request class 2, it should be reloaded assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id" - assert len(responses.calls) == 4 + assert len(responses.calls) == 5 assert mock_elements_worker.classes == { corpus_id: { "class1": "class1_id", @@ -445,6 +448,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker): } } assert [(call.request.method, call.request.url) for call in responses.calls] == [ + ("GET", "http://testserver/api/v1/user/"), ( "GET", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", @@ -599,8 +603,9 @@ def test_create_sub_element_api_error(responses, mock_elements_worker): polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/elements/create/", ] @@ -628,12 +633,13 @@ def test_create_sub_element(responses, mock_elements_worker): polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/elements/create/", ] - assert json.loads(responses.calls[1].request.body) == { + assert json.loads(responses.calls[2].request.body) == { "type": "something", "name": "0", "image": "22222222-2222-2222-2222-222222222222", @@ -684,8 +690,9 @@ def test_create_transcription_type_warning(responses, mock_elements_worker): == "Transcription types are deprecated and will be removed in the next release." ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcription/", ] @@ -772,8 +779,9 @@ def test_create_transcription_api_error(responses, mock_elements_worker): score=0.42, ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcription/", ] @@ -793,13 +801,14 @@ def test_create_transcription(responses, mock_elements_worker): score=0.42, ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcription/", ] - assert json.loads(responses.calls[1].request.body) == { + assert json.loads(responses.calls[2].request.body) == { "text": "i am a line", "worker_version": "12341234-1234-1234-1234-123412341234", "score": 0.42, @@ -1017,8 +1026,9 @@ def test_create_classification_api_error(responses, mock_elements_worker): high_confidence=True, ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/classifications/", ] @@ -1047,13 +1057,14 @@ def test_create_classification(responses, mock_elements_worker): high_confidence=True, ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/classifications/", ] - assert json.loads(responses.calls[1].request.body) == { + assert json.loads(responses.calls[2].request.body) == { "element": "12341234-1234-1234-1234-123412341234", "ml_class": "0000", "worker_version": "12341234-1234-1234-1234-123412341234", @@ -1095,13 +1106,14 @@ def test_create_classification_duplicate(responses, mock_elements_worker): high_confidence=True, ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/classifications/", ] - assert json.loads(responses.calls[1].request.body) == { + assert json.loads(responses.calls[2].request.body) == { "element": "12341234-1234-1234-1234-123412341234", "ml_class": "0000", "worker_version": "12341234-1234-1234-1234-123412341234", @@ -1252,8 +1264,9 @@ def test_create_entity_api_error(responses, mock_elements_worker): corpus="12341234-1234-1234-1234-123412341234", ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/entity/", ] @@ -1275,12 +1288,13 @@ def test_create_entity(responses, mock_elements_worker): corpus="12341234-1234-1234-1234-123412341234", ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/entity/", ] - assert json.loads(responses.calls[1].request.body) == { + assert json.loads(responses.calls[2].request.body) == { "name": "Bob Bob", "type": "person", "metas": None, @@ -1705,8 +1719,9 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker transcriptions=TRANSCRIPTIONS_SAMPLE, ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", ] @@ -1731,13 +1746,14 @@ def test_create_element_transcriptions(responses, mock_elements_worker): transcriptions=TRANSCRIPTIONS_SAMPLE, ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", ] - assert json.loads(responses.calls[1].request.body) == { + assert json.loads(responses.calls[2].request.body) == { "element_type": "page", "worker_version": "12341234-1234-1234-1234-123412341234", "transcriptions": TRANSCRIPTIONS_SAMPLE, @@ -1875,8 +1891,9 @@ def test_create_metadata_api_error(responses, mock_elements_worker): value="La Turbine, Grenoble 38000", ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/", ] @@ -1898,12 +1915,13 @@ def test_create_metadata(responses, mock_elements_worker): value="La Turbine, Grenoble 38000", ) - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/", ] - assert json.loads(responses.calls[1].request.body) == { + assert json.loads(responses.calls[2].request.body) == { "type": "location", "name": "Teklia", "value": "La Turbine, Grenoble 38000", @@ -2105,8 +2123,9 @@ def test_list_transcriptions_api_error(responses, mock_elements_worker): ): next(mock_elements_worker.list_transcriptions(element=elt)) - assert len(responses.calls) == 6 + assert len(responses.calls) == 7 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", # We do 5 retries "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", @@ -2158,8 +2177,9 @@ def test_list_transcriptions(responses, mock_elements_worker): ): assert transcription == trans[idx] - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", ] @@ -2298,8 +2318,9 @@ def test_list_element_children_api_error(responses, mock_elements_worker): ): next(mock_elements_worker.list_element_children(element=elt)) - assert len(responses.calls) == 6 + assert len(responses.calls) == 7 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", # We do 5 retries "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", @@ -2363,8 +2384,9 @@ def test_list_element_children(responses, mock_elements_worker): ): assert child == expected_children[idx] - assert len(responses.calls) == 2 + assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ] diff --git a/tests/test_worker_activity.py b/tests/test_worker_activity.py new file mode 100644 index 0000000000000000000000000000000000000000..4498e5ff66c268cd65eee504351fefd59b4f8381 --- /dev/null +++ b/tests/test_worker_activity.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +import json + +import pytest +from apistar.exceptions import ErrorResponse + +from arkindex_worker.worker import ActivityState, Element + +# Common API calls for all workers +BASE_API_CALLS = [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", +] + + +def test_defaults(responses, mock_elements_worker): + """Test the default values from mocked calls""" + assert not mock_elements_worker.is_read_only + assert mock_elements_worker.features == { + "workers_activity": False, + "signup": False, + } + + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == BASE_API_CALLS + + +def test_feature_disabled(responses, mock_elements_worker): + """Test disabled calls do not trigger any API calls""" + assert not mock_elements_worker.is_read_only + + out = mock_elements_worker.update_activity( + Element({"id": "1234-deadbeef"}), ActivityState.Processed + ) + + assert out is None + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == BASE_API_CALLS + + +def test_readonly(responses, mock_elements_worker): + """Test readonly worker does not trigger any API calls""" + + # Setup the worker as read-only, but with workers_activity enabled + mock_elements_worker.worker_version_id = None + assert mock_elements_worker.is_read_only is True + mock_elements_worker.features["workers_activity"] = True + + out = mock_elements_worker.update_activity( + Element({"id": "1234-deadbeef"}), ActivityState.Processed + ) + + assert out is None + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == BASE_API_CALLS + + +def test_update_call(responses, mock_elements_worker): + """Test an update call with feature enabled triggers an API call""" + responses.add( + responses.PUT, + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/", + status=200, + json={ + "element_id": "1234-deadbeef", + "state": "processed", + }, + ) + + # Enable worker activity + mock_elements_worker.features["workers_activity"] = True + + out = mock_elements_worker.update_activity( + Element({"id": "1234-deadbeef"}), ActivityState.Processed + ) + + # Check the response received by worker + assert out == { + "element_id": "1234-deadbeef", + "state": "processed", + } + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == BASE_API_CALLS + [ + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/", + ] + + # Check the request sent by worker + assert json.loads(responses.calls[2].request.body) == { + "element_id": "1234-deadbeef", + "state": "processed", + } + + +@pytest.mark.parametrize( + "process_exception, final_state", + [ + # Successful process_element + (None, "processed"), + # Failures in process_element + ( + ErrorResponse(title="bad gateway", status_code=502, content="Bad gateway"), + "error", + ), + (ValueError("Something bad"), "error"), + (Exception("Any error"), "error"), + ], +) +def test_run( + monkeypatch, mock_elements_worker, responses, process_exception, final_state +): + """Check the normal runtime sends 2 API calls to update activity""" + # Disable second configure call from run() + monkeypatch.setattr(mock_elements_worker, "configure", lambda: None) + + # Mock elements + monkeypatch.setattr( + mock_elements_worker, + "list_elements", + lambda: [ + "1234-deadbeef", + ], + ) + responses.add( + responses.GET, + "http://testserver/api/v1/element/1234-deadbeef/", + status=200, + json={ + "id": "1234-deadbeef", + "type": "page", + "name": "Test Page n°1", + }, + ) + + # Mock Update activity + responses.add( + responses.PUT, + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/", + status=200, + json={ + "element_id": "1234-deadbeef", + "state": "started", + }, + ) + responses.add( + responses.PUT, + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/", + status=200, + json={ + "element_id": "1234-deadbeef", + "state": final_state, + }, + ) + + # Enable worker activity + assert mock_elements_worker.is_read_only is False + mock_elements_worker.features["workers_activity"] = True + + # Mock exception in process_element + if process_exception: + + def _err(): + raise process_exception + + monkeypatch.setattr(mock_elements_worker, "process_element", _err) + + # The worker stops because all elements failed ! + with pytest.raises(SystemExit): + mock_elements_worker.run() + else: + # Simply run the process + mock_elements_worker.run() + + assert len(responses.calls) == 5 + assert [call.request.url for call in responses.calls] == BASE_API_CALLS + [ + "http://testserver/api/v1/element/1234-deadbeef/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/activity/", + ] + + # Check the requests sent by worker + assert json.loads(responses.calls[3].request.body) == { + "element_id": "1234-deadbeef", + "state": "started", + } + assert json.loads(responses.calls[4].request.body) == { + "element_id": "1234-deadbeef", + "state": final_state, + }