diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 00ac21114ccc44b47346574823702238490d7a1b..c1899cdc69591cf8e08a6e44356e4d0ab38d765a 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -120,6 +120,7 @@ from arkindex.process.api import ( from arkindex.project.openapi import OpenApiSchemaView from arkindex.training.api import ( CorpusDataset, + CreateDatasetElementsSelection, DatasetElements, DatasetUpdate, MetricValueBulkCreate, @@ -206,6 +207,7 @@ api = [ # Datasets path('corpus/<uuid:pk>/datasets/', CorpusDataset.as_view(), name='corpus-datasets'), + path('corpus/<uuid:pk>/datasets/selection/', CreateDatasetElementsSelection.as_view(), name='dataset-elements-selection'), path('datasets/<uuid:pk>/', DatasetUpdate.as_view(), name='dataset-update'), path('datasets/<uuid:pk>/elements/', DatasetElements.as_view(), name='dataset-elements'), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 9aa97fb698162eafe4112675eeae7f2e1ea59368..24e8b9d0dd35c07893682e37c2ced3adad90ae26 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -41,6 +41,7 @@ from arkindex.training.serializers import ( ModelSerializer, ModelVersionSerializer, ModelVersionValidateSerializer, + SelectionDatasetElementSerializer, ) from arkindex.users.models import Role from arkindex.users.utils import get_max_level @@ -579,3 +580,54 @@ class DatasetElements(CorpusACLMixin, ListAPIView): 'element__mirrored', ) ) + + +@extend_schema(tags=['datasets']) +@extend_schema_view( + post=extend_schema( + operation_id='CreateDatasetElementsSelection', + responses={ + 204: None, + 400: None, + 403: None, + }, + parameters=[ + OpenApiParameter( + 'id', + type=UUID, + location=OpenApiParameter.PATH, + description='ID of the corpus containing the selected elements.', + required=True, + ) + ], + ) +) +class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView): + """ + Add elements from a selection to a dataset. + + Requires a **contributor** access to the corpus. + """ + permission_classes = (IsVerified, ) + serializer_class = SelectionDatasetElementSerializer + + def get_queryset(self): + return Corpus.objects.readable(self.request.user) + + def check_object_permissions(self, request, corpus): + if not self.has_write_access(corpus): + raise PermissionDenied(detail="You need a Contributor access to the corpus to perform this action.") + super().check_object_permissions(request, corpus) + + def get_serializer_context(self): + context = super().get_serializer_context() + context['corpus'] = self.corpus + return context + + @cached_property + def corpus(self): + return self.get_object() + + def create(self, request, *args, **kwargs): + super().create(request, *args, **kwargs) + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 1de611dca9af0b614322c9f5fab4355ecfcaabe1..31d855c3710c1ad8e3199a533fae945adf3bf54a 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -446,3 +446,52 @@ class DatasetElementSerializer(serializers.ModelSerializer): model = DatasetElement fields = ('set', 'element') read_only_fields = fields + + +class SelectionDatasetElementSerializer(serializers.Serializer): + dataset_id = serializers.PrimaryKeyRelatedField( + queryset=Dataset.objects.all(), + source='dataset', + write_only=True, + help_text="UUID of a dataset to add elements from your corpus' selection.", + style={'base_template': 'input.html'}, + ) + set = serializers.CharField( + max_length=50, + write_only=True, + help_text='Name of the set elements will be added to.', + ) + + def validate_dataset_id(self, dataset): + if ( + (corpus := self.context.get('corpus')) + and dataset.corpus_id != corpus.id + ): + raise ValidationError(f'Dataset {dataset.id} is not part of corpus {corpus.name}.') + if dataset.state == DatasetState.Complete: + raise ValidationError(f'Dataset {dataset.id} is marked as completed.') + return dataset + + def validate(self, data): + data = super().validate(data) + dataset = data['dataset'] + if data['set'] not in dataset.sets: + raise ValidationError({'set': [f'This dataset only allows one of {", ".join(dataset.sets)}.']}) + return data + + def create(self, validated_data): + user = self.context['request'].user + corpus = self.context['corpus'] + + DatasetElement.objects.bulk_create( + ( + DatasetElement(element_id=elt_id, **validated_data) + for elt_id in ( + user.selected_elements + .filter(corpus=corpus) + .values_list("id", flat=True) + ) + ), + ignore_conflicts=True + ) + return validated_data diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 9d663a05a4a95e334b30745f29585389a99d982c..0b5790e6a9b5e87f32e29b38aa920a826f305aff 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -10,7 +10,7 @@ from arkindex.process.models import Process, ProcessMode from arkindex.project.tests import FixtureAPITestCase from arkindex.project.tools import fake_now from arkindex.training.models import Dataset, DatasetState -from arkindex.users.models import User +from arkindex.users.models import Role, User # Using the fake DB fixtures creation date when needed FAKE_CREATED = '2020-02-02T01:23:45.678000Z' @@ -977,3 +977,138 @@ class TestDatasetsAPI(FixtureAPITestCase): 'rotation_angle': 0, }, }]) + + def test_add_from_selection_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post(reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {'detail': 'Authentication credentials were not provided.'}) + + def test_add_from_selection_forbidden_methods(self): + self.client.force_login(self.user) + forbidden_methods = ('get', 'patch', 'put', 'delete') + for method in forbidden_methods: + with self.subTest(method=method): + response = getattr(self.client, method)( + reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id}) + ) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_add_from_selection_requires_verified(self): + user = User.objects.create(email='not_verified@mail.com', display_name='Not Verified', verified_email=False) + self.client.force_login(user) + with self.assertNumQueries(2): + response = self.client.post(reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {'detail': 'You do not have permission to perform this action.'}) + + def test_add_from_selection_private_corpus(self): + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post( + reverse('api:dataset-elements-selection', kwargs={'pk': self.private_corpus.id}) + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_add_from_selection_requires_writable_corpus(self): + self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value) + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post( + reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), { + 'detail': 'You need a Contributor access to the corpus to perform this action.' + }) + + def test_add_from_selection_required_fields(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post(reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id})) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'dataset_id': ['This field is required.'], + 'set': ['This field is required.'], + }) + + def test_add_from_selection_wrong_values(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id}), + data={'set': {}, 'dataset_id': 'AAA'}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'dataset_id': ['“AAA†is not a valid UUID.'], + 'set': ['Not a valid string.'], + }) + + def test_add_from_selection_wrong_dataset(self): + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id}), + data={'set': 'aaa', 'dataset_id': self.private_dataset.id}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'dataset_id': [f'Dataset {self.private_dataset.id} is not part of corpus Unit Tests.'], + }) + + def test_add_from_selection_completed_dataset(self): + """A dataset in the Complete state is immutable""" + self.client.force_login(self.user) + self.dataset.state = DatasetState.Complete + self.dataset.save() + with self.assertNumQueries(7): + response = self.client.post( + reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id}), + data={'set': 'aaa', 'dataset_id': self.dataset.id}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'dataset_id': [f'Dataset {self.dataset.id} is marked as completed.'] + }) + + def test_add_from_selection_wrong_set(self): + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id}), + data={'set': 'aaa', 'dataset_id': self.dataset.id}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'set': ['This dataset only allows one of training, test, validation.'], + }) + + def test_add_from_selection(self): + self.dataset.dataset_elements.create(element=self.page1, set="training") + self.assertQuerysetEqual( + self.dataset.dataset_elements.values_list('set', 'element__name').order_by('element__name'), + [('training', 'Volume 1, page 1r')] + ) + self.user.selected_elements.set([self.vol, self.page1, self.page2]) + + self.client.force_login(self.user) + with self.assertNumQueries(9): + response = self.client.post( + reverse('api:dataset-elements-selection', kwargs={'pk': self.corpus.id}), + data={'set': 'training', 'dataset_id': self.dataset.id}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertQuerysetEqual( + self.dataset.dataset_elements.values_list('set', 'element__name').order_by('element__name'), + [ + ('training', 'Volume 1'), + ('training', 'Volume 1, page 1r'), + ('training', 'Volume 1, page 1v'), + ] + )