diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 921fedc5a727914fc6c81b98228c5f17544e59a0..89d108e481f4f70203aedcfbf00ed9c88f573552 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -210,15 +210,15 @@ class EntityMixin(object): def list_corpus_entities( self, corpus: Corpus, - name: str, - parent: str, + name: str = None, + parent: str or Element = None, ): """ List all entities in a corpus This method does not support cache :param corpus Corpus: The corpus that contains the entities to list. - :param name str: uuid for filter entities by part of their name (case-insensitive) + :param name str: For filter entities by part of their name (case-insensitive) :param parent str: uuid for restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored. """ query_params = {} @@ -226,16 +226,21 @@ class EntityMixin(object): corpus, Corpus ), "corpus shouldn't be null and should be a Corpus" - assert name and isinstance( - name, str - ), "name shouldn't be null and should be of type str" - - assert parent and isinstance( - parent, str - ), "parent shouldn't be null and should be of type str" - - query_params["name"] = name - query_params["parent"] = parent + if name is not None: + assert name and isinstance(name, str), "name should be of type str" + query_params["name"] = name + + if parent is not None: + assert ( + parent + and isinstance(parent, str) + or parent + and isinstance(parent, Element) + ), "parent should be of type str or Element" + query_params["parent"] = parent + + if type(parent) == Element: + query_params["parent"] = parent.id return self.api_client.paginate( "ListCorpusEntities", id=corpus.id, **query_params diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 39b488ab57a285868beb5357a9ef278542dfde76..4742eaf2234bed36f608fc21ed6fc1795b76a5d0 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -688,20 +688,35 @@ def test_list_transcription_entities(fake_dummy_worker): assert len(fake_dummy_worker.api_client.responses) == 0 -def test_list_corpus_entities(fake_dummy_worker): +@pytest.mark.parametrize( + "name, parent", + [ + (None, None), + ("fake_name", "fake_parent"), + ("fake_name", Element({"id": "fake_parent_id"})), + (None, "fake_parent"), + ("fake_name", None), + ], +) +def test_list_corpus_entities(fake_dummy_worker, name, parent): corpus = Corpus({"id": "fake_corpus_id"}) - name = "fake_name" - parent = "fake_parent" + query_params = {} + if name is not None: + query_params["name"] = name + if parent is not None: + query_params["parent"] = parent + if type(parent) == Element: + query_params["parent"] = parent.id + fake_dummy_worker.api_client.add_response( "ListCorpusEntities", id=corpus.id, - name=name, - parent=parent, response={"id": "fake_entity_id"}, + **query_params, ) - assert fake_dummy_worker.list_corpus_entities(corpus, name, parent) == { + + assert fake_dummy_worker.list_corpus_entities(corpus, **query_params) == { "id": "fake_entity_id" } - assert len(fake_dummy_worker.api_client.history) == 1 assert len(fake_dummy_worker.api_client.responses) == 0