diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index d6381b752b78137af75bd30ae41ba8538ad48c67..488bc8fe2503f10c160a5e643ae2baa205c033be 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -230,11 +230,8 @@ class ElementsWorker(BaseWorker): element, Element ), "element shouldn't be null and should be of type Element" assert type and isinstance( - type, str - ), "type shouldn't be null and should be of type str" - assert ( - type in TranscriptionType._value2member_map_ - ), "type should be an allowed transcription type" + type, TranscriptionType + ), "type shouldn't be null and should be of type TranscriptionType" assert text and isinstance( text, str ), "text shouldn't be null and should be of type str" @@ -247,14 +244,16 @@ class ElementsWorker(BaseWorker): id=element.id, body={ "text": text, - "type": type, + "type": type.value, "worker_version": self.worker_version_id, "score": score, }, ) self.report.add_transcription(element.id, type) - def create_classification(self, element, ml_class, confidence, high_confidence): + def create_classification( + self, element, ml_class, confidence, high_confidence=False + ): """ Create a classification on the given element through API """ @@ -291,11 +290,8 @@ class ElementsWorker(BaseWorker): name, str ), "name shouldn't be null and should be of type str" assert type and isinstance( - type, str - ), "type shouldn't be null and should be of type str" - assert ( - type in EntityType._value2member_map_ - ), "type should be an allowed entity type" + type, EntityType + ), "type shouldn't be null and should be of type EntityType" assert corpus and isinstance( corpus, str ), "corpus shouldn't be null and should be of type str" @@ -308,7 +304,7 @@ class ElementsWorker(BaseWorker): "CreateEntity", body={ "name": name, - "type": type, + "type": type.value, "metas": metas, "validated": validated, "corpus": corpus, diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py index 7f44936ff37a13984a46dfbc7dffe5bfe4c65e17..48baa6852dd6f53eff36c214a54fee1c0a9e094a 100644 --- a/tests/test_elements_worker.py +++ b/tests/test_elements_worker.py @@ -10,7 +10,7 @@ import pytest from apistar.exceptions import ErrorResponse from arkindex_worker.models import Element -from arkindex_worker.worker import ElementsWorker +from arkindex_worker.worker import ElementsWorker, EntityType, TranscriptionType def test_cli_default(monkeypatch): @@ -385,13 +385,16 @@ def test_create_transcription_wrong_element(): worker = ElementsWorker() with pytest.raises(AssertionError) as e: worker.create_transcription( - element=None, text="i am a line", type="line", score=0.42, + element=None, text="i am a line", type=TranscriptionType.Line, score=0.42, ) assert str(e.value) == "element shouldn't be null and should be of type Element" with pytest.raises(AssertionError) as e: worker.create_transcription( - element="not element type", text="i am a line", type="line", score=0.42, + element="not element type", + text="i am a line", + type=TranscriptionType.Line, + score=0.42, ) assert str(e.value) == "element shouldn't be null and should be of type Element" @@ -404,13 +407,17 @@ def test_create_transcription_wrong_type(): worker.create_transcription( element=elt, text="i am a line", type=None, score=0.42, ) - assert str(e.value) == "type shouldn't be null and should be of type str" + assert ( + str(e.value) == "type shouldn't be null and should be of type TranscriptionType" + ) with pytest.raises(AssertionError) as e: worker.create_transcription( element=elt, text="i am a line", type=1234, score=0.42, ) - assert str(e.value) == "type shouldn't be null and should be of type str" + assert ( + str(e.value) == "type shouldn't be null and should be of type TranscriptionType" + ) with pytest.raises(AssertionError) as e: worker.create_transcription( @@ -419,7 +426,9 @@ def test_create_transcription_wrong_type(): type="not_a_transcription_type", score=0.42, ) - assert str(e.value) == "type should be an allowed transcription type" + assert ( + str(e.value) == "type shouldn't be null and should be of type TranscriptionType" + ) def test_create_transcription_wrong_text(): @@ -428,13 +437,13 @@ def test_create_transcription_wrong_text(): with pytest.raises(AssertionError) as e: worker.create_transcription( - element=elt, text=None, type="line", score=0.42, + element=elt, text=None, type=TranscriptionType.Line, score=0.42, ) assert str(e.value) == "text shouldn't be null and should be of type str" with pytest.raises(AssertionError) as e: worker.create_transcription( - element=elt, text=1234, type="line", score=0.42, + element=elt, text=1234, type=TranscriptionType.Line, score=0.42, ) assert str(e.value) == "text shouldn't be null and should be of type str" @@ -445,7 +454,7 @@ def test_create_transcription_wrong_score(): with pytest.raises(AssertionError) as e: worker.create_transcription( - element=elt, text="i am a line", type="line", score=None, + element=elt, text="i am a line", type=TranscriptionType.Line, score=None, ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" @@ -453,7 +462,10 @@ def test_create_transcription_wrong_score(): with pytest.raises(AssertionError) as e: worker.create_transcription( - element=elt, text="i am a line", type="line", score="wrong score", + element=elt, + text="i am a line", + type=TranscriptionType.Line, + score="wrong score", ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" @@ -461,7 +473,7 @@ def test_create_transcription_wrong_score(): with pytest.raises(AssertionError) as e: worker.create_transcription( - element=elt, text="i am a line", type="line", score=0, + element=elt, text="i am a line", type=TranscriptionType.Line, score=0, ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" @@ -469,7 +481,7 @@ def test_create_transcription_wrong_score(): with pytest.raises(AssertionError) as e: worker.create_transcription( - element=elt, text="i am a line", type="line", score=2.00, + element=elt, text="i am a line", type=TranscriptionType.Line, score=2.00, ) assert ( str(e.value) == "score shouldn't be null and should be a float in [0..1] range" @@ -488,7 +500,7 @@ def test_create_transcription_api_error(responses): with pytest.raises(ErrorResponse): worker.create_transcription( - element=elt, text="i am a line", type="line", score=0.42, + element=elt, text="i am a line", type=TranscriptionType.Line, score=0.42, ) assert len(responses.calls) == 1 @@ -509,7 +521,7 @@ def test_create_transcription(responses): ) worker.create_transcription( - element=elt, text="i am a line", type="line", score=0.42, + element=elt, text="i am a line", type=TranscriptionType.Line, score=0.42, ) assert len(responses.calls) == 1 @@ -669,13 +681,17 @@ def test_create_entity_wrong_name(): worker = ElementsWorker() with pytest.raises(AssertionError) as e: worker.create_entity( - name=None, type="person", corpus="12341234-1234-1234-1234-123412341234", + name=None, + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", ) assert str(e.value) == "name shouldn't be null and should be of type str" with pytest.raises(AssertionError) as e: worker.create_entity( - name=1234, type="person", corpus="12341234-1234-1234-1234-123412341234", + name=1234, + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", ) assert str(e.value) == "name shouldn't be null and should be of type str" @@ -687,13 +703,13 @@ def test_create_entity_wrong_type(): worker.create_entity( name="Bob Bob", type=None, corpus="12341234-1234-1234-1234-123412341234", ) - assert str(e.value) == "type shouldn't be null and should be of type str" + assert str(e.value) == "type shouldn't be null and should be of type EntityType" with pytest.raises(AssertionError) as e: worker.create_entity( name="Bob Bob", type=1234, corpus="12341234-1234-1234-1234-123412341234", ) - assert str(e.value) == "type shouldn't be null and should be of type str" + assert str(e.value) == "type shouldn't be null and should be of type EntityType" with pytest.raises(AssertionError) as e: worker.create_entity( @@ -701,20 +717,20 @@ def test_create_entity_wrong_type(): type="not_an_entity_type", corpus="12341234-1234-1234-1234-123412341234", ) - assert str(e.value) == "type should be an allowed entity type" + assert str(e.value) == "type shouldn't be null and should be of type EntityType" def test_create_entity_wrong_corpus(): worker = ElementsWorker() with pytest.raises(AssertionError) as e: worker.create_entity( - name="Bob Bob", type="person", corpus=None, + name="Bob Bob", type=EntityType.Person, corpus=None, ) assert str(e.value) == "corpus shouldn't be null and should be of type str" with pytest.raises(AssertionError) as e: worker.create_entity( - name="Bob Bob", type="person", corpus=1234, + name="Bob Bob", type=EntityType.Person, corpus=1234, ) assert str(e.value) == "corpus shouldn't be null and should be of type str" @@ -724,7 +740,7 @@ def test_create_entity_wrong_metas(): with pytest.raises(AssertionError) as e: worker.create_entity( name="Bob Bob", - type="person", + type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", metas="wrong metas", ) @@ -736,7 +752,7 @@ def test_create_entity_wrong_validated(): with pytest.raises(AssertionError) as e: worker.create_entity( name="Bob Bob", - type="person", + type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", validated="wrong validated", ) @@ -753,7 +769,7 @@ def test_create_entity_api_error(responses): with pytest.raises(ErrorResponse): worker.create_entity( name="Bob Bob", - type="person", + type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", ) @@ -771,7 +787,9 @@ def test_create_entity(responses): ) worker.create_entity( - name="Bob Bob", type="person", corpus="12341234-1234-1234-1234-123412341234", + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", ) assert len(responses.calls) == 1