diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 449e8ce3df433e6f6cf8c40e2e71a4eb63b96e7f..aa6e9a06a03ad71d7a9ed1167ceabe8df01570d8 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -109,16 +109,19 @@ class EntityMixin(object): ) return + body = { + "entity": entity, + "length": length, + "offset": offset, + "worker_version_id": self.worker_version_id, + } + if confidence is not None: + body["confidence"] = confidence + transcription_ent = self.request( "CreateTranscriptionEntity", id=transcription, - body={ - "entity": entity, - "length": length, - "offset": offset, - "worker_version_id": self.worker_version_id, - "confidence": confidence, - }, + body=body, ) # TODO: Report transcription entity creation diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 075f8e5b7d95ccab43c2d96e56480d1bea264f54..99a7492c1463ffe8a5e68e5f9995903ad5da1ad9 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -418,7 +418,7 @@ def test_create_transcription_entity_api_error(responses, mock_elements_worker): ] -def test_create_transcription_entity(responses, mock_elements_worker): +def test_create_transcription_entity_no_confidence(responses, mock_elements_worker): responses.add( responses.POST, "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", @@ -451,7 +451,83 @@ def test_create_transcription_entity(responses, mock_elements_worker): "offset": 5, "length": 10, "worker_version_id": "12341234-1234-1234-1234-123412341234", - "confidence": None, + } + + +def test_create_transcription_entity_with_confidence(responses, mock_elements_worker): + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + status=200, + json={ + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "confidence": 0.33, + }, + ) + + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + confidence=0.33, + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + ), + ] + assert json.loads(responses.calls[-1].request.body) == { + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + "confidence": 0.33, + } + + +def test_create_transcription_entity_confidence_none(responses, mock_elements_worker): + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + status=200, + json={ + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "confidence": None, + }, + ) + + mock_elements_worker.create_transcription_entity( + transcription="11111111-1111-1111-1111-111111111111", + entity="11111111-1111-1111-1111-111111111111", + offset=5, + length=10, + confidence=None, + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/", + ), + ] + assert json.loads(responses.calls[-1].request.body) == { + "entity": "11111111-1111-1111-1111-111111111111", + "offset": 5, + "length": 10, + "worker_version_id": "12341234-1234-1234-1234-123412341234", } @@ -509,7 +585,6 @@ def test_create_transcription_entity_with_cache( "offset": 5, "length": 10, "worker_version_id": "12341234-1234-1234-1234-123412341234", - "confidence": None, } # Check that created transcription entity was properly stored in SQLite cache