diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 3f66c63db9229a9a6cf9f974c84afc30938644c7..e37e3ea9654f43c67f03794826f1618156addca3 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -28,6 +28,26 @@ TRANSCRIPTIONS_SAMPLE = [ "text": "line", }, ] +TRANSCRIPTIONS_TO_INSERT = [ + CachedTranscription( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + element_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + text="Hello!", + confidence=0.42, + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedTranscription( + id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + element_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + text="How are you?", + confidence=0.42, + worker_version_id=convert_str_uuid_to_hex( + "90129012-9012-9012-9012-901290129012" + ), + ), +] def test_create_transcription_wrong_element(mock_elements_worker): @@ -1155,3 +1175,92 @@ def test_list_transcriptions(responses, mock_elements_worker): "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/", ] + + +def test_list_transcriptions_with_cache_unhandled_param( + responses, mock_elements_worker_with_cache +): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker_with_cache.list_transcriptions( + element=elt, element_type="page" + ) + assert ( + str(e.value) + == "When using the local cache, you can only filter by 'worker_version'" + ) + + +def test_list_transcriptions_with_cache_skip_recursive( + responses, mock_elements_worker_with_cache +): + # When the local cache is activated and the user defines the recursive filter, we should fallback to the API + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + trans = [ + { + "id": "0000", + "text": "hey", + "confidence": 0.42, + "worker_version_id": "56785678-5678-5678-5678-567856785678", + "element": None, + }, + ] + responses.add( + responses.GET, + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True", + status=200, + json={ + "count": 3, + "next": None, + "results": trans, + }, + ) + + for idx, transcription in enumerate( + mock_elements_worker_with_cache.list_transcriptions(element=elt, recursive=True) + ): + assert transcription == trans[idx] + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/transcriptions/?recursive=True", + ] + + +def test_list_transcriptions_with_cache(responses, mock_elements_worker_with_cache): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + for idx, transcription in enumerate( + mock_elements_worker_with_cache.list_transcriptions(element=elt) + ): + assert transcription == [] + + # Initialize SQLite cache with some transcriptions + mock_elements_worker_with_cache.cache.insert( + "transcriptions", TRANSCRIPTIONS_TO_INSERT + ) + + expected_tr = TRANSCRIPTIONS_TO_INSERT + + for idx, transcription in enumerate( + mock_elements_worker_with_cache.list_transcriptions(element=elt) + ): + assert transcription == expected_tr[idx] + + expected_tr = [TRANSCRIPTIONS_TO_INSERT[0]] + + for idx, transcription in enumerate( + mock_elements_worker_with_cache.list_transcriptions( + element=elt, worker_version="56785678-5678-5678-5678-567856785678" + ) + ): + assert transcription == expected_tr[idx] + + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + ]