From 7e4fc11de0aa2c7f70e4089d0b2b7bd810551d9b Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Mon, 11 Mar 2024 19:06:29 +0100
Subject: [PATCH] fixed datasets API

---
 arkindex/training/api.py                     | 17 ++++++----
 arkindex/training/serializers.py             | 11 +++---
 arkindex/training/tests/test_datasets_api.py | 35 +++++++++++---------
 3 files changed, 35 insertions(+), 28 deletions(-)

diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 0a42b292e2..530b759921 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -72,18 +72,18 @@ def _fetch_datasetelement_neighbors(datasetelements):
                 SELECT
                     n.id,
                     lag(element_id) OVER (
-                        partition BY (n.dataset_id, n.set)
+                        partition BY (n.set_id)
                         order by
                             n.element_id
                     ) as previous,
                     lead(element_id) OVER (
-                        partition BY (n.dataset_id, n.set)
+                        partition BY (n.set_id)
                         order by
                             n.element_id
                     ) as next
                 FROM training_datasetelement as n
-                WHERE (dataset_id, set) IN (
-                    SELECT dataset_id, set
+                WHERE set_id IN (
+                    SELECT set_id
                     FROM training_datasetelement
                     WHERE id IN %(ids)s
                 )
@@ -688,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
     serializer_class = DatasetSerializer
 
     def get_queryset(self):
-        queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)).prefetch_related("sets")
+        queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
         return queryset.select_related("corpus", "creator")
 
     def check_object_permissions(self, request, obj):
@@ -910,7 +910,8 @@ class ElementDatasetSets(CorpusACLMixin, ListAPIView):
         qs = (
             self.element.dataset_elements
             .select_related("set__dataset__creator")
-            .order_by("set__name", "id")
+            .prefetch_related("set__dataset__sets")
+            .order_by("set__dataset__name", "set__name")
         )
 
         with_neighbors = self.request.query_params.get("with_neighbors", "false")
@@ -953,7 +954,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
     serializer_class = DatasetSerializer
 
     def get_queryset(self):
-        return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
+        return (
+            Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
+        )
 
     def check_object_permissions(self, request, dataset):
         if not self.has_write_access(dataset.corpus):
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 37dacc0d76..2765dcfaef 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -512,7 +512,11 @@ class DatasetSerializer(serializers.ModelSerializer):
         help_text="Display name of the user who created the dataset.",
     )
 
-    set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, required=False)
+    set_names = serializers.ListField(
+        child=serializers.CharField(max_length=50),
+        write_only=True,
+        default=serializers.CreateOnlyDefault(["training", "validation", "test"])
+    )
     sets = DatasetSetSerializer(many=True, read_only=True)
 
     # When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
@@ -587,10 +591,7 @@ class DatasetSerializer(serializers.ModelSerializer):
 
     @transaction.atomic
     def create(self, validated_data):
-        if "set_names" not in validated_data:
-            sets = ["training", "validation", "test"]
-        else:
-            sets = validated_data.pop("set_names")
+        sets = validated_data.pop("set_names")
         dataset = Dataset.objects.create(**validated_data)
         DatasetSet.objects.bulk_create(
             DatasetSet(
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 56f7e3fb41..20b854a193 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -981,7 +981,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_retrieve(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.get(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk})
             )
@@ -1010,7 +1010,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.client.force_login(self.user)
         self.dataset.task = self.task
         self.dataset.save()
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.get(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk})
             )
@@ -1431,7 +1431,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
     def test_add_element_wrong_element(self):
         element = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="folder"))
         self.client.force_login(self.user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.post(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
                 data={"set": "test", "element_id": str(element.id)},
@@ -1724,7 +1724,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": None,
                 "next": None
             }, {
@@ -1770,7 +1770,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": None,
                 "next": None
             }]
@@ -1811,7 +1811,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": None,
                 "next": None
             }, {
@@ -1857,7 +1857,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": None,
                 "next": None
             }]
@@ -1880,9 +1880,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
         sorted_dataset2_elements = sorted([str(self.page1.id), str(self.page3.id)])
         page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id))
 
-        with self.assertNumQueries(7):
+        with self.assertNumQueries(8):
             response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True})
             self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.maxDiff = None
         self.assertDictEqual(response.json(), {
             "count": 3,
             "next": None,
@@ -1908,7 +1909,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": (
                     sorted_dataset_elements[page1_index_1 - 1]
                     if page1_index_1 - 1 >= 0
@@ -1952,7 +1953,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset2.sets.all()
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1962,7 +1963,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": (
                     sorted_dataset2_elements[page1_index_2 - 1]
                     if page1_index_2 == 1
@@ -2101,7 +2102,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
     def test_clone_existing_name(self):
         self.corpus.datasets.create(name="Clone of First Dataset", creator=self.user)
         self.client.force_login(self.user)
-        with self.assertNumQueries(11):
+        with self.assertNumQueries(15):
             response = self.client.post(
                 reverse("api:dataset-clone", kwargs={"pk": self.dataset.id}),
                 format="json",
@@ -2115,12 +2116,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
         ])
 
         data = response.json()
-        data.pop("id")
         data.pop("created")
         data.pop("updated")
+        cloned_dataset = Dataset.objects.get(id=data["id"])
+        self.maxDiff = None
         self.assertDictEqual(
             response.json(),
             {
+                "id": str(cloned_dataset.id),
                 "name": "Clone of First Dataset 1",
                 "description": self.dataset.description,
                 "creator": self.user.display_name,
@@ -2130,9 +2133,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
                         "id": str(ds.id),
                         "name": ds.name
                     }
-                    for ds in self.dataset.sets.all()
+                    for ds in cloned_dataset.sets.all()
                 ],
-                "set_elements": {k: 0 for k in self.dataset.sets.all()},
+                "set_elements": {str(k.name): 0 for k in self.dataset.sets.all()},
                 "state": DatasetState.Open.value,
                 "task_id": None,
             },
@@ -2141,7 +2144,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
     def test_clone_name_too_long(self):
         dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user)
         self.client.force_login(self.user)
-        with self.assertNumQueries(11):
+        with self.assertNumQueries(14):
             response = self.client.post(
                 reverse("api:dataset-clone", kwargs={"pk": dataset.id}),
                 format="json",
-- 
GitLab