From 399f0799c73cef838f4b7725f7d9115339da5a31 Mon Sep 17 00:00:00 2001
From: Valentin Rigal <rigal@teklia.com>
Date: Fri, 14 Jun 2024 13:52:20 +0000
Subject: [PATCH] PopulateDataset endpoint

---
 arkindex/project/api_v1.py                   |   2 +
 arkindex/training/api.py                     |  36 +++
 arkindex/training/serializers.py             | 158 +++++++++++++
 arkindex/training/tests/test_datasets_api.py | 235 ++++++++++++++++++-
 4 files changed, 430 insertions(+), 1 deletion(-)

diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index b5d5a32327..aab4bb37f5 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -103,6 +103,7 @@ from arkindex.training.api import (
     DatasetClone,
     DatasetElementDestroy,
     DatasetElements,
+    DatasetPopulate,
     DatasetSetCreate,
     DatasetSets,
     DatasetUpdate,
@@ -186,6 +187,7 @@ api = [
     path("datasets/<uuid:dataset>/elements/<uuid:element>/", DatasetElementDestroy.as_view(), name="dataset-element"),
     path("datasets/<uuid:pk>/sets/", DatasetSetCreate.as_view(), name="dataset-sets"),
     path("datasets/<uuid:dataset>/sets/<uuid:set>/", DatasetSets.as_view(), name="dataset-set"),
+    path("datasets/<uuid:pk>/populate/", DatasetPopulate.as_view(), name="dataset-populate"),
 
     # Moderation
     path("classifications/", ClassificationCreate.as_view(), name="classification-create"),
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index cb124b3107..db033ec5fd 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -40,6 +40,7 @@ from arkindex.training.serializers import (
     CreateModelErrorResponseSerializer,
     DatasetElementInfoSerializer,
     DatasetElementSerializer,
+    DatasetPopulateSerializer,
     DatasetSerializer,
     DatasetSetSerializer,
     ElementDatasetSetSerializer,
@@ -1049,3 +1050,38 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
             DatasetSerializer(clone).data,
             status=status.HTTP_201_CREATED,
         )
+
+
+@extend_schema_view(
+    post=extend_schema(
+        tags=["datasets"],
+        operation_id="PopulateDataset",
+    ),
+)
+class DatasetPopulate(CorpusACLMixin, CreateAPIView):
+    """
+    Populate a dataset using randomized corpus content
+
+    Requires a **contributor** access to the dataset's corpus.
+    """
+    permission_classes = (IsVerified, )
+    serializer_class = DatasetPopulateSerializer
+
+    def get_queryset(self):
+        return (
+            Dataset.objects
+            .select_related("corpus")
+            .prefetch_related("corpus__types", "sets")
+            .filter(corpus__in=Corpus.objects.readable(self.request.user))
+        )
+
+    def check_object_permissions(self, request, dataset):
+        if not self.has_write_access(dataset.corpus):
+            raise PermissionDenied(detail="You need a Contributor access to the dataset's corpus.")
+        return super().check_object_permissions(request, dataset)
+
+    def get_serializer_context(self):
+        return {
+            **super().get_serializer_context(),
+            "dataset": self.get_object(),
+        }
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 855434b8ed..9cd60e6a34 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -899,3 +899,161 @@ class SelectionDatasetElementSerializer(serializers.Serializer):
             ignore_conflicts=True
         )
         return validated_data
+
+
+class DatasetPopulateSerializer(serializers.Serializer):
+    parent_id = serializers.UUIDField(
+        write_only=True,
+        default=None,
+        source="parent",
+        help_text=dedent("""
+            UUID of a parent element from which to select random children. The parent must be in the same corpus as the dataset.
+
+            If this is unset, all elements from the corpus of the dataset will be used.
+        """),
+
+    )
+    recursive = serializers.BooleanField(
+        write_only=True,
+        default=False,
+        help_text=dedent("""
+            When `parent_id` is set, also choose elements among its grandchildren, and so on.
+
+            When `parent_id` is unset, take every element in the corpus instead of only the top-level ones.
+        """),
+    )
+    types = serializers.ListField(
+        child=serializers.SlugField(),
+        write_only=True,
+        allow_empty=False,
+        default=["page"],
+        help_text="Restrict the selected elements to those with any of the specified element type slugs.",
+    )
+    count = serializers.IntegerField(
+        write_only=True,
+        default=1000,
+        help_text=dedent("""
+            Number of elements to fill the dataset with.
+
+            This cannot exceed the number of available elements.
+        """),
+        min_value=1,
+    )
+    sets = serializers.DictField(
+        child=serializers.FloatField(min_value=0, max_value=1),
+        write_only=True,
+        allow_empty=False,
+        default={"train": 0.8, "dev": 0.1, "test": 0.1},
+        help_text=dedent("""
+            A dictionary mapping set names to ratios, defining how the elements will be distributed between the sets.
+
+            The sum of all ratios must be equal to 1.
+        """),
+    )
+
+    def validate_parent_id(self, parent_id):
+        if parent_id is None:
+            return
+        try:
+            parent = self.context["dataset"].corpus.elements.only("id").get(id=parent_id)
+        except Element.DoesNotExist:
+            raise ValidationError("This element does not exist in the corpus of the dataset.")
+        return parent
+
+    def validate_types(self, type_slugs):
+        type_slugs = set(type_slugs)
+        # Types are prefetched at view's level
+        existing_types = {
+            t.slug: t.id
+            for t in self.context["dataset"].corpus.types.all()
+        }
+        if (missing_types := [slug for slug in type_slugs if slug not in existing_types]):
+            raise ValidationError(f"Some types does not exist in the corpus of the dataset: {sorted(missing_types)}.")
+        return [existing_types[slug] for slug in type_slugs]
+
+    def validate_sets(self, sets):
+        # Sets are prefetched at view's level
+        existing_sets = {
+            d.name: d.id
+            for d in self.context["dataset"].sets.all()
+        }
+        errors = []
+        if (missing_sets := [s for s in sets if s not in existing_sets]):
+            errors.append(f"Some sets does not exist in the dataset: {sorted(missing_sets)}.")
+        if sum(sets.values()) != 1:
+            errors.append("The sum of all ratios must be equal to 1.")
+        if errors:
+            raise ValidationError(errors)
+        return {existing_sets[s]: v for s, v in sets.items()}
+
+    def validate(self, data):
+        data = super().validate(data)
+        elts_count = self.filter_elements(**data).count()
+        errors = defaultdict(list)
+        if data["count"] > elts_count:
+            errors["count"].append(f"This value exceeds the number of filtered elements ({elts_count}).")
+
+        min_ratio = 1 / data["count"]
+        if any(ratio < min_ratio for ratio in data["sets"].values()):
+            errors["sets"].append(f"Ratio must be greater than {min_ratio:.2g} to contain at least one element.")
+
+        # Ensure none of filtered elements is already part of any
+        # set on to avoid any edge case and preserve valid ratio.
+        if (
+            self.filter_elements(**data)
+            .filter(dataset_elements__set__dataset_id=self.context["dataset"].id)
+            .exists()
+        ):
+            errors["__all__"].append("Some filtered elements are already part of a set on this dataset.")
+
+        if errors:
+            raise ValidationError(errors)
+
+        return data
+
+    def filter_elements(self, *, parent=None, recursive=False, types, **kwargs):
+        qs = self.context["dataset"].corpus.elements.all()
+        filters = {}
+        if parent:
+            if recursive:
+                filters["paths__path__overlap"] = [parent.id]
+            else:
+                # Only pick direct children
+                filters["paths__path__last"] = parent.id
+        elif not recursive:
+            # Only pick top-level elements
+            filters["paths__path__last__isnull"] = True
+        filters["type_id__in"] = types
+        return qs.filter(**filters)
+
+
+    @transaction.atomic
+    def save(self):
+        count = self.validated_data["count"]
+        sets = self.validated_data["sets"]
+        element_ids = list(
+            self.filter_elements(**self.validated_data)
+            .order_by("?")
+            .values_list("id", flat=True)
+            [:count]
+        )
+
+        dataset_elements = []
+        index = 0
+        for set_id, ratio in sets.items():
+            set_count = int(ratio * count)
+            dataset_elements.extend(
+                DatasetElement(set_id=set_id, element_id=elt_id)
+                for elt_id in element_ids[index:index + set_count]
+            )
+            index += set_count
+
+        # Distribute remaining elements among sets, as we rounded down using int.
+        # Populate largest sets first, to be closer to the ratio.
+        sorted_sets = sorted(sets.items(), key=lambda item: -item[1])
+        dataset_elements.extend(
+            DatasetElement(set_id=set_id, element_id=elt_id)
+            for elt_id, (set_id, _ratio) in zip(element_ids[index:], sorted_sets)
+        )
+
+        DatasetElement.objects.bulk_create(dataset_elements)
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index d24571674e..1bb2a5aa91 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -7,7 +7,7 @@ from rest_framework import status
 
 from arkindex.documents.models import Corpus
 from arkindex.process.models import Process, ProcessDatasetSet, ProcessMode
-from arkindex.project.tests import FixtureAPITestCase
+from arkindex.project.tests import FixtureAPITestCase, force_constraints_immediate
 from arkindex.project.tools import fake_now
 from arkindex.training.models import Dataset, DatasetElement, DatasetSet, DatasetState
 from arkindex.users.models import Role, User
@@ -2495,3 +2495,236 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.dataset.refresh_from_db()
         self.assertEqual(train_set.set_elements.count(), 0)
         self.assertEqual(dev_set.set_elements.count(), 1)
+
+    def test_populate_dataset_requires_login(self):
+        with self.assertNumQueries(0):
+            response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}))
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."})
+
+    def test_populate_dataset_requires_verified(self):
+        self.user.verified_email = False
+        self.user.save()
+        self.client.force_login(self.user)
+        with self.assertNumQueries(2):
+            response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}))
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
+
+    @patch("arkindex.users.managers.BaseACLManager.filter_rights", return_value=Corpus.objects.none())
+    def test_populate_dataset_private_corpus(self, filter_rights_mock):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(2):
+            response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.private_dataset.pk}))
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+        self.assertEqual(filter_rights_mock.call_count, 1)
+        self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value))
+
+    @patch("arkindex.project.mixins.has_access", return_value=False)
+    def test_populate_dataset_requires_contributor(self, has_access_mock):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(5):
+            response = self.client.post(reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}))
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the dataset's corpus."})
+        self.assertEqual(has_access_mock.call_count, 1)
+        self.assertEqual(has_access_mock.call_args, call(self.user, self.corpus, Role.Contributor.value, skip_public=False))
+
+    def test_populate_dataset_invalid_parameters(self):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(5):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={
+                    "parent_id": "a",
+                    "recursive": "b",
+                    "types": "c",
+                    "count": "d",
+                    "sets": "e",
+                },
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {
+            "count": ["A valid integer is required."],
+            "parent_id": ["Must be a valid UUID."],
+            "recursive": ["Must be a valid boolean."],
+            "sets": ['Expected a dictionary of items but got type "str".'],
+            "types": ['Expected a list of items but got type "str".'],
+        })
+
+    def test_populate_dataset_unexisting_values(self):
+        self.client.force_login(self.user)
+        type = self.private_corpus.types.create(slug="A")
+        elt = self.private_corpus.elements.create(name="a", type_id=type.id)
+        with self.assertNumQueries(6):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={
+                    "parent_id": str(elt.id),
+                    "types": ["page", "A", "AAAAAA"],
+                    "sets": {"test": .8, "A": .15, "AAAAA": .05},
+                },
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {
+            "parent_id": ["This element does not exist in the corpus of the dataset."],
+            "sets": ["Some sets does not exist in the dataset: ['A', 'AAAAA']."],
+            "types": ["Some types does not exist in the corpus of the dataset: ['A', 'AAAAAA']."],
+        })
+
+    def test_populate_dataset_count_max_value(self):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(7):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={"recursive": True},
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {
+            "count": ["This value exceeds the number of filtered elements (6)."]
+        })
+
+    def test_populate_dataset_wrong_sets_ratio(self):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(5):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={"recursive": True, "sets": {"test": 0.13, "train": 0.37}},
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {
+            "sets": ["The sum of all ratios must be equal to 1."]
+        })
+
+    def test_populate_dataset_min_sets_ratio(self):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(7):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={"recursive": True, "count": 5, "sets": {"test": 0.01, "train": 0.99}},
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {
+            "sets": ["Ratio must be greater than 0.2 to contain at least one element."]
+        })
+
+    def test_populate_dataset_duplicate_entries(self):
+        """Raises a 400 in case an element of the selection is already part of a set"""
+        DatasetElement.objects.create(
+            element=self.page1,
+            set=self.dataset.sets.get(name="train"),
+        )
+        self.client.force_login(self.user)
+        with self.assertNumQueries(7):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={"recursive": True, "count": 1, "sets": {"test": 1}},
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {
+            "__all__": ["Some filtered elements are already part of a set on this dataset."]
+        })
+
+    def test_populate_dataset(self):
+        DatasetElement.objects.create(
+            element=self.vol,
+            set=self.dataset.sets.get(name="train"),
+        )
+        self.client.force_login(self.user)
+        with force_constraints_immediate(), self.assertNumQueries(11):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={
+                    "recursive": True,
+                    "types": ["page", "word"],
+                    "count": 11,
+                    "sets": {"train": .3, "dev": .2, "test": .5}},
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        self.assertEqual(response.json(), {})
+        # 3 elements + volume DatasetElement
+        self.assertEqual(DatasetElement.objects.filter(set__dataset=self.dataset, set__name="train").count(), 4)
+        # 2 elements
+        self.assertEqual(DatasetElement.objects.filter(set__dataset=self.dataset, set__name="dev").count(), 2)
+        # 5 elements + the final one
+        self.assertEqual(DatasetElement.objects.filter(set__dataset=self.dataset, set__name="test").count(), 6)
+        # Ensure no element is found in multiple sets at once
+        self.assertEqual(
+            DatasetElement.objects.filter(set__dataset=self.dataset).values("element_id").distinct("element_id").count(),
+            12
+        )
+
+    def test_populate_dataset_direct_children(self):
+        self.client.force_login(self.user)
+        with force_constraints_immediate(), self.assertNumQueries(13):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={
+                    "parent_id": str(self.vol.id),
+                    "recursive": False,
+                    "types": ["page", "word"],
+                    "count": 3,
+                    "sets": {"train": 1}
+                },
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        self.assertEqual(response.json(), {})
+        self.assertQuerysetEqual(
+            (
+                DatasetElement.objects
+                .filter(set__dataset=self.dataset)
+                .order_by("element__name")
+                .values_list("set__name", "element__name")
+            ),
+            (
+                ("train", "Volume 1, page 1r"),
+                ("train", "Volume 1, page 1v"),
+                ("train", "Volume 1, page 2r"),
+            )
+        )
+
+    def test_populate_dataset_recursive_children(self):
+        self.client.force_login(self.user)
+        with force_constraints_immediate(), self.assertNumQueries(13):
+            response = self.client.post(
+                reverse("api:dataset-populate", kwargs={"pk": self.dataset.pk}),
+                data={
+                    "parent_id": str(self.vol.id),
+                    "recursive": True,
+                    "types": ["page", "word"],
+                    "count": 12,
+                    "sets": {"train": 1}
+                },
+                format="json",
+            )
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        self.assertQuerysetEqual(
+            (
+                DatasetElement.objects
+                .filter(set__dataset=self.dataset)
+                .order_by("element__name")
+                .values_list("set__name", "element__name")
+            ),
+            (
+                ("train", "DATUM"),
+                ("train", "DATUM"),
+                ("train", "DATUM"),
+                ("train", "PARIS"),
+                ("train", "PARIS"),
+                ("train", "PARIS"),
+                ("train", "ROY"),
+                ("train", "ROY"),
+                ("train", "ROY"),
+                ("train", "Volume 1, page 1r"),
+                ("train", "Volume 1, page 1v"),
+                ("train", "Volume 1, page 2r"),
+            ),
+        )
-- 
GitLab