diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 4ab39d2f6c8fd5b7df84379b360568b8778bcce1..5c496b9fb06b29fcc1ed16ee145b0c52dc774173 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -688,3 +688,31 @@ class ElementsWorker(BaseWorker): return MANUAL_SLUG else: raise ValueError(f"Unable to get slug from: {ml_result}") + + def list_transcriptions( + self, element, element_type=None, recursive=None, worker_version=None + ): + """ + List transcriptions on an element + """ + assert element and isinstance( + element, Element + ), "element shouldn't be null and should be of type Element" + query_params = {} + if element_type: + assert isinstance(element_type, str), "element_type should be of type str" + query_params["element_type"] = element_type + if recursive is not None: + assert isinstance(recursive, bool), "recursive should be of type bool" + query_params["recursive"] = recursive + if worker_version: + assert isinstance( + worker_version, str + ), "worker_version should be of type str" + query_params["worker_version"] = worker_version + + transcriptions = self.api_client.paginate( + "ListTranscriptions", id=element.id, **query_params + ) + + return transcriptions diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py index c254bdf190bd01ffd65131e2a6380e62593d96e1..5af17a9846b6340dc4e5cbdc8de2e3ca7556f0a3 100644 --- a/tests/test_elements_worker.py +++ b/tests/test_elements_worker.py @@ -2047,3 +2047,119 @@ def test_get_ml_result_slug__fail(fake_dummy_worker, ml_result): with pytest.raises(ValueError) as excinfo: fake_dummy_worker.get_ml_result_slug(ml_result) assert str(excinfo.value).startswith("Unable to get slug from") + + +def test_list_transcriptions_wrong_element(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_transcriptions(element=None) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_transcriptions(element="not element type") + assert str(e.value) == "element shouldn't be null and should be of type Element" + + +def test_list_transcriptions_wrong_element_type(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_transcriptions( + element=elt, + element_type=1234, + ) + assert str(e.value) == "element_type should be of type str" + + +def test_list_transcriptions_wrong_recursive(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_transcriptions( + element=elt, + recursive="not bool", + ) + assert str(e.value) == "recursive should be of type bool" + + +def test_list_transcriptions_wrong_worker_version(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_transcriptions( + element=elt, + worker_version=1234, + ) + assert str(e.value) == "worker_version should be of type str" + + +def test_list_transcriptions_api_error(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.GET, + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + status=500, + ) + + with pytest.raises( + Exception, match="Stopping pagination as data will be incomplete" + ): + next(mock_elements_worker.list_transcriptions(element=elt)) + + assert len(responses.calls) == 6 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + # We do 5 retries + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + ] + + +def test_list_transcriptions(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + trans = [ + { + "id": "0000", + "text": "hey", + "confidence": 0.42, + "worker_version_id": "56785678-5678-5678-5678-567856785678", + "element": None, + }, + { + "id": "1111", + "text": "it's", + "confidence": 0.42, + "worker_version_id": "56785678-5678-5678-5678-567856785678", + "element": None, + }, + { + "id": "2222", + "text": "me", + "confidence": 0.42, + "worker_version_id": "56785678-5678-5678-5678-567856785678", + "element": None, + }, + ] + responses.add( + responses.GET, + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + status=200, + json={ + "count": 3, + "next": None, + "results": trans, + }, + ) + + for idx, transcription in enumerate( + mock_elements_worker.list_transcriptions(element=elt) + ): + assert transcription == trans[idx] + + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", + ]