From 9b8e7fe592e8c98534b96c663e1ed048b5ad8099 Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Wed, 19 Jul 2023 09:01:40 +0000
Subject: [PATCH] Dataset process list create

---
 arkindex/process/api.py                       |  74 +++++-
 arkindex/process/serializers/training.py      |  35 ++-
 .../process/tests/test_process_datasets.py    | 214 ++++++++++++++++++
 arkindex/project/api_v1.py                    |   4 +
 4 files changed, 323 insertions(+), 4 deletions(-)
 create mode 100644 arkindex/process/tests/test_process_datasets.py

diff --git a/arkindex/process/api.py b/arkindex/process/api.py
index 4207a922f9..a7828117e3 100644
--- a/arkindex/process/api.py
+++ b/arkindex/process/api.py
@@ -78,7 +78,7 @@ from arkindex.process.serializers.imports import (
     StartProcessSerializer,
 )
 from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer
-from arkindex.process.serializers.training import StartTrainingSerializer
+from arkindex.process.serializers.training import ProcessDatasetSerializer, StartTrainingSerializer
 from arkindex.process.serializers.worker_runs import WorkerRunEditSerializer, WorkerRunSerializer
 from arkindex.process.serializers.workers import (
     DockerWorkerVersionSerializer,
@@ -108,6 +108,8 @@ from arkindex.project.pagination import CustomCursorPagination
 from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
 from arkindex.project.tools import PercentileCont, RTrimChr
 from arkindex.project.triggers import process_delete
+from arkindex.training.models import Dataset
+from arkindex.training.serializers import DatasetSerializer
 from arkindex.users.models import OAuthCredentials, Role, Scope
 from arkindex.users.utils import get_max_level
 
@@ -652,6 +654,76 @@ class DataFileCreate(CreateAPIView):
     serializer_class = DataFileCreateSerializer
 
 
+@extend_schema(tags=['process'])
+@extend_schema_view(
+    get=extend_schema(
+        operation_id='ListProcessDatasets',
+        description=dedent(
+            """
+            List all datasets on a process.
+
+            Requires a **guest** access to the process.
+            """
+        ),
+    ),
+)
+class ProcessDatasets(ProcessACLMixin, ListAPIView):
+    permission_classes = (IsVerified, )
+    serializer_class = DatasetSerializer
+    queryset = Dataset.objects.none()
+
+    @cached_property
+    def process(self):
+        process = get_object_or_404(
+            Process.objects.using('default'),
+            Q(pk=self.kwargs['pk'])
+        )
+        if not self.process_access_level(process):
+            raise PermissionDenied(detail='You do not have guest access to this process.')
+        return process
+
+    def get_queryset(self):
+        return self.process.datasets.select_related('creator').order_by('name')
+
+    def get_serializer_context(self):
+        context = super().get_serializer_context()
+        # Ignore this step when generating the schema with OpenAPI
+        if not self.kwargs:
+            return context
+        context['process'] = self.process
+        return context
+
+
+@extend_schema(tags=['process'])
+@extend_schema_view(
+    post=extend_schema(
+        operation_id='CreateProcessDataset',
+        description=dedent(
+            """
+            Add a dataset to a process.
+
+            Requires an **admin** access to the process and a **guest** access to the dataset's corpus.
+            """
+        ),
+    ),
+)
+class ProcessDataset(CreateAPIView):
+    permission_classes = (IsVerified, )
+    serializer_class = ProcessDatasetSerializer
+
+    def get_serializer_from_params(self, process=None, dataset=None, **kwargs):
+        data = {'process': process, 'dataset': dataset}
+        kwargs['context'] = self.get_serializer_context()
+        return ProcessDatasetSerializer(data=data, **kwargs)
+
+    def create(self, request, *args, **kwargs):
+        serializer = self.get_serializer_from_params(**kwargs)
+        serializer.is_valid(raise_exception=True)
+        serializer.create(serializer.validated_data)
+        headers = self.get_success_headers(serializer.data)
+        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+
+
 @extend_schema(exclude=True)
 class GitRepositoryImportHook(APIView):
     """
diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py
index 243ee14a03..11802e7a65 100644
--- a/arkindex/process/serializers/training.py
+++ b/arkindex/process/serializers/training.py
@@ -1,17 +1,19 @@
 from rest_framework import serializers
-from rest_framework.exceptions import ValidationError
+from rest_framework.exceptions import PermissionDenied, ValidationError
 
 from arkindex.documents.models import Corpus, Element
 from arkindex.process.models import (
     Process,
+    ProcessDataset,
     ProcessMode,
     WorkerConfiguration,
     WorkerVersion,
     WorkerVersionGPUUsage,
     WorkerVersionState,
 )
-from arkindex.project.mixins import TrainingModelMixin, WorkerACLMixin
-from arkindex.training.models import Model, ModelVersion
+from arkindex.project.mixins import ProcessACLMixin, TrainingModelMixin, WorkerACLMixin
+from arkindex.training.models import Dataset, Model, ModelVersion
+from arkindex.users.models import Role
 
 
 class StartTrainingSerializer(serializers.ModelSerializer, WorkerACLMixin, TrainingModelMixin):
@@ -192,3 +194,30 @@ class StartTrainingSerializer(serializers.ModelSerializer, WorkerACLMixin, Train
             use_gpu=validated_data["use_gpu"],
         )
         return self.instance
+
+
+class ProcessDatasetSerializer(serializers.ModelSerializer, ProcessACLMixin):
+    process = serializers.PrimaryKeyRelatedField(queryset=Process.objects.using('default'))
+    dataset = serializers.PrimaryKeyRelatedField(queryset=Dataset.objects.none())
+
+    class Meta():
+        model = ProcessDataset
+        fields = ('dataset', 'process', 'id', )
+        read_only_fields = ('process', 'id', )
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        if not self.context.get('request'):
+            # Do not raise Error in order to create OpenAPI schema
+            return
+        # Required for the ProcessACLMixin and readable corpora
+        self._user = self.context['request'].user
+        self.fields['dataset'].queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user))
+
+    def validate_process(self, process):
+        access = self.process_access_level(process)
+        if not access or not (access >= Role.Admin.value):
+            raise PermissionDenied(detail='You do not have admin access to this process.')
+        if process.mode != ProcessMode.Dataset:
+            raise ValidationError(detail='Datasets can only be added to processes of mode "dataset".')
+        return process
diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py
new file mode 100644
index 0000000000..755dbfb0e7
--- /dev/null
+++ b/arkindex/process/tests/test_process_datasets.py
@@ -0,0 +1,214 @@
+import uuid
+from unittest.mock import patch
+
+from django.urls import reverse
+from rest_framework import status
+
+from arkindex.documents.models import Corpus
+from arkindex.process.models import Process, ProcessDataset, ProcessMode
+from arkindex.project.tests import FixtureAPITestCase
+from arkindex.training.models import Dataset
+from arkindex.users.models import Role, User
+
+# Using the fake DB fixtures creation date when needed
+FAKE_CREATED = '2020-02-02T01:23:45.678000Z'
+
+
+class TestProcessDatasets(FixtureAPITestCase):
+    @classmethod
+    def setUpTestData(cls):
+        super().setUpTestData()
+        cls.private_corpus = Corpus.objects.create(name='Private corpus')
+        with patch('django.utils.timezone.now') as mock_now:
+            mock_now.return_value = FAKE_CREATED
+            cls.private_dataset = cls.private_corpus.datasets.create(
+                name='Dead sea scrolls',
+                description='Human instrumentality manual',
+                creator=cls.user
+            )
+        cls.test_user = User.objects.create(email='katsuragi@nerv.co.jp', verified_email=True)
+        cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value)
+
+        # Datasets from another corpus
+        cls.dataset1, cls.dataset2 = Dataset.objects.filter(corpus=cls.corpus).order_by('name')
+
+        cls.dataset_process = Process.objects.create(
+            creator_id=cls.user.id,
+            mode=ProcessMode.Dataset,
+            corpus_id=cls.private_corpus.id
+        )
+        cls.dataset_process.datasets.set([cls.dataset1, cls.private_dataset])
+
+        # Control process to check that its datasets are not retrieved
+        cls.dataset_process_2 = Process.objects.create(
+            creator_id=cls.user.id,
+            mode=ProcessMode.Dataset,
+            corpus_id=cls.corpus.id
+        )
+        cls.dataset_process_2.datasets.set([cls.dataset2])
+
+        # For repository process
+        cls.creds = cls.user.credentials.get()
+        cls.repo = cls.creds.repos.get(url='http://my_repo.fake/workers/worker')
+        cls.repo.memberships.create(user=cls.test_user, level=Role.Admin.value)
+        cls.rev = cls.repo.revisions.get()
+
+    def test_list_requires_login(self):
+        with self.assertNumQueries(0):
+            response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id}))
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_list_process_does_not_exist(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(3):
+            response = self.client.get(reverse('api:process-datasets', kwargs={'pk': str(uuid.uuid4())}))
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_list_process_access_level(self):
+        self.private_corpus.memberships.filter(user=self.test_user).delete()
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(6):
+            response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id}))
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {'detail': 'You do not have guest access to this process.'})
+
+    def test_list_process_datasets(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(9):
+            response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id}))
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(response.json()['results'], [
+            {
+                'id': str(self.private_dataset.id),
+                'name': 'Dead sea scrolls',
+                'description': 'Human instrumentality manual',
+                'creator': 'Test user',
+                'sets': ['training', 'test', 'validation'],
+                'corpus_id': str(self.private_corpus.id),
+                'state': 'open',
+                'task_id': None,
+                'created': FAKE_CREATED,
+                'updated': FAKE_CREATED
+            },
+            {
+                'id': str(self.dataset1.id),
+                'name': 'First Dataset',
+                'description': 'dataset number one',
+                'creator': 'Test user',
+                'sets': ['training', 'test', 'validation'],
+                'corpus_id': str(self.corpus.id),
+                'state': 'open',
+                'task_id': None,
+                'created': FAKE_CREATED,
+                'updated': FAKE_CREATED
+            }
+        ])
+
+    def test_create_process_dataset_requires_login(self):
+        with self.assertNumQueries(0):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_create_process_dataset_requires_verified(self):
+        unverified_user = User.objects.create(email='email@mail.com')
+        self.client.force_login(unverified_user)
+        with self.assertNumQueries(2):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_create_process_dataset_access_level(self):
+        cases = [None, Role.Guest, Role.Contributor]
+        for level in cases:
+            with self.subTest(level=level):
+                self.private_corpus.memberships.filter(user=self.test_user).delete()
+                if level:
+                    self.private_corpus.memberships.create(user=self.test_user, level=level.value)
+                self.client.force_login(self.test_user)
+                with self.assertNumQueries(7):
+                    response = self.client.post(
+                        reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
+                    )
+                    self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+                self.assertEqual(response.json(), {'detail': 'You do not have admin access to this process.'})
+
+    def test_create_process_dataset_process_mode(self):
+        cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local}
+        for mode in cases:
+            with self.subTest(mode=mode):
+                request_count = 8
+                self.dataset_process.mode = mode
+                self.dataset_process.corpus = self.private_corpus
+                if mode == ProcessMode.Repository:
+                    self.dataset_process.corpus = None
+                    self.dataset_process.revision = self.rev
+                    request_count = 10
+                self.dataset_process.save()
+                self.client.force_login(self.test_user)
+                with self.assertNumQueries(request_count):
+                    response = self.client.post(
+                        reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
+                    )
+                    self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+                self.assertEqual(response.json(), {'process': ['Datasets can only be added to processes of mode "dataset".']})
+
+    def test_create_process_dataset_process_mode_local(self):
+        self.client.force_login(self.user)
+        local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local)
+        with self.assertNumQueries(6):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': local_process.id, 'dataset': self.dataset2.id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertEqual(response.json(), {'detail': 'You do not have admin access to this process.'})
+
+    def test_create_process_dataset_wrong_process_uuid(self):
+        self.client.force_login(self.test_user)
+        wrong_id = uuid.uuid4()
+        with self.assertNumQueries(6):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': wrong_id, 'dataset': self.dataset2.id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.json(), {'process': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
+
+    def test_create_process_dataset_wrong_dataset_uuid(self):
+        self.client.force_login(self.test_user)
+        wrong_id = uuid.uuid4()
+        with self.assertNumQueries(8):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': wrong_id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.json(), {'dataset': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
+
+    def test_create_process_dataset_dataset_access(self):
+        new_corpus = Corpus.objects.create(name='NERV')
+        new_dataset = new_corpus.datasets.create(name='Eva series', description='We created the Evas from Adam', creator=self.user)
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(8):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': new_dataset.id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.json(), {'dataset': [f'Invalid pk "{str(new_dataset.id)}" - object does not exist.']})
+
+    def test_create_process_dataset(self):
+        self.client.force_login(self.test_user)
+        self.assertEqual(ProcessDataset.objects.count(), 3)
+        self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
+        with self.assertNumQueries(9):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        self.assertEqual(ProcessDataset.objects.count(), 4)
+        self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
+        self.assertQuerysetEqual(self.dataset_process.datasets.order_by('name'), [
+            self.private_dataset,
+            self.dataset1,
+            self.dataset2
+        ])
diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index a315d4ae3f..a545c81195 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -92,6 +92,8 @@ from arkindex.process.api import (
     GitRepositoryImportHook,
     ImportTranskribus,
     ListProcessElements,
+    ProcessDataset,
+    ProcessDatasets,
     ProcessDetails,
     ProcessList,
     ProcessRetry,
@@ -289,6 +291,8 @@ api = [
     path('process/<uuid:pk>/clear/', ClearProcess.as_view(), name='clear-process'),
     path('process/training/', StartTraining.as_view(), name='process-training'),
     path('process/<uuid:pk>/select-failures/', SelectProcessFailures.as_view(), name='process-select-failures'),
+    path('process/<uuid:pk>/datasets/', ProcessDatasets.as_view(), name='process-datasets'),
+    path('process/<uuid:process>/dataset/<uuid:dataset>/', ProcessDataset.as_view(), name='process-dataset'),
 
     # ML models training
     path('modelversion/<uuid:pk>/', ModelVersionsRetrieve.as_view(), name='model-version-retrieve'),
-- 
GitLab