diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 5e2b6cf64180ef804a7f7dbc332c894a3d54e3a6..15129bee346c6616b51774c7615c1597e1a66f24 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -193,3 +193,28 @@ class EntityMixin(object): return self.api_client.paginate( "ListTranscriptionEntities", id=transcription.id, **query_params ) + + def list_corpus_entities( + self, + name: str = None, + parent: Element = None, + ): + """ + List all entities in the worker's corpus + This method does not support cache + :param name str: Filter entities by part of their name (case-insensitive) + :param parent Element: Restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored. + """ + query_params = {} + + if name is not None: + assert name and isinstance(name, str), "name should be of type str" + query_params["name"] = name + + if parent is not None: + assert isinstance(parent, Element), "parent should be of type Element" + query_params["parent"] = parent.id + + return self.api_client.paginate( + "ListCorpusEntities", id=self.corpus_id, **query_params + ) diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 0d5d18ccb91bf47883bd41d36024d0275fda7af6..46a606306bd0be3004c69e863903f824c5b995ae 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -271,10 +271,9 @@ def test_database_arg_cache_missing_version_table( def test_load_corpus_classes_api_error(responses, mock_elements_worker): - mock_elements_worker.corpus_id = "12341234-1234-1234-1234-123412341234" responses.add( responses.GET, - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", status=500, ) @@ -291,33 +290,32 @@ def test_load_corpus_classes_api_error(responses, mock_elements_worker): # We do 5 retries ( "GET", - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ] assert not mock_elements_worker.classes def test_load_corpus_classes(responses, mock_elements_worker): - mock_elements_worker.corpus_id = "12341234-1234-1234-1234-123412341234" responses.add( responses.GET, - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", status=200, json={ "count": 3, @@ -348,11 +346,11 @@ def test_load_corpus_classes(responses, mock_elements_worker): ] == BASE_API_CALLS + [ ( "GET", - f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/", + "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ] assert mock_elements_worker.classes == { - "12341234-1234-1234-1234-123412341234": { + "11111111-1111-1111-1111-111111111111": { "good": "0000", "average": "1111", "bad": "2222", diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index fded79794d9adbadbb53d0df8d7f31ebb9dd595e..0df2fe4ab97dd749d129e731a68e33f92580c1a0 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -654,3 +654,60 @@ def test_list_transcription_entities(fake_dummy_worker): assert len(fake_dummy_worker.api_client.history) == 1 assert len(fake_dummy_worker.api_client.responses) == 0 + + +def test_list_corpus_entities(responses, mock_elements_worker): + corpus_id = "11111111-1111-1111-1111-111111111111" + responses.add( + responses.GET, + f"http://testserver/api/v1/corpus/{corpus_id}/entities/", + json={ + "count": 1, + "next": None, + "results": [ + { + "id": "fake_entity_id", + } + ], + }, + ) + + # list is required to actually do the request + assert list(mock_elements_worker.list_corpus_entities()) == [ + { + "id": "fake_entity_id", + } + ] + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "GET", + f"http://testserver/api/v1/corpus/{corpus_id}/entities/", + ), + ] + + +@pytest.mark.parametrize( + "wrong_name", + [ + 1234, + 12.5, + ], +) +def test_list_corpus_entities_wrong_name(mock_elements_worker, wrong_name): + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_corpus_entities(name=wrong_name) + assert str(e.value) == "name should be of type str" + + +@pytest.mark.parametrize( + "wrong_parent", + [{"id": "element_id"}, 12.5, "blabla"], +) +def test_list_corpus_entities_wrong_parent(mock_elements_worker, wrong_parent): + with pytest.raises(AssertionError) as e: + mock_elements_worker.list_corpus_entities(parent=wrong_parent) + assert str(e.value) == "parent should be of type Element"