diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 4207a922f97d20cae53fccb3b25a81cf0c9f3fd0..a7828117e344ff001c08fe4adda5448b25d0806e 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -78,7 +78,7 @@ from arkindex.process.serializers.imports import ( StartProcessSerializer, ) from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer -from arkindex.process.serializers.training import StartTrainingSerializer +from arkindex.process.serializers.training import ProcessDatasetSerializer, StartTrainingSerializer from arkindex.process.serializers.worker_runs import WorkerRunEditSerializer, WorkerRunSerializer from arkindex.process.serializers.workers import ( DockerWorkerVersionSerializer, @@ -108,6 +108,8 @@ from arkindex.project.pagination import CustomCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.project.tools import PercentileCont, RTrimChr from arkindex.project.triggers import process_delete +from arkindex.training.models import Dataset +from arkindex.training.serializers import DatasetSerializer from arkindex.users.models import OAuthCredentials, Role, Scope from arkindex.users.utils import get_max_level @@ -652,6 +654,76 @@ class DataFileCreate(CreateAPIView): serializer_class = DataFileCreateSerializer +@extend_schema(tags=['process']) +@extend_schema_view( + get=extend_schema( + operation_id='ListProcessDatasets', + description=dedent( + """ + List all datasets on a process. + + Requires a **guest** access to the process. + """ + ), + ), +) +class ProcessDatasets(ProcessACLMixin, ListAPIView): + permission_classes = (IsVerified, ) + serializer_class = DatasetSerializer + queryset = Dataset.objects.none() + + @cached_property + def process(self): + process = get_object_or_404( + Process.objects.using('default'), + Q(pk=self.kwargs['pk']) + ) + if not self.process_access_level(process): + raise PermissionDenied(detail='You do not have guest access to this process.') + return process + + def get_queryset(self): + return self.process.datasets.select_related('creator').order_by('name') + + def get_serializer_context(self): + context = super().get_serializer_context() + # Ignore this step when generating the schema with OpenAPI + if not self.kwargs: + return context + context['process'] = self.process + return context + + +@extend_schema(tags=['process']) +@extend_schema_view( + post=extend_schema( + operation_id='CreateProcessDataset', + description=dedent( + """ + Add a dataset to a process. + + Requires an **admin** access to the process and a **guest** access to the dataset's corpus. + """ + ), + ), +) +class ProcessDataset(CreateAPIView): + permission_classes = (IsVerified, ) + serializer_class = ProcessDatasetSerializer + + def get_serializer_from_params(self, process=None, dataset=None, **kwargs): + data = {'process': process, 'dataset': dataset} + kwargs['context'] = self.get_serializer_context() + return ProcessDatasetSerializer(data=data, **kwargs) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer_from_params(**kwargs) + serializer.is_valid(raise_exception=True) + serializer.create(serializer.validated_data) + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + @extend_schema(exclude=True) class GitRepositoryImportHook(APIView): """ diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py index 243ee14a03a494ae45feef9660acafd2033ba6ab..11802e7a65b91fcdb1a0b80bf640788f1d0c623a 100644 --- a/arkindex/process/serializers/training.py +++ b/arkindex/process/serializers/training.py @@ -1,17 +1,19 @@ from rest_framework import serializers -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import PermissionDenied, ValidationError from arkindex.documents.models import Corpus, Element from arkindex.process.models import ( Process, + ProcessDataset, ProcessMode, WorkerConfiguration, WorkerVersion, WorkerVersionGPUUsage, WorkerVersionState, ) -from arkindex.project.mixins import TrainingModelMixin, WorkerACLMixin -from arkindex.training.models import Model, ModelVersion +from arkindex.project.mixins import ProcessACLMixin, TrainingModelMixin, WorkerACLMixin +from arkindex.training.models import Dataset, Model, ModelVersion +from arkindex.users.models import Role class StartTrainingSerializer(serializers.ModelSerializer, WorkerACLMixin, TrainingModelMixin): @@ -192,3 +194,30 @@ class StartTrainingSerializer(serializers.ModelSerializer, WorkerACLMixin, Train use_gpu=validated_data["use_gpu"], ) return self.instance + + +class ProcessDatasetSerializer(serializers.ModelSerializer, ProcessACLMixin): + process = serializers.PrimaryKeyRelatedField(queryset=Process.objects.using('default')) + dataset = serializers.PrimaryKeyRelatedField(queryset=Dataset.objects.none()) + + class Meta(): + model = ProcessDataset + fields = ('dataset', 'process', 'id', ) + read_only_fields = ('process', 'id', ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.context.get('request'): + # Do not raise Error in order to create OpenAPI schema + return + # Required for the ProcessACLMixin and readable corpora + self._user = self.context['request'].user + self.fields['dataset'].queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) + + def validate_process(self, process): + access = self.process_access_level(process) + if not access or not (access >= Role.Admin.value): + raise PermissionDenied(detail='You do not have admin access to this process.') + if process.mode != ProcessMode.Dataset: + raise ValidationError(detail='Datasets can only be added to processes of mode "dataset".') + return process diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..755dbfb0e73de70ca9de92a598ecc67cb4b5c40c --- /dev/null +++ b/arkindex/process/tests/test_process_datasets.py @@ -0,0 +1,214 @@ +import uuid +from unittest.mock import patch + +from django.urls import reverse +from rest_framework import status + +from arkindex.documents.models import Corpus +from arkindex.process.models import Process, ProcessDataset, ProcessMode +from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Dataset +from arkindex.users.models import Role, User + +# Using the fake DB fixtures creation date when needed +FAKE_CREATED = '2020-02-02T01:23:45.678000Z' + + +class TestProcessDatasets(FixtureAPITestCase): + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.private_corpus = Corpus.objects.create(name='Private corpus') + with patch('django.utils.timezone.now') as mock_now: + mock_now.return_value = FAKE_CREATED + cls.private_dataset = cls.private_corpus.datasets.create( + name='Dead sea scrolls', + description='Human instrumentality manual', + creator=cls.user + ) + cls.test_user = User.objects.create(email='katsuragi@nerv.co.jp', verified_email=True) + cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value) + + # Datasets from another corpus + cls.dataset1, cls.dataset2 = Dataset.objects.filter(corpus=cls.corpus).order_by('name') + + cls.dataset_process = Process.objects.create( + creator_id=cls.user.id, + mode=ProcessMode.Dataset, + corpus_id=cls.private_corpus.id + ) + cls.dataset_process.datasets.set([cls.dataset1, cls.private_dataset]) + + # Control process to check that its datasets are not retrieved + cls.dataset_process_2 = Process.objects.create( + creator_id=cls.user.id, + mode=ProcessMode.Dataset, + corpus_id=cls.corpus.id + ) + cls.dataset_process_2.datasets.set([cls.dataset2]) + + # For repository process + cls.creds = cls.user.credentials.get() + cls.repo = cls.creds.repos.get(url='http://my_repo.fake/workers/worker') + cls.repo.memberships.create(user=cls.test_user, level=Role.Admin.value) + cls.rev = cls.repo.revisions.get() + + def test_list_requires_login(self): + with self.assertNumQueries(0): + response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_list_process_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.get(reverse('api:process-datasets', kwargs={'pk': str(uuid.uuid4())})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_list_process_access_level(self): + self.private_corpus.memberships.filter(user=self.test_user).delete() + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {'detail': 'You do not have guest access to this process.'}) + + def test_list_process_datasets(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(9): + response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()['results'], [ + { + 'id': str(self.private_dataset.id), + 'name': 'Dead sea scrolls', + 'description': 'Human instrumentality manual', + 'creator': 'Test user', + 'sets': ['training', 'test', 'validation'], + 'corpus_id': str(self.private_corpus.id), + 'state': 'open', + 'task_id': None, + 'created': FAKE_CREATED, + 'updated': FAKE_CREATED + }, + { + 'id': str(self.dataset1.id), + 'name': 'First Dataset', + 'description': 'dataset number one', + 'creator': 'Test user', + 'sets': ['training', 'test', 'validation'], + 'corpus_id': str(self.corpus.id), + 'state': 'open', + 'task_id': None, + 'created': FAKE_CREATED, + 'updated': FAKE_CREATED + } + ]) + + def test_create_process_dataset_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}), + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_process_dataset_requires_verified(self): + unverified_user = User.objects.create(email='email@mail.com') + self.client.force_login(unverified_user) + with self.assertNumQueries(2): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}), + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_process_dataset_access_level(self): + cases = [None, Role.Guest, Role.Contributor] + for level in cases: + with self.subTest(level=level): + self.private_corpus.memberships.filter(user=self.test_user).delete() + if level: + self.private_corpus.memberships.create(user=self.test_user, level=level.value) + self.client.force_login(self.test_user) + with self.assertNumQueries(7): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}), + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.json(), {'detail': 'You do not have admin access to this process.'}) + + def test_create_process_dataset_process_mode(self): + cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local} + for mode in cases: + with self.subTest(mode=mode): + request_count = 8 + self.dataset_process.mode = mode + self.dataset_process.corpus = self.private_corpus + if mode == ProcessMode.Repository: + self.dataset_process.corpus = None + self.dataset_process.revision = self.rev + request_count = 10 + self.dataset_process.save() + self.client.force_login(self.test_user) + with self.assertNumQueries(request_count): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}), + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'process': ['Datasets can only be added to processes of mode "dataset".']}) + + def test_create_process_dataset_process_mode_local(self): + self.client.force_login(self.user) + local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': local_process.id, 'dataset': self.dataset2.id}), + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.json(), {'detail': 'You do not have admin access to this process.'}) + + def test_create_process_dataset_wrong_process_uuid(self): + self.client.force_login(self.test_user) + wrong_id = uuid.uuid4() + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': wrong_id, 'dataset': self.dataset2.id}), + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'process': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) + + def test_create_process_dataset_wrong_dataset_uuid(self): + self.client.force_login(self.test_user) + wrong_id = uuid.uuid4() + with self.assertNumQueries(8): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': wrong_id}), + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'dataset': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) + + def test_create_process_dataset_dataset_access(self): + new_corpus = Corpus.objects.create(name='NERV') + new_dataset = new_corpus.datasets.create(name='Eva series', description='We created the Evas from Adam', creator=self.user) + self.client.force_login(self.test_user) + with self.assertNumQueries(8): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': new_dataset.id}), + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'dataset': [f'Invalid pk "{str(new_dataset.id)}" - object does not exist.']}) + + def test_create_process_dataset(self): + self.client.force_login(self.test_user) + self.assertEqual(ProcessDataset.objects.count(), 3) + self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) + with self.assertNumQueries(9): + response = self.client.post( + reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}), + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ProcessDataset.objects.count(), 4) + self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) + self.assertQuerysetEqual(self.dataset_process.datasets.order_by('name'), [ + self.private_dataset, + self.dataset1, + self.dataset2 + ]) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index a315d4ae3f7271c7a272ac9248153c26a2e3099e..a545c81195939464562b42b55c7711e153da4efc 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -92,6 +92,8 @@ from arkindex.process.api import ( GitRepositoryImportHook, ImportTranskribus, ListProcessElements, + ProcessDataset, + ProcessDatasets, ProcessDetails, ProcessList, ProcessRetry, @@ -289,6 +291,8 @@ api = [ path('process/<uuid:pk>/clear/', ClearProcess.as_view(), name='clear-process'), path('process/training/', StartTraining.as_view(), name='process-training'), path('process/<uuid:pk>/select-failures/', SelectProcessFailures.as_view(), name='process-select-failures'), + path('process/<uuid:pk>/datasets/', ProcessDatasets.as_view(), name='process-datasets'), + path('process/<uuid:process>/dataset/<uuid:dataset>/', ProcessDataset.as_view(), name='process-dataset'), # ML models training path('modelversion/<uuid:pk>/', ModelVersionsRetrieve.as_view(), name='model-version-retrieve'),