Skip to content
Snippets Groups Projects
Commit de03a197 authored by Valentin Rigal's avatar Valentin Rigal Committed by Erwan Rouchet
Browse files

Add an element to a dataset

parent 5c12eb2d
No related branches found
No related tags found
1 merge request!2217Add an element to a dataset
......@@ -654,29 +654,60 @@ class DatasetElementCursorPagination(CountCursorPagination):
@extend_schema(tags=["datasets"])
class DatasetElements(CorpusACLMixin, ListAPIView):
"""
List all elements in a dataset.\n\n
Requires a **guest** access to the dataset corpus.
"""
@extend_schema_view(
get=extend_schema(
operation_id="ListDatasetElements",
description=dedent("""
List all elements in a dataset.
Requires a **guest** access to the dataset corpus.
"""),
),
post=extend_schema(
operation_id="CreateDatasetElement",
description=dedent("""
Add an element to a dataset in an **open** state.
Requires a **contributor** access to the corpus of the dataset.
"""),
)
)
class DatasetElements(CorpusACLMixin, ListCreateAPIView):
permission_classes = (IsVerified, )
queryset = DatasetElement.objects.none()
serializer_class = DatasetElementSerializer
pagination_class = DatasetElementCursorPagination
def get_queryset(self):
dataset = get_object_or_404(
Dataset.objects.select_related("corpus"),
id=self.kwargs["pk"],
@cached_property
def dataset(self):
qs = (
Dataset.objects
.using("default")
.select_related("corpus")
.filter(corpus__in=Corpus.objects.readable(self.request.user))
)
if not self.has_read_access(dataset.corpus):
raise PermissionDenied(detail="You do not have access to the corpus of this dataset")
dataset = get_object_or_404(qs, pk=self.kwargs["pk"])
if self.request.method not in permissions.SAFE_METHODS:
if not self.has_write_access(dataset.corpus):
raise PermissionDenied(detail="You do not have contributor access to the corpus of this dataset.")
if dataset.state != DatasetState.Open:
raise ValidationError("You can only add elements to a dataset in an open state.")
return dataset
def get_queryset(self):
return (
dataset.dataset_elements
self.dataset.dataset_elements
.prefetch_related("element")
.select_related("element__type", "element__corpus", "element__image__server")
.order_by("element__id")
)
def get_serializer_context(self):
context = super().get_serializer_context()
if self.request.method == "POST":
context["dataset"] = self.dataset
return context
@extend_schema_view(
delete=extend_schema(
......
......@@ -42,8 +42,13 @@ def _corpus_from_context(serializer_field):
return serializer_field.context.get("corpus")
def _dataset_from_context(serializer_field):
return serializer_field.context.get("dataset")
_model_from_context.requires_context = True
_corpus_from_context.requires_context = True
_dataset_from_context.requires_context = True
class ModelLightSerializer(serializers.ModelSerializer):
......@@ -651,12 +656,44 @@ class DatasetSerializer(serializers.ModelSerializer):
class DatasetElementSerializer(serializers.ModelSerializer):
element = ElementListSerializer(read_only=True, allow_null=False)
element = ElementListSerializer(
read_only=True,
allow_null=False,
)
element_id = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.none(),
style={"base_template": "input.html"},
source="element",
write_only=True,
)
dataset = serializers.HiddenField(
default=_dataset_from_context,
write_only=True,
)
class Meta:
model = DatasetElement
fields = ("set", "element")
read_only_fields = fields
fields = ("set", "element", "element_id", "dataset")
read_only_fields = ("element",)
validators = [
UniqueTogetherValidator(
queryset=DatasetElement.objects.all(),
fields=["dataset", "element_id", "set"],
message="This element is already part of this set.",
)
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if dataset := self.context.get("dataset"):
self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
def validate_set(self, value):
# The set must match the `sets` array defined at the dataset level
dataset = self.context["dataset"]
if dataset and value not in dataset.sets:
raise ValidationError(f"This dataset has no set named {value}.")
return value
class DatasetElementInfoSerializer(DatasetElementSerializer):
......
......@@ -1209,25 +1209,17 @@ class TestDatasetsAPI(FixtureAPITestCase):
response = self.client.get(reverse("api:dataset-elements", kwargs={"pk": str(self.dataset.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_list_elements_forbidden_methods(self):
self.client.force_login(self.user)
forbidden_methods = ("post", "patch", "put", "delete")
for method in forbidden_methods:
with self.subTest(method=method):
response = getattr(self.client, method)(reverse("api:dataset-elements", kwargs={"pk": str(self.dataset.id)}))
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_list_elements_invalid_dataset_id(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
with self.assertNumQueries(5):
response = self.client.get(reverse("api:dataset-elements", kwargs={"pk": str(uuid.uuid4())}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_list_elements_readable_corpus(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
with self.assertNumQueries(5):
response = self.client.get(reverse("api:dataset-elements", kwargs={"pk": str(self.private_dataset.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_list_elements(self):
self.dataset.dataset_elements.create(element_id=self.vol.id, set="test")
......@@ -1416,18 +1408,132 @@ class TestDatasetsAPI(FixtureAPITestCase):
}
first_three = [value for key, value in sorted(elements.items())[:3]]
self.client.force_login(self.read_user)
for state in DatasetState:
# A dataset should be readable in any state
self.dataset.state = state
self.dataset.save()
with self.subTest(state=state):
with self.assertNumQueries(6):
response = self.client.get(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}),
{"page_size": 3},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertTrue("?cursor=" in data["next"])
self.assertIsNone(data["count"])
self.assertListEqual(data["results"], first_three)
def test_add_element_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."})
def test_add_element_requires_verified(self):
user = User.objects.create(email="not_verified@mail.com", display_name="Not Verified", verified_email=False)
self.client.force_login(user)
with self.assertNumQueries(2):
response = self.client.post(reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
def test_add_element_private_dataset(self):
self.client.force_login(self.user)
with self.assertNumQueries(4):
response = self.client.get(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}),
{"page_size": 3},
with self.assertNumQueries(5):
response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.private_dataset.id})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertTrue("?cursor=" in data["next"])
self.assertIsNone(data["count"])
self.assertListEqual(data["results"], first_three)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_add_element_requires_writable_corpus(self):
self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value)
self.client.force_login(self.user)
with self.assertNumQueries(5):
response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.id})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {
"detail": "You do not have contributor access to the corpus of this dataset."
})
def test_add_element_required_fields(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
"element_id": ["This field is required."],
"set": ["This field is required."],
})
def test_add_element_wrong_element(self):
element = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="folder"))
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
data={"set": "test", "element_id": str(element.id)},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
"element_id": [f'Invalid pk "{element.id}" - object does not exist.'],
})
def test_add_element_wrong_set(self):
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
data={"set": "aaaaaaaaaaa", "element_id": str(self.vol.id)},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
"set": ["This dataset has no set named aaaaaaaaaaa."],
})
def test_add_element_dataset_requires_open(self):
self.client.force_login(self.user)
self.dataset.state = DatasetState.Complete
self.dataset.save()
with self.assertNumQueries(6):
response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
data={"set": "test", "element_id": str(self.vol.id)},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ["You can only add elements to a dataset in an open state."])
def test_add_element_already_exists(self):
self.dataset.dataset_elements.create(element=self.page1, set="test")
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
data={"set": "test", "element_id": str(self.page1.id)},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["This element is already part of this set."]})
def test_add_element(self):
self.client.force_login(self.user)
with self.assertNumQueries(13):
response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
data={"set": "training", "element_id": str(self.page1.id)},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertQuerysetEqual(
self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
[("training", "Volume 1, page 1r")]
)
def test_add_from_selection_requires_login(self):
with self.assertNumQueries(0):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment