From 399f0799c73cef838f4b7725f7d9115339da5a31 Mon Sep 17 00:00:00 2001 From: Valentin Rigal <rigal@teklia.com> Date: Fri, 14 Jun 2024 13:52:20 +0000 Subject: [PATCH] PopulateDataset endpoint --- arkindex/project/api_v1.py | 2 + arkindex/training/api.py | 36 +++ arkindex/training/serializers.py | 158 +++++++++++++ arkindex/training/tests/test_datasets_api.py | 235 ++++++++++++++++++- 4 files changed, 430 insertions(+), 1 deletion(-) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index b5d5a32327..aab4bb37f5 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -103,6 +103,7 @@ from arkindex.training.api import ( DatasetClone, DatasetElementDestroy, DatasetElements, + DatasetPopulate, DatasetSetCreate, DatasetSets, DatasetUpdate, @@ -186,6 +187,7 @@ api = [ path("datasets/<uuid:dataset>/elements/<uuid:element>/", DatasetElementDestroy.as_view(), name="dataset-element"), path("datasets/<uuid:pk>/sets/", DatasetSetCreate.as_view(), name="dataset-sets"), path("datasets/<uuid:dataset>/sets/<uuid:set>/", DatasetSets.as_view(), name="dataset-set"), + path("datasets/<uuid:pk>/populate/", DatasetPopulate.as_view(), name="dataset-populate"), # Moderation path("classifications/", ClassificationCreate.as_view(), name="classification-create"), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index cb124b3107..db033ec5fd 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -40,6 +40,7 @@ from arkindex.training.serializers import ( CreateModelErrorResponseSerializer, DatasetElementInfoSerializer, DatasetElementSerializer, + DatasetPopulateSerializer, DatasetSerializer, DatasetSetSerializer, ElementDatasetSetSerializer, @@ -1049,3 +1050,38 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): DatasetSerializer(clone).data, status=status.HTTP_201_CREATED, ) + + +@extend_schema_view( + post=extend_schema( + tags=["datasets"], + operation_id="PopulateDataset", + ), +) +class DatasetPopulate(CorpusACLMixin, CreateAPIView): + """ + Populate a dataset using randomized corpus content + + Requires a **contributor** access to the dataset's corpus. + """ + permission_classes = (IsVerified, ) + serializer_class = DatasetPopulateSerializer + + def get_queryset(self): + return ( + Dataset.objects + .select_related("corpus") + .prefetch_related("corpus__types", "sets") + .filter(corpus__in=Corpus.objects.readable(self.request.user)) + ) + + def check_object_permissions(self, request, dataset): + if not self.has_write_access(dataset.corpus): + raise PermissionDenied(detail="You need a Contributor access to the dataset's corpus.") + return super().check_object_permissions(request, dataset) + + def get_serializer_context(self): + return { + **super().get_serializer_context(), + "dataset": self.get_object(), + } diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 855434b8ed..9cd60e6a34 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -899,3 +899,161 @@ class SelectionDatasetElementSerializer(serializers.Serializer): ignore_conflicts=True ) return validated_data + + +class DatasetPopulateSerializer(serializers.Serializer): + parent_id = serializers.UUIDField( + write_only=True, + default=None, + source="parent", + help_text=dedent(""" + UUID of a parent element from which to select random children. The parent must be in the same corpus as the dataset. + + If this is unset, all elements from the corpus of the dataset will be used. + """), + + ) + recursive = serializers.BooleanField( + write_only=True, + default=False, + help_text=dedent(""" + When `parent_id` is set, also choose elements among its grandchildren, and so on. + + When `parent_id` is unset, take every element in the corpus instead of only the top-level ones. + """), + ) + types = serializers.ListField( + child=serializers.SlugField(), + write_only=True, + allow_empty=False, + default=["page"], + help_text="Restrict the selected elements to those with any of the specified element type slugs.", + ) + count = serializers.IntegerField( + write_only=True, + default=1000, + help_text=dedent(""" + Number of elements to fill the dataset with. + + This cannot exceed the number of available elements. + """), + min_value=1, + ) + sets = serializers.DictField( + child=serializers.FloatField(min_value=0, max_value=1), + write_only=True, + allow_empty=False, + default={"train": 0.8, "dev": 0.1, "test": 0.1}, + help_text=dedent(""" + A dictionary mapping set names to ratios, defining how the elements will be distributed between the sets. + + The sum of all ratios must be equal to 1. + """), + ) + + def validate_parent_id(self, parent_id): + if parent_id is None: + return + try: + parent = self.context["dataset"].corpus.elements.only("id").get(id=parent_id) + except Element.DoesNotExist: + raise ValidationError("This element does not exist in the corpus of the dataset.") + return parent + + def validate_types(self, type_slugs): + type_slugs = set(type_slugs) + # Types are prefetched at view's level + existing_types = { + t.slug: t.id + for t in self.context["dataset"].corpus.types.all() + } + if (missing_types := [slug for slug in type_slugs if slug not in existing_types]): + raise ValidationError(f"Some types does not exist in the corpus of the dataset: {sorted(missing_types)}.") + return [existing_types[slug] for slug in type_slugs] + + def validate_sets(self, sets): + # Sets are prefetched at view's level + existing_sets = { + d.name: d.id + for d in self.context["dataset"].sets.all() + } + errors = [] + if (missing_sets := [s for s in sets if s not in existing_sets]): + errors.append(f"Some sets does not exist in the dataset: {sorted(missing_sets)}.") + if sum(sets.values()) != 1: + errors.append("The sum of all ratios must be equal to 1.") + if errors: + raise ValidationError(errors) + return {existing_sets[s]: v for s, v in sets.items()} + + def validate(self, data): + data = super().validate(data) + elts_count = self.filter_elements(**data).count() + errors = defaultdict(list) + if data["count"] > elts_count: + errors["count"].append(f"This value exceeds the number of filtered elements ({elts_count}).") + + min_ratio = 1 / data["count"] + if any(ratio < min_ratio for ratio in data["sets"].values()): + errors["sets"].append(f"Ratio must be greater than {min_ratio:.2g} to contain at least one element.") + + # Ensure none of filtered elements is already part of any + # set on to avoid any edge case and preserve valid ratio. + if ( + self.filter_elements(**data) + .filter(dataset_elements__set__dataset_id=self.context["dataset"].id) + .exists() + ): + errors["__all__"].append("Some filtered elements are already part of a set on this dataset.") + + if errors: + raise ValidationError(errors) + + return data + + def filter_elements(self, *, parent=None, recursive=False, types, **kwargs): + qs = self.context["dataset"].corpus.elements.all() + filters = {} + if parent: + if recursive: + filters["paths__path__overlap"] = [parent.id] + else: + # Only pick direct children + filters["paths__path__last"] = parent.id + elif not recursive: + # Only pick top-level elements + filters["paths__path__last__isnull"] = True + filters["type_id__in"] = types + return qs.filter(**filters) + + + @transaction.atomic + def save(self): + count = self.validated_data["count"] + sets = self.validated_data["sets"] + element_ids = list( + self.filter_elements(**self.validated_data) + .order_by("?") + .values_list("id", flat=True) + [:count] + ) + + dataset_elements = [] + index = 0 + for set_id, ratio in sets.items(): + set_count = int(ratio * count) + dataset_elements.extend( + DatasetElement(set_id=set_id, element_id=elt_id) + for elt_id in element_ids[index:index + set_count] + ) + index += set_count + + # Distribute remaining elements among sets, as we rounded down using int. + # Populate largest sets first, to be closer to the ratio. + sorted_sets = sorted(sets.items(), key=lambda item: -item[1]) + dataset_elements.extend( + DatasetElement(set_id=set_id, element_id=elt_id) + for elt_id, (set_id, _ratio) in zip(element_ids[index:], sorted_sets) + ) + + DatasetElement.objects.bulk_create(dataset_elements) diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index d24571674e..1bb2a5aa91 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -7,7 +7,7 @@ from rest_framework import status from arkindex.documents.models import Corpus from arkindex.process.models import Process, ProcessDatasetSet, ProcessMode -from arkindex.project.tests import FixtureAPITestCase +from arkindex.project.tests import FixtureAPITestCase, force_constraints_immediate from arkindex.project.tools import fake_now from arkindex.training.models import Dataset, DatasetElement, DatasetSet, DatasetState from arkindex.users.models import Role, User @@ -2495,3 +2495,236 @@ class TestDatasetsAPI(FixtureAPITestCase): self.dataset.refresh_from_db() self.assertEqual(train_set.set_elements.count(), 0) self.assertEqual(dev_set.set_elements.count(), 1) + + def test_populate_dataset_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."}) + + def test_populate_dataset_requires_verified(self): + self.user.verified_email = False + self.user.save() + self.client.force_login(self.user) + with self.assertNumQueries(2): + response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."}) + + @patch("arkindex.users.managers.BaseACLManager.filter_rights", return_value=Corpus.objects.none()) + def test_populate_dataset_private_corpus(self, filter_rights_mock): + self.client.force_login(self.user) + with self.assertNumQueries(2): + response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.private_dataset.pk})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(filter_rights_mock.call_count, 1) + self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value)) + + @patch("arkindex.project.mixins.has_access", return_value=False) + def test_populate_dataset_requires_contributor(self, has_access_mock): + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the dataset's corpus."}) + self.assertEqual(has_access_mock.call_count, 1) + self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Contributor.value, skip_public=False)) + + def test_populate_dataset_invalid_parameters(self): + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={ + "parent_id": "a", + "recursive": "b", + "types": "c", + "count": "d", + "sets": "e", + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "count": ["A valid integer is required."], + "parent_id": ["Must be a valid UUID."], + "recursive": ["Must be a valid boolean."], + "sets": ['Expected a dictionary of items but got type "str".'], + "types": ['Expected a list of items but got type "str".'], + }) + + def test_populate_dataset_unexisting_values(self): + self.client.force_login(self.user) + type = self.private_corpus.types.create(slug="A") + elt = self.private_corpus.elements.create(name="a", type_id=type.id) + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={ + "parent_id": str(elt.id), + "types": ["page", "A", "AAAAAA"], + "sets": {"test": .8, "A": .15, "AAAAA": .05}, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "parent_id": ["This element does not exist in the corpus of the dataset."], + "sets": ["Some sets does not exist in the dataset: ['A', 'AAAAA']."], + "types": ["Some types does not exist in the corpus of the dataset: ['A', 'AAAAAA']."], + }) + + def test_populate_dataset_count_max_value(self): + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={"recursive": True}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "count": ["This value exceeds the number of filtered elements (6)."] + }) + + def test_populate_dataset_wrong_sets_ratio(self): + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={"recursive": True, "sets": {"test": 0.13, "train": 0.37}}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "sets": ["The sum of all ratios must be equal to 1."] + }) + + def test_populate_dataset_min_sets_ratio(self): + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={"recursive": True, "count": 5, "sets": {"test": 0.01, "train": 0.99}}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "sets": ["Ratio must be greater than 0.2 to contain at least one element."] + }) + + def test_populate_dataset_duplicate_entries(self): + """Raises a 400 in case an element of the selection is already part of a set""" + DatasetElement.objects.create( + element=self.page1, + set=self.dataset.sets.get(name="train"), + ) + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={"recursive": True, "count": 1, "sets": {"test": 1}}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "__all__": ["Some filtered elements are already part of a set on this dataset."] + }) + + def test_populate_dataset(self): + DatasetElement.objects.create( + element=self.vol, + set=self.dataset.sets.get(name="train"), + ) + self.client.force_login(self.user) + with force_constraints_immediate(), self.assertNumQueries(11): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={ + "recursive": True, + "types": ["page", "word"], + "count": 11, + "sets": {"train": .3, "dev": .2, "test": .5}}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json(), {}) + # 3 elements + volume DatasetElement + self.assertEqual(DatasetElement.objects.filter(set__dataset=self.dataset, set__name="train").count(), 4) + # 2 elements + self.assertEqual(DatasetElement.objects.filter(set__dataset=self.dataset, set__name="dev").count(), 2) + # 5 elements + the final one + self.assertEqual(DatasetElement.objects.filter(set__dataset=self.dataset, set__name="test").count(), 6) + # Ensure no element is found in multiple sets at once + self.assertEqual( + DatasetElement.objects.filter(set__dataset=self.dataset).values("element_id").distinct("element_id").count(), + 12 + ) + + def test_populate_dataset_direct_children(self): + self.client.force_login(self.user) + with force_constraints_immediate(), self.assertNumQueries(13): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={ + "parent_id": str(self.vol.id), + "recursive": False, + "types": ["page", "word"], + "count": 3, + "sets": {"train": 1} + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json(), {}) + self.assertQuerysetEqual( + ( + DatasetElement.objects + .filter(set__dataset=self.dataset) + .order_by("element__name") + .values_list("set__name", "element__name") + ), + ( + ("train", "Volume 1, page 1r"), + ("train", "Volume 1, page 1v"), + ("train", "Volume 1, page 2r"), + ) + ) + + def test_populate_dataset_recursive_children(self): + self.client.force_login(self.user) + with force_constraints_immediate(), self.assertNumQueries(13): + response = self.client.post( + reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}), + data={ + "parent_id": str(self.vol.id), + "recursive": True, + "types": ["page", "word"], + "count": 12, + "sets": {"train": 1} + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertQuerysetEqual( + ( + DatasetElement.objects + .filter(set__dataset=self.dataset) + .order_by("element__name") + .values_list("set__name", "element__name") + ), + ( + ("train", "DATUM"), + ("train", "DATUM"), + ("train", "DATUM"), + ("train", "PARIS"), + ("train", "PARIS"), + ("train", "PARIS"), + ("train", "ROY"), + ("train", "ROY"), + ("train", "ROY"), + ("train", "Volume 1, page 1r"), + ("train", "Volume 1, page 1v"), + ("train", "Volume 1, page 2r"), + ), + ) -- GitLab