diff --git a/arkindex/training/api.py b/arkindex/training/api.py index db6f07f48f75c8faf8043fb401505cf666c5c2b9..55b6ed38aaa01d344f4cb712e1cb69a0b5bf0765 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -654,29 +654,60 @@ class DatasetElementCursorPagination(CountCursorPagination): @extend_schema(tags=["datasets"]) -class DatasetElements(CorpusACLMixin, ListAPIView): - """ - List all elements in a dataset.\n\n - Requires a **guest** access to the dataset corpus. - """ +@extend_schema_view( + get=extend_schema( + operation_id="ListDatasetElements", + description=dedent(""" + List all elements in a dataset. + + Requires a **guest** access to the dataset corpus. + """), + ), + post=extend_schema( + operation_id="CreateDatasetElement", + description=dedent(""" + Add an element to a dataset in an **open** state. + + Requires a **contributor** access to the corpus of the dataset. + """), + ) +) +class DatasetElements(CorpusACLMixin, ListCreateAPIView): permission_classes = (IsVerified, ) queryset = DatasetElement.objects.none() serializer_class = DatasetElementSerializer pagination_class = DatasetElementCursorPagination - def get_queryset(self): - dataset = get_object_or_404( - Dataset.objects.select_related("corpus"), - id=self.kwargs["pk"], + @cached_property + def dataset(self): + qs = ( + Dataset.objects + .using("default") + .select_related("corpus") + .filter(corpus__in=Corpus.objects.readable(self.request.user)) ) - if not self.has_read_access(dataset.corpus): - raise PermissionDenied(detail="You do not have access to the corpus of this dataset") + dataset = get_object_or_404(qs, pk=self.kwargs["pk"]) + if self.request.method not in permissions.SAFE_METHODS: + if not self.has_write_access(dataset.corpus): + raise PermissionDenied(detail="You do not have contributor access to the corpus of this dataset.") + if dataset.state != DatasetState.Open: + raise ValidationError("You can only add elements to a dataset in an open state.") + return dataset + + def get_queryset(self): return ( - dataset.dataset_elements + self.dataset.dataset_elements .prefetch_related("element") .select_related("element__type", "element__corpus", "element__image__server") + .order_by("element__id") ) + def get_serializer_context(self): + context = super().get_serializer_context() + if self.request.method == "POST": + context["dataset"] = self.dataset + return context + @extend_schema_view( delete=extend_schema( diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index e5fc75f822d0e282d1717fb089a22c2a44d4238d..3d5f9e56a216ddba62e6ca8dd15faf590c90d5dc 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -42,8 +42,13 @@ def _corpus_from_context(serializer_field): return serializer_field.context.get("corpus") +def _dataset_from_context(serializer_field): + return serializer_field.context.get("dataset") + + _model_from_context.requires_context = True _corpus_from_context.requires_context = True +_dataset_from_context.requires_context = True class ModelLightSerializer(serializers.ModelSerializer): @@ -651,12 +656,44 @@ class DatasetSerializer(serializers.ModelSerializer): class DatasetElementSerializer(serializers.ModelSerializer): - element = ElementListSerializer(read_only=True, allow_null=False) + element = ElementListSerializer( + read_only=True, + allow_null=False, + ) + element_id = serializers.PrimaryKeyRelatedField( + queryset=Element.objects.none(), + style={"base_template": "input.html"}, + source="element", + write_only=True, + ) + dataset = serializers.HiddenField( + default=_dataset_from_context, + write_only=True, + ) class Meta: model = DatasetElement - fields = ("set", "element") - read_only_fields = fields + fields = ("set", "element", "element_id", "dataset") + read_only_fields = ("element",) + validators = [ + UniqueTogetherValidator( + queryset=DatasetElement.objects.all(), + fields=["dataset", "element_id", "set"], + message="This element is already part of this set.", + ) + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if dataset := self.context.get("dataset"): + self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus) + + def validate_set(self, value): + # The set must match the `sets` array defined at the dataset level + dataset = self.context["dataset"] + if dataset and value not in dataset.sets: + raise ValidationError(f"This dataset has no set named {value}.") + return value class DatasetElementInfoSerializer(DatasetElementSerializer): diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 42711839ddd3da9e5d8565fe9d22c5ae89b01245..7ddb449dfe582417246533c2f784e2197a50495d 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -1209,25 +1209,17 @@ class TestDatasetsAPI(FixtureAPITestCase): response = self.client.get(reverse("api:dataset-elements", kwargs={"pk": str(self.dataset.id)})) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_list_elements_forbidden_methods(self): - self.client.force_login(self.user) - forbidden_methods = ("post", "patch", "put", "delete") - for method in forbidden_methods: - with self.subTest(method=method): - response = getattr(self.client, method)(reverse("api:dataset-elements", kwargs={"pk": str(self.dataset.id)})) - self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - def test_list_elements_invalid_dataset_id(self): self.client.force_login(self.user) - with self.assertNumQueries(3): + with self.assertNumQueries(5): response = self.client.get(reverse("api:dataset-elements", kwargs={"pk": str(uuid.uuid4())})) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_list_elements_readable_corpus(self): self.client.force_login(self.user) - with self.assertNumQueries(6): + with self.assertNumQueries(5): response = self.client.get(reverse("api:dataset-elements", kwargs={"pk": str(self.private_dataset.id)})) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_list_elements(self): self.dataset.dataset_elements.create(element_id=self.vol.id, set="test") @@ -1416,18 +1408,132 @@ class TestDatasetsAPI(FixtureAPITestCase): } first_three = [value for key, value in sorted(elements.items())[:3]] + self.client.force_login(self.read_user) + for state in DatasetState: + # A dataset should be readable in any state + self.dataset.state = state + self.dataset.save() + with self.subTest(state=state): + with self.assertNumQueries(6): + response = self.client.get( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}), + {"page_size": 3}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertTrue("?cursor=" in data["next"]) + self.assertIsNone(data["count"]) + self.assertListEqual(data["results"], first_three) + + def test_add_element_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post(reverse("api:dataset-elements", kwargs={"pk": self.dataset.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."}) + def test_add_element_requires_verified(self): + user = User.objects.create(email="not_verified@mail.com", display_name="Not Verified", verified_email=False) + self.client.force_login(user) + with self.assertNumQueries(2): + response = self.client.post(reverse("api:dataset-elements", kwargs={"pk": self.dataset.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."}) + + def test_add_element_private_dataset(self): self.client.force_login(self.user) - with self.assertNumQueries(4): - response = self.client.get( - reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}), - {"page_size": 3}, + with self.assertNumQueries(5): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.private_dataset.id}) ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertTrue("?cursor=" in data["next"]) - self.assertIsNone(data["count"]) - self.assertListEqual(data["results"], first_three) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_add_element_requires_writable_corpus(self): + self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value) + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), { + "detail": "You do not have contributor access to the corpus of this dataset." + }) + + def test_add_element_required_fields(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post(reverse("api:dataset-elements", kwargs={"pk": self.dataset.id})) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "element_id": ["This field is required."], + "set": ["This field is required."], + }) + + def test_add_element_wrong_element(self): + element = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="folder")) + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "test", "element_id": str(element.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "element_id": [f'Invalid pk "{element.id}" - object does not exist.'], + }) + + def test_add_element_wrong_set(self): + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "aaaaaaaaaaa", "element_id": str(self.vol.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "set": ["This dataset has no set named aaaaaaaaaaa."], + }) + + def test_add_element_dataset_requires_open(self): + self.client.force_login(self.user) + self.dataset.state = DatasetState.Complete + self.dataset.save() + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "test", "element_id": str(self.vol.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertListEqual(response.json(), ["You can only add elements to a dataset in an open state."]) + + def test_add_element_already_exists(self): + self.dataset.dataset_elements.create(element=self.page1, set="test") + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "test", "element_id": str(self.page1.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"non_field_errors": ["This element is already part of this set."]}) + + def test_add_element(self): + self.client.force_login(self.user) + with self.assertNumQueries(13): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "training", "element_id": str(self.page1.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertQuerysetEqual( + self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"), + [("training", "Volume 1, page 1r")] + ) def test_add_from_selection_requires_login(self): with self.assertNumQueries(0):