From 4d7f81fe7daf77a5192759197464318e405ca096 Mon Sep 17 00:00:00 2001
From: Valentin Rigal <rigal@teklia.com>
Date: Thu, 21 Sep 2023 14:20:10 +0000
Subject: [PATCH] List datasets of an element

---
 arkindex/project/api_v1.py                   |  2 +
 arkindex/training/api.py                     | 42 ++++++++++++-
 arkindex/training/serializers.py             |  8 +++
 arkindex/training/tests/test_datasets_api.py | 64 ++++++++++++++++++++
 4 files changed, 114 insertions(+), 2 deletions(-)

diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index c1899cdc69..147b69c101 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -123,6 +123,7 @@ from arkindex.training.api import (
     CreateDatasetElementsSelection,
     DatasetElements,
     DatasetUpdate,
+    ElementDatasets,
     MetricValueBulkCreate,
     MetricValueCreate,
     ModelRetrieve,
@@ -208,6 +209,7 @@ api = [
     # Datasets
     path('corpus/<uuid:pk>/datasets/', CorpusDataset.as_view(), name='corpus-datasets'),
     path('corpus/<uuid:pk>/datasets/selection/', CreateDatasetElementsSelection.as_view(), name='dataset-elements-selection'),
+    path('element/<uuid:pk>/datasets/', ElementDatasets.as_view(), name='element-datasets'),
     path('datasets/<uuid:pk>/', DatasetUpdate.as_view(), name='dataset-update'),
     path('datasets/<uuid:pk>/elements/', DatasetElements.as_view(), name='dataset-elements'),
 
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 24e8b9d0dd..336da0e422 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -17,11 +17,11 @@ from rest_framework.generics import (
 )
 from rest_framework.response import Response
 
-from arkindex.documents.models import Corpus
+from arkindex.documents.models import Corpus, Element
 from arkindex.process.utils import annotate_image_url
 from arkindex.project.mixins import ACLMixin, CorpusACLMixin, TrainingModelMixin
 from arkindex.project.pagination import CustomCursorPagination
-from arkindex.project.permissions import IsVerified
+from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
 from arkindex.training.models import (
     Dataset,
     DatasetElement,
@@ -36,6 +36,7 @@ from arkindex.training.serializers import (
     DatasetElementSerializer,
     DatasetLightSerializer,
     DatasetSerializer,
+    ElementDatasetSerializer,
     MetricValueBulkSerializer,
     MetricValueCreateSerializer,
     ModelSerializer,
@@ -631,3 +632,40 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
     def create(self, request, *args, **kwargs):
         super().create(request, *args, **kwargs)
         return Response(status=status.HTTP_204_NO_CONTENT)
+
+
+@extend_schema(
+    tags=['datasets'],
+    parameters=[
+        OpenApiParameter(
+            'id',
+            type=UUID,
+            location=OpenApiParameter.PATH,
+            description='ID of the element.',
+            required=True,
+        )
+    ],
+)
+class ElementDatasets(CorpusACLMixin, ListAPIView):
+    """
+    List all datasets containing a specific element.
+
+    Requires a **guest** access to the element's corpus.
+    """
+    permission_classes = (IsVerifiedOrReadOnly, )
+    serializer_class = ElementDatasetSerializer
+
+    @cached_property
+    def element(self):
+        return get_object_or_404(
+            Element.objects.only('id'),
+            id=self.kwargs['pk'],
+            corpus__in=Corpus.objects.readable(self.request.user)
+        )
+
+    def get_queryset(self):
+        return (
+            self.element.dataset_elements.all()
+            .order_by('dataset__name', 'set', 'dataset_id')
+            .values('dataset_id', 'dataset__name', 'set')
+        )
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 31d855c371..faf60898df 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -448,6 +448,14 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         read_only_fields = fields
 
 
+class ElementDatasetSerializer(serializers.ModelSerializer):
+    dataset_name = serializers.CharField(max_length=100, source='dataset__name')
+
+    class Meta:
+        model = DatasetElement
+        fields = ('dataset_id', 'dataset_name', 'set')
+
+
 class SelectionDatasetElementSerializer(serializers.Serializer):
     dataset_id = serializers.PrimaryKeyRelatedField(
         queryset=Dataset.objects.all(),
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 0b5790e6a9..799d2ef956 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -1112,3 +1112,67 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 ('training', 'Volume 1, page 1v'),
             ]
         )
+
+    def test_element_datasets_requires_read_access(self):
+        self.client.force_login(self.user)
+        private_elt = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug='t'), name='elt')
+        with self.assertNumQueries(5):
+            response = self.client.get(reverse('api:element-datasets', kwargs={'pk': private_elt.id}))
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_element_datasets_methods(self):
+        self.client.force_login(self.user)
+        forbidden_methods = ('post', 'patch', 'put', 'delete')
+        for method in forbidden_methods:
+            with self.subTest(method=method):
+                client_method = getattr(self.client, method)
+                response = client_method(reverse('api:element-datasets', kwargs={'pk': str(self.vol.id)}))
+                self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+    def test_element_datasets_public(self):
+        """
+        A non authenticated user can list datasets of a public element
+        """
+        self.dataset.dataset_elements.create(element=self.vol, set='train')
+        with self.assertNumQueries(3):
+            response = self.client.get(reverse('api:element-datasets', kwargs={'pk': str(self.vol.id)}))
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            'count': 1,
+            'next': None,
+            'number': 1,
+            'previous': None,
+            'results': [{
+                'dataset_id': str(self.dataset.id),
+                'dataset_name': 'First Dataset',
+                'set': 'train'
+            }]
+        })
+
+    def test_element_datasets(self):
+        self.client.force_login(self.user)
+        self.dataset.dataset_elements.create(element=self.page1, set='train')
+        self.dataset.dataset_elements.create(element=self.page1, set='validation')
+        self.dataset2.dataset_elements.create(element=self.page1, set='train')
+        with self.assertNumQueries(7):
+            response = self.client.get(reverse('api:element-datasets', kwargs={'pk': str(self.page1.id)}))
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            'count': 3,
+            'next': None,
+            'number': 1,
+            'previous': None,
+            'results': [{
+                'dataset_id': str(self.dataset.id),
+                'dataset_name': 'First Dataset',
+                'set': 'train'
+            }, {
+                'dataset_id': str(self.dataset.id),
+                'dataset_name': 'First Dataset',
+                'set': 'validation'
+            }, {
+                'dataset_id': str(self.dataset2.id),
+                'dataset_name': 'Second Dataset',
+                'set': 'train'
+            }]
+        })
-- 
GitLab