diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index b9825eae73828da19fb3c6f3380e966bd57a908c..031b9da3fd6cb669f59f52e79b2456d24bc75051 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -532,6 +532,63 @@ class ElementsWorker(BaseWorker): f"Couldn't save created transcription in local cache: {e}" ) + def create_transcriptions(self, transcriptions): + """ + Create multiple transcriptions at once on existing elements through the API. + """ + + assert transcriptions and isinstance( + transcriptions, list + ), "transcriptions shouldn't be null and should be of type list" + + for index, transcription in enumerate(transcriptions): + element_id = transcription.get("element_id") + assert element_id and isinstance( + element_id, str + ), f"Transcription at index {index} in transcriptions: element_id shouldn't be null and should be of type str" + + text = transcription.get("text") + assert text and isinstance( + text, str + ), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str" + + score = transcription.get("score") + assert ( + score is not None and isinstance(score, float) and 0 <= score <= 1 + ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range" + + created_trs = self.api_client.request( + "CreateTranscriptions", + body={ + "worker_version": self.worker_version_id, + "transcriptions": transcriptions, + }, + )["transcriptions"] + + for created_tr in created_trs: + self.report.add_transcription(created_tr["element_id"]) + + if self.cache: + # Store transcriptions in local cache + try: + to_insert = [ + CachedTranscription( + id=convert_str_uuid_to_hex(created_tr["id"]), + element_id=convert_str_uuid_to_hex(created_tr["element_id"]), + text=created_tr["text"], + confidence=created_tr["confidence"], + worker_version_id=convert_str_uuid_to_hex( + self.worker_version_id + ), + ) + for created_tr in created_trs + ] + self.cache.insert("transcriptions", to_insert) + except sqlite3.IntegrityError as e: + logger.warning( + f"Couldn't save created transcriptions in local cache: {e}" + ) + def create_classification( self, element, ml_class, confidence, high_confidence=False ): diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 1fa9845e39feaf79ce1fef7ef9e321990967f5f2..3f66c63db9229a9a6cf9f974c84afc30938644c7 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -189,6 +189,349 @@ def test_create_transcription(responses, mock_elements_worker_with_cache): ] +def test_create_transcriptions_wrong_transcriptions(mock_elements_worker): + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=None, + ) + assert str(e.value) == "transcriptions shouldn't be null and should be of type list" + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=1234, + ) + assert str(e.value) == "transcriptions shouldn't be null and should be of type list" + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "text": "word", + "score": 0.5, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: element_id shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": None, + "text": "word", + "score": 0.5, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: element_id shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": 1234, + "text": "word", + "score": 0.5, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: element_id shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "score": 0.5, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": None, + "score": 0.5, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": 1234, + "score": 0.5, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": None, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": "a wrong score", + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": 0, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": 2.00, + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" + ) + + +def test_create_transcriptions_api_error(responses, mock_elements_worker): + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/bulk/", + status=500, + ) + trans = [ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": 0.42, + }, + ] + + with pytest.raises(ErrorResponse): + mock_elements_worker.create_transcriptions(transcriptions=trans) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/transcription/bulk/", + ] + + +def test_create_transcriptions(responses, mock_elements_worker_with_cache): + trans = [ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": 0.42, + }, + ] + + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/bulk/", + status=200, + json={ + "worker_version": "12341234-1234-1234-1234-123412341234", + "transcriptions": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "confidence": 0.75, + }, + { + "id": "11111111-1111-1111-1111-111111111111", + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "confidence": 0.42, + }, + ], + }, + ) + + mock_elements_worker_with_cache.create_transcriptions( + transcriptions=trans, + ) + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/transcription/bulk/", + ] + + assert json.loads(responses.calls[2].request.body) == { + "worker_version": "12341234-1234-1234-1234-123412341234", + "transcriptions": trans, + } + + # Check that created transcriptions were properly stored in SQLite cache + cache_path = f"{CACHE_DIR}/db.sqlite" + assert os.path.isfile(cache_path) + + rows = mock_elements_worker_with_cache.cache.cursor.execute( + "SELECT * FROM transcriptions" + ).fetchall() + assert [CachedTranscription(**dict(row)) for row in rows] == [ + CachedTranscription( + id=convert_str_uuid_to_hex("00000000-0000-0000-0000-000000000000"), + element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + text="The", + confidence=0.75, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), + CachedTranscription( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + text="word", + confidence=0.42, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), + ] + + def test_create_element_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions(