diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index d587433c94125eecd0f9634186f613e4e6a00042..7b561c3ebea5bd566ca378d3275ec015698b0537 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -10,7 +10,7 @@ from peewee import IntegrityError from arkindex_worker import logger from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity -from arkindex_worker.models import Element +from arkindex_worker.models import Element, Transcription class EntityType(Enum): @@ -174,3 +174,35 @@ class EntityMixin(object): f"Couldn't save created transcription entity in local cache: {e}" ) return transcription_ent + + def list_transcription_entities( + self, + transcription: Transcription, + worker_version: bool = None, + ): + """ + List existing entities on a transcription + This method does not support cache + + :param transcription Transcription: The transcription to list entities on. + :param worker_version str or bool: Restrict to entities created by a worker version with this UUID. Set to False to look for manually created transcriptions. + """ + query_params = {} + assert transcription and isinstance( + transcription, Transcription + ), "transcription shouldn't be null and should be a Transcription" + + if worker_version is not None: + assert isinstance( + worker_version, (str, bool) + ), "worker_version should be of type str or bool" + + if isinstance(worker_version, bool): + assert ( + worker_version is False + ), "if of type bool, worker_version can only be set to False" + query_params["worker_version"] = worker_version + + return self.api_client.paginate( + "ListTranscriptionEntities", id=transcription.id, **query_params + ) diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 99a7492c1463ffe8a5e68e5f9995903ad5da1ad9..57ea05af869c63bb4859b37dbf536e525d3aa1e0 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -11,7 +11,7 @@ from arkindex_worker.cache import ( CachedTranscription, CachedTranscriptionEntity, ) -from arkindex_worker.models import Element +from arkindex_worker.models import Element, Transcription from arkindex_worker.worker import EntityType from arkindex_worker.worker.transcription import TextOrientation @@ -669,3 +669,20 @@ def test_create_transcription_entity_with_confidence_with_cache( confidence=0.77, ) ] + + +def test_list_transcription_entities(fake_dummy_worker): + transcription = Transcription({"id": "fake_transcription_id"}) + worker_version = "worker_version_id" + fake_dummy_worker.api_client.add_response( + "ListTranscriptionEntities", + id=transcription.id, + worker_version=worker_version, + response={"id": "entity_id"}, + ) + assert fake_dummy_worker.list_transcription_entities( + transcription, worker_version + ) == {"id": "entity_id"} + + assert len(fake_dummy_worker.api_client.history) == 1 + assert len(fake_dummy_worker.api_client.responses) == 0