diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index ccdf2267bf3698a46da6a2069386ff4783bace9e..e0cbd7c511c7ce3066824bb952a61039a31f9c64 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -3,6 +3,7 @@ ElementsWorker methods for entities. """ +from operator import itemgetter from typing import Dict, List, Optional, TypedDict, Union from peewee import IntegrityError @@ -273,6 +274,10 @@ class EntityMixin(object): isinstance(confidence, float) and 0 <= confidence <= 1 ), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range" + assert len(entities) == len( + set(map(itemgetter("offset", "length", "name", "type_id"), entities)) + ), "entities should be unique" + if self.is_read_only: logger.warning( "Cannot create transcription entities in bulk as this worker is in read-only mode" diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 63a09da4225b4187f168f3687b251c6d8d93fc28..426c485cedb0c8954b7a79c97d157e778c3ba9b9 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -864,11 +864,34 @@ def test_create_transcription_entities_wrong_transcription( ) -@pytest.mark.parametrize("entities", (None, "not a list of entities", 1)) -def test_create_transcription_entities_wrong_entities(mock_elements_worker, entities): - with pytest.raises( - AssertionError, match="entities shouldn't be null and should be of type list" - ): +@pytest.mark.parametrize( + "entities, error", + ( + (None, "entities shouldn't be null and should be of type list"), + ( + "not a list of entities", + "entities shouldn't be null and should be of type list", + ), + (1, "entities shouldn't be null and should be of type list"), + ( + [ + { + "name": "A", + "type_id": "12341234-1234-1234-1234-123412341234", + "offset": 0, + "length": 1, + "confidence": 0.5, + } + ] + * 2, + "entities should be unique", + ), + ), +) +def test_create_transcription_entities_wrong_entities( + mock_elements_worker, entities, error +): + with pytest.raises(AssertionError, match=error): mock_elements_worker.create_transcription_entities( transcription=Transcription(id="transcription_id"), entities=entities,