diff --git a/arkindex/documents/api/entities.py b/arkindex/documents/api/entities.py index df8b042a43ebcab496059a9e93ed9c3e6a3833bf..cdd2a1fa7ed51d29511c1347eae4bccca1a3bf46 100644 --- a/arkindex/documents/api/entities.py +++ b/arkindex/documents/api/entities.py @@ -3,6 +3,7 @@ from uuid import UUID from django.core.exceptions import ValidationError from django.db.models import Q +from django.shortcuts import get_object_or_404 from drf_spectacular.utils import OpenApiExample, OpenApiParameter, extend_schema, extend_schema_view from rest_framework import permissions, serializers, status from rest_framework.exceptions import NotFound, PermissionDenied @@ -265,7 +266,9 @@ class TranscriptionEntityCreate(CreateAPIView): ) class TranscriptionEntities(ListAPIView): """ - List existing entities linked to a specific transcription + List all entities linked to a transcription, ordeded by their position on the text. + + A guest access is required on the corpus of the transcription's element. """ serializer_class = TranscriptionEntityDetailsSerializer # For OpenAPI type discovery: a transcription's ID is in the path @@ -285,11 +288,7 @@ class TranscriptionEntities(ListAPIView): return validated def get_queryset(self): - filters = { - 'transcription__element__corpus__in': Corpus.objects.readable(self.request.user), - 'transcription_id': self.kwargs['pk'], - } - + filters = {} errors = defaultdict() if 'worker_version' in self.request.query_params: try: @@ -304,12 +303,24 @@ class TranscriptionEntities(ListAPIView): except serializers.ValidationError as e: errors['entity_worker_version'] = e.detail if errors: - raise ValidationError(errors) + raise serializers.ValidationError(errors) + + transcription = get_object_or_404( + Transcription.objects.filter( + id=self.kwargs['pk'], + element__corpus__in=Corpus.objects.readable(self.request.user), + ).only("id") + ) - return TranscriptionEntity.objects \ - .filter(**filters) \ - .order_by('offset') \ - .prefetch_related('entity') + return ( + TranscriptionEntity.objects + .filter( + transcription=transcription, + **filters, + ) + .order_by('offset') + .select_related('entity') + ) @extend_schema_view( diff --git a/arkindex/documents/tests/test_entities_api.py b/arkindex/documents/tests/test_entities_api.py index 3934222f4c70404421afa9cc459ecce37c679fff..5e6c74a1cace13b04773684f2cd5b36760a44f64 100644 --- a/arkindex/documents/tests/test_entities_api.py +++ b/arkindex/documents/tests/test_entities_api.py @@ -1,7 +1,6 @@ import uuid from django.contrib.gis.geos import LinearRing -from django.core.exceptions import ValidationError from django.urls import reverse from rest_framework import status @@ -757,8 +756,7 @@ class TestEntitiesAPI(FixtureAPITestCase): ) self.client.force_login(self.user) response = self.client.get(reverse('api:transcription-entities', kwargs={'pk': str(self.transcription.id)})) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.json()['count'], 0) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_list_transcription_entities(self): self.client.force_login(self.user) @@ -788,25 +786,49 @@ class TestEntitiesAPI(FixtureAPITestCase): }] ) + def test_list_transcription_entities_superuser(self): + self.client.force_login(self.superuser) + self.assertFalse(self.transcription.element.corpus.memberships.filter(user=self.superuser).exists()) + response = self.client.get(reverse('api:transcription-entities', kwargs={'pk': str(self.transcription.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertEqual(data['count'], 1) + self.assertEqual( + data['results'], + [{ + 'entity': { + 'id': str(self.entity_bis.id), + 'name': self.entity_bis.name, + 'type': self.entity_bis.type.value, + 'metas': None, + 'validated': self.entity_bis.validated, + 'dates': [], + 'worker_version_id': str(self.worker_version_2.id), + 'worker_run': None, + }, + 'length': self.transcriptionentity.length, + 'offset': self.transcriptionentity.offset, + 'worker_version_id': str(self.worker_version_1.id), + 'worker_run': None, + 'confidence': None + }], + ) + def test_list_transcription_entities_worker_version_validation(self): self.client.force_login(self.user) - with self.assertRaises(ValidationError): - with self.assertNumQueries(2): - response = self.client.get( - reverse('api:transcription-entities', kwargs={'pk': str(self.transcription.id)}), - data={'worker_version': 'blah'} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {'worker_version': ['Invalid UUID.']}) - - with self.assertRaises(ValidationError): - with self.assertNumQueries(3): - response = self.client.get( - reverse('api:transcription-entities', kwargs={'pk': str(self.transcription.id)}), - data={'worker_version': str(uuid.uuid4())} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {'worker_version': ['This worker version does not exist.']}) + parameters = [ + ('blah', 2, 'Invalid UUID.'), + (str(uuid.uuid4()), 3, 'This worker version does not exist.'), + ] + for value, num_queries, expected in parameters: + with self.subTest(value=value): + with self.assertNumQueries(num_queries): + response = self.client.get( + reverse('api:transcription-entities', kwargs={'pk': str(self.transcription.id)}), + data={'worker_version': value} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'worker_version': [expected]}) def test_list_transcription_entities_worker_version(self): TranscriptionEntity.objects.create( @@ -880,6 +902,22 @@ class TestEntitiesAPI(FixtureAPITestCase): }] ) + def test_list_transcription_entities_entity_worker_version_validation(self): + self.client.force_login(self.user) + parameters = [ + ('blah', 2, 'Invalid UUID.'), + (str(uuid.uuid4()), 3, 'This worker version does not exist.'), + ] + for value, num_queries, expected in parameters: + with self.subTest(value=value): + with self.assertNumQueries(num_queries): + response = self.client.get( + reverse('api:transcription-entities', kwargs={'pk': str(self.transcription.id)}), + data={'entity_worker_version': value} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'entity_worker_version': [expected]}) + def test_list_transcription_entities_entity_worker_version(self): TranscriptionEntity.objects.create( transcription=self.transcription,