diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index c1899cdc69591cf8e08a6e44356e4d0ab38d765a..147b69c101eb7c6f117d380dc15c03993018ce86 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -123,6 +123,7 @@ from arkindex.training.api import ( CreateDatasetElementsSelection, DatasetElements, DatasetUpdate, + ElementDatasets, MetricValueBulkCreate, MetricValueCreate, ModelRetrieve, @@ -208,6 +209,7 @@ api = [ # Datasets path('corpus/<uuid:pk>/datasets/', CorpusDataset.as_view(), name='corpus-datasets'), path('corpus/<uuid:pk>/datasets/selection/', CreateDatasetElementsSelection.as_view(), name='dataset-elements-selection'), + path('element/<uuid:pk>/datasets/', ElementDatasets.as_view(), name='element-datasets'), path('datasets/<uuid:pk>/', DatasetUpdate.as_view(), name='dataset-update'), path('datasets/<uuid:pk>/elements/', DatasetElements.as_view(), name='dataset-elements'), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 24e8b9d0dd35c07893682e37c2ced3adad90ae26..336da0e422bdef987a889062afd7f124d63047a0 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -17,11 +17,11 @@ from rest_framework.generics import ( ) from rest_framework.response import Response -from arkindex.documents.models import Corpus +from arkindex.documents.models import Corpus, Element from arkindex.process.utils import annotate_image_url from arkindex.project.mixins import ACLMixin, CorpusACLMixin, TrainingModelMixin from arkindex.project.pagination import CustomCursorPagination -from arkindex.project.permissions import IsVerified +from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.training.models import ( Dataset, DatasetElement, @@ -36,6 +36,7 @@ from arkindex.training.serializers import ( DatasetElementSerializer, DatasetLightSerializer, DatasetSerializer, + ElementDatasetSerializer, MetricValueBulkSerializer, MetricValueCreateSerializer, ModelSerializer, @@ -631,3 +632,40 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView): def create(self, request, *args, **kwargs): super().create(request, *args, **kwargs) return Response(status=status.HTTP_204_NO_CONTENT) + + +@extend_schema( + tags=['datasets'], + parameters=[ + OpenApiParameter( + 'id', + type=UUID, + location=OpenApiParameter.PATH, + description='ID of the element.', + required=True, + ) + ], +) +class ElementDatasets(CorpusACLMixin, ListAPIView): + """ + List all datasets containing a specific element. + + Requires a **guest** access to the element's corpus. + """ + permission_classes = (IsVerifiedOrReadOnly, ) + serializer_class = ElementDatasetSerializer + + @cached_property + def element(self): + return get_object_or_404( + Element.objects.only('id'), + id=self.kwargs['pk'], + corpus__in=Corpus.objects.readable(self.request.user) + ) + + def get_queryset(self): + return ( + self.element.dataset_elements.all() + .order_by('dataset__name', 'set', 'dataset_id') + .values('dataset_id', 'dataset__name', 'set') + ) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 31d855c3710c1ad8e3199a533fae945adf3bf54a..faf60898dfe6853559cc3375c513d0cca203f6eb 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -448,6 +448,14 @@ class DatasetElementSerializer(serializers.ModelSerializer): read_only_fields = fields +class ElementDatasetSerializer(serializers.ModelSerializer): + dataset_name = serializers.CharField(max_length=100, source='dataset__name') + + class Meta: + model = DatasetElement + fields = ('dataset_id', 'dataset_name', 'set') + + class SelectionDatasetElementSerializer(serializers.Serializer): dataset_id = serializers.PrimaryKeyRelatedField( queryset=Dataset.objects.all(), diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 0b5790e6a9b5e87f32e29b38aa920a826f305aff..799d2ef956c87a91df17789970a0a25ab8113212 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -1112,3 +1112,67 @@ class TestDatasetsAPI(FixtureAPITestCase): ('training', 'Volume 1, page 1v'), ] ) + + def test_element_datasets_requires_read_access(self): + self.client.force_login(self.user) + private_elt = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug='t'), name='elt') + with self.assertNumQueries(5): + response = self.client.get(reverse('api:element-datasets', kwargs={'pk': private_elt.id})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_element_datasets_methods(self): + self.client.force_login(self.user) + forbidden_methods = ('post', 'patch', 'put', 'delete') + for method in forbidden_methods: + with self.subTest(method=method): + client_method = getattr(self.client, method) + response = client_method(reverse('api:element-datasets', kwargs={'pk': str(self.vol.id)})) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_element_datasets_public(self): + """ + A non authenticated user can list datasets of a public element + """ + self.dataset.dataset_elements.create(element=self.vol, set='train') + with self.assertNumQueries(3): + response = self.client.get(reverse('api:element-datasets', kwargs={'pk': str(self.vol.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + 'count': 1, + 'next': None, + 'number': 1, + 'previous': None, + 'results': [{ + 'dataset_id': str(self.dataset.id), + 'dataset_name': 'First Dataset', + 'set': 'train' + }] + }) + + def test_element_datasets(self): + self.client.force_login(self.user) + self.dataset.dataset_elements.create(element=self.page1, set='train') + self.dataset.dataset_elements.create(element=self.page1, set='validation') + self.dataset2.dataset_elements.create(element=self.page1, set='train') + with self.assertNumQueries(7): + response = self.client.get(reverse('api:element-datasets', kwargs={'pk': str(self.page1.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + 'count': 3, + 'next': None, + 'number': 1, + 'previous': None, + 'results': [{ + 'dataset_id': str(self.dataset.id), + 'dataset_name': 'First Dataset', + 'set': 'train' + }, { + 'dataset_id': str(self.dataset.id), + 'dataset_name': 'First Dataset', + 'set': 'validation' + }, { + 'dataset_id': str(self.dataset2.id), + 'dataset_name': 'Second Dataset', + 'set': 'train' + }] + })