From 2df4601f7bfe75c6ee081b500d077ff9cc073677 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Tue, 26 Mar 2024 12:12:50 +0100
Subject: [PATCH] Always sort sets by name

---
 arkindex/training/api.py                     | 20 ++++++++----
 arkindex/training/serializers.py             |  2 +-
 arkindex/training/tests/test_datasets_api.py | 32 ++++++++++----------
 3 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index b167ea914f..a2b21048eb 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -55,9 +55,9 @@ from arkindex.users.models import Role
 from arkindex.users.utils import get_max_level
 
 # A prefetch object that includes the number of elements per set.
-SET_COUNTS_PREFETCH = Prefetch(
+DATASET_SET_COUNTS_PREFETCH = Prefetch(
     "sets",
-    DatasetSet.objects.annotate(element_count=Count("set_elements"))
+    DatasetSet.objects.annotate(element_count=Count("set_elements")).order_by("name")
 )
 
 
@@ -616,7 +616,11 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView):
     def get_queryset(self):
         return Dataset.objects \
             .select_related("creator") \
-            .prefetch_related("sets") \
+            .prefetch_related(Prefetch(
+                "sets",
+                # Prefetch sets, but ensure they are ordered by name
+                DatasetSet.objects.order_by("name")
+            )) \
             .filter(corpus=self.corpus) \
             .order_by("name")
 
@@ -696,7 +700,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
             .select_related("corpus", "creator")
         )
         if self.request.method != "DELETE":
-            queryset = queryset.prefetch_related(SET_COUNTS_PREFETCH)
+            queryset = queryset.prefetch_related(DATASET_SET_COUNTS_PREFETCH)
         return queryset
 
     def check_object_permissions(self, request, obj):
@@ -937,7 +941,11 @@ class ElementDatasetSets(CorpusACLMixin, ListAPIView):
         qs = (
             self.element.dataset_elements
             .select_related("set__dataset__creator")
-            .prefetch_related("set__dataset__sets")
+            .prefetch_related(Prefetch(
+                "set__dataset__sets",
+                # Prefetch sets, but ensure they are ordered by name
+                DatasetSet.objects.order_by("name")
+            ))
             .order_by("set__dataset__name", "set__name")
         )
 
@@ -1025,7 +1033,7 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
         ])
 
         # Add the set counts to the API response
-        prefetch_related_objects([clone], SET_COUNTS_PREFETCH)
+        prefetch_related_objects([clone], DATASET_SET_COUNTS_PREFETCH)
 
         return Response(
             DatasetSerializer(clone).data,
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 9fc5c2bae4..85b11b1eee 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -598,7 +598,7 @@ class DatasetSerializer(serializers.ModelSerializer):
             DatasetSet(
                 name=set_name,
                 dataset_id=dataset.id
-            ) for set_name in set_names
+            ) for set_name in sorted(set_names)
         )
         # We will output set element counts in the API, but we know there are zero,
         # so no need to make another query to prefetch the sets and count them
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 073d299eb3..e7b54da344 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -89,7 +89,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                         "id": str(ds.id),
                         "name": ds.name
                     }
-                    for ds in self.dataset.sets.all()
+                    for ds in self.dataset.sets.order_by("name")
                 ],
                 "set_elements": None,
                 "state": "open",
@@ -108,7 +108,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                         "id": str(ds.id),
                         "name": ds.name
                     }
-                    for ds in self.dataset2.sets.all()
+                    for ds in self.dataset2.sets.order_by("name")
                 ],
                 "set_elements": None,
                 "state": "open",
@@ -284,7 +284,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(ds.id),
                     "name": ds.name
                 }
-                for ds in created_dataset.sets.all()
+                for ds in created_dataset.sets.order_by("name")
             ],
             "set_elements": {
                 "training": 0,
@@ -319,7 +319,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(ds.id),
                     "name": ds.name
                 }
-                for ds in created_dataset.sets.all()
+                for ds in created_dataset.sets.order_by("name")
             ],
             "set_elements": {
                 "training": 0,
@@ -353,7 +353,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(ds.id),
                     "name": ds.name
                 }
-                for ds in created_dataset.sets.all()
+                for ds in created_dataset.sets.order_by("name")
             ],
             "set_elements": {"a": 0, "b": 0, "c": 0, "d": 0},
             "state": "open",
@@ -996,7 +996,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(ds.id),
                     "name": ds.name
                 }
-                for ds in self.dataset.sets.all()
+                for ds in self.dataset.sets.order_by("name")
             ],
             "set_elements": {"test": 0, "training": 0, "validation": 0},
             "creator": "Test user",
@@ -1672,7 +1672,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1713,7 +1713,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1736,7 +1736,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1759,7 +1759,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset2.sets.all()
+                        for ds in self.dataset2.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1800,7 +1800,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1823,7 +1823,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1846,7 +1846,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset2.sets.all()
+                        for ds in self.dataset2.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1898,7 +1898,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1929,7 +1929,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset.sets.all()
+                        for ds in self.dataset.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
@@ -1952,7 +1952,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                             "id": str(ds.id),
                             "name": ds.name
                         }
-                        for ds in self.dataset2.sets.all()
+                        for ds in self.dataset2.sets.order_by("name")
                     ],
                     "set_elements": None,
                     "state": "open",
-- 
GitLab