diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index d9a9312c00cfc4742ccfb0bb49c0e345ce5b3e88..99b7ae8ab5661fa4d5b88ef9cca323b8b19b625d 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -659,7 +659,7 @@ class TranscriptionsPagination(PageNumberPagination): class ElementTranscriptions(ListAPIView): """ - List all transcriptions for an element, optionally filtered by type. + List all transcriptions for an element, optionally filtered by type or worker version id. Recursive parameter allow listing transcriptions on sub-elements, otherwise element fields in the response will be set to null. """ @@ -684,9 +684,13 @@ class ElementTranscriptions(ListAPIView): 'in': 'query', 'required': False, 'description': 'Recursively list transcriptions on sub-elements', - 'schema': { - 'type': 'boolean', - } + 'schema': {'type': 'boolean'} + }, { + 'name': 'worker_version', + 'in': 'query', + 'required': False, + 'description': 'Filter transcriptions by worker version', + 'schema': {'type': 'string', 'format': 'uuid'} }, ] } @@ -732,6 +736,10 @@ class ElementTranscriptions(ListAPIView): else: queryset = queryset.filter(element_id=element.id) + return queryset + + def filter_queryset(self, queryset): + # Filter by transcription type req_type = self.request.query_params.get('type') if req_type: try: @@ -739,6 +747,12 @@ class ElementTranscriptions(ListAPIView): queryset = queryset.filter(type=req_type) except ValueError: raise ValidationError({'type': 'Not a valid transcription type'}) + + # Filter by worker version + worker_version = self.request.query_params.get('worker_version') + if worker_version: + queryset = queryset.filter(worker_version_id=worker_version) + return queryset diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 0aff55e168409d7e62446a9e102521a5d5bfc5e7..1d3859269ef496707eaa53476532e79d10b36a0f 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -246,6 +246,7 @@ class TranscriptionSerializer(serializers.ModelSerializer): 'score', 'zone', 'source', + 'worker_version_id', ) diff --git a/arkindex/documents/tests/test_create_transcriptions.py b/arkindex/documents/tests/test_create_transcriptions.py index 81444eb8b6567f616fd46900e7261ea0008e5a17..29c8de81a6ba63d78b40468014c0c4ce57985d15 100644 --- a/arkindex/documents/tests/test_create_transcriptions.py +++ b/arkindex/documents/tests/test_create_transcriptions.py @@ -127,6 +127,7 @@ class TestTranscriptionCreate(FixtureAPITestCase): }, 'text': 'A perfect day in a perfect place', 'type': 'line', + 'worker_version_id': None, 'zone': None }) @@ -340,6 +341,7 @@ class TestTranscriptionCreate(FixtureAPITestCase): 'source': None, 'text': 'NEKUDOTAYIM', 'type': 'word', + 'worker_version_id': str(version.id), 'zone': None }) diff --git a/arkindex/documents/tests/test_edit_transcriptions.py b/arkindex/documents/tests/test_edit_transcriptions.py index e3bf5704fbb04dba0b4bfc8b695d90227ae69d3f..37b88f76f51cb51998f39b52dd3305620c029075 100644 --- a/arkindex/documents/tests/test_edit_transcriptions.py +++ b/arkindex/documents/tests/test_edit_transcriptions.py @@ -62,6 +62,7 @@ class TestEditTranscription(FixtureAPITestCase): }, 'text': 'A manual transcription', 'type': 'line', + 'worker_version_id': None, 'zone': None }) @@ -133,6 +134,7 @@ class TestEditTranscription(FixtureAPITestCase): }, 'text': 'a knight was living lonely', 'type': 'line', + 'worker_version_id': None, 'zone': None }) diff --git a/arkindex/documents/tests/test_transcriptions.py b/arkindex/documents/tests/test_transcriptions.py index e9bbc842995d24d6a7b904624ca01aac1bfb6669..58e625f6d22f3add1b0470a6043c484f78de597b 100644 --- a/arkindex/documents/tests/test_transcriptions.py +++ b/arkindex/documents/tests/test_transcriptions.py @@ -3,7 +3,9 @@ from rest_framework import status from arkindex.project.tests import FixtureAPITestCase from arkindex.project.polygon import Polygon from arkindex_common.enums import TranscriptionType +from arkindex_common.ml_tool import MLToolType from arkindex.documents.models import Corpus, DataSource +from arkindex.dataimport.models import Worker, WorkerVersion from arkindex.users.models import User @@ -25,6 +27,17 @@ class TestTranscriptions(FixtureAPITestCase): cls.private_read_user = User.objects.create_user('a@bc.de', 'a') cls.private_read_user.verified_email = True cls.private_read_user.save() + cls.repo = cls.user.credentials.get().repos.get() + cls.worker_version = WorkerVersion.objects.create( + worker=Worker.objects.create( + repository=cls.repo, + name='Test Worker', + slug='test_worker', + type=MLToolType.Classifier + ), + revision=cls.repo.revisions.get(), + configuration={"test": "test1"} + ) def test_list_transcriptions_read_right(self): # A read right on the element corpus is required to access transcriptions @@ -74,7 +87,7 @@ class TestTranscriptions(FixtureAPITestCase): def test_list_transcriptions_recursive(self): for i in range(1, 5): # Add 4 transcriptions on the page line - self.line.transcriptions.create(source_id=self.src.id, type=TranscriptionType.Line, text=f'Text {i}') + self.line.transcriptions.create(worker_version=self.worker_version, type=TranscriptionType.Line, text=f'Text {i}') for i in range(1, 5): # Add 4 transcribed line children zone, _ = self.page.zone.image.zones.get_or_create(polygon=Polygon.from_coords(0, 0, i + 1, i + 1)) @@ -97,20 +110,22 @@ class TestTranscriptions(FixtureAPITestCase): page_polygon = [[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]] line_polygon = [[400, 400], [400, 500], [500, 500], [500, 400], [400, 400]] self.assertCountEqual( - [(data['element']['type'], data['element']['zone']['polygon'], data['text']) for data in results], [ - ('page', page_polygon, 'PARIS'), - ('page', page_polygon, 'ROY'), - ('page', page_polygon, 'Lorem ipsum dolor sit amet'), - ('page', page_polygon, 'DATUM'), - ('text_line', line_polygon, 'Text 1'), - ('text_line', line_polygon, 'Text 2'), - ('text_line', line_polygon, 'Text 3'), - ('text_line', line_polygon, 'Text 4'), - ('text_line', [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]], 'Added text 1'), - ('text_line', [[0, 0], [0, 3], [3, 3], [3, 0], [0, 0]], 'Added text 2'), - ('text_line', [[0, 0], [0, 4], [4, 4], [4, 0], [0, 0]], 'Added text 3'), - ('text_line', [[0, 0], [0, 5], [5, 5], [5, 0], [0, 0]], 'Added text 4') + (data['element']['type'], data['worker_version_id'], data['element']['zone']['polygon'], data['text']) + for data in results + ], [ + ('page', None, page_polygon, 'PARIS'), + ('page', None, page_polygon, 'ROY'), + ('page', None, page_polygon, 'Lorem ipsum dolor sit amet'), + ('page', None, page_polygon, 'DATUM'), + ('text_line', str(self.worker_version.id), line_polygon, 'Text 1'), + ('text_line', str(self.worker_version.id), line_polygon, 'Text 2'), + ('text_line', str(self.worker_version.id), line_polygon, 'Text 3'), + ('text_line', str(self.worker_version.id), line_polygon, 'Text 4'), + ('text_line', None, [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]], 'Added text 1'), + ('text_line', None, [[0, 0], [0, 3], [3, 3], [3, 0], [0, 0]], 'Added text 2'), + ('text_line', None, [[0, 0], [0, 4], [4, 4], [4, 0], [0, 0]], 'Added text 3'), + ('text_line', None, [[0, 0], [0, 5], [5, 5], [5, 0], [0, 0]], 'Added text 4') ] ) @@ -132,3 +147,27 @@ class TestTranscriptions(FixtureAPITestCase): self.assertEqual(len(results), 4) for tr in results: self.assertEqual(tr.get('type'), 'line') + + def test_list_worker_version_transcriptions(self): + + for i in range(1, 5): + # Add 4 transcriptions on the page line with a specific worker_version + self.line.transcriptions.create( + type=TranscriptionType.Line, + text=f'Text {i}', + worker_version=self.worker_version + ) + + self.client.force_login(self.user) + + with self.assertNumQueries(12): + response = self.client.get( + reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}), + data={'recursive': 'true', 'worker_version': str(self.worker_version.id)} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + results = response.json()['results'] + self.assertEqual(len(results), 4) + for tr in results: + self.assertEqual(tr.get('worker_version_id'), str(self.worker_version.id))