From 0b9d79bf0f759f84e35a5c0490f12ace7abbd162 Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Tue, 12 Mar 2024 13:19:19 +0100
Subject: [PATCH] temp dataset process fix

---
 arkindex/process/api.py                       |   2 +
 arkindex/process/serializers/training.py      |   6 +-
 arkindex/process/tests/test_create_process.py |   6 +-
 .../process/tests/test_process_datasets.py    | 164 ++++++++++++------
 arkindex/process/tests/test_processes.py      |  14 +-
 arkindex/training/api.py                      |   4 +-
 arkindex/training/tests/test_datasets_api.py  |  28 +--
 7 files changed, 149 insertions(+), 75 deletions(-)

diff --git a/arkindex/process/api.py b/arkindex/process/api.py
index e6f76cf645..7d1fd99a82 100644
--- a/arkindex/process/api.py
+++ b/arkindex/process/api.py
@@ -706,6 +706,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
         return (
             ProcessDataset.objects.filter(process_id=self.process.id)
             .select_related("process__creator", "dataset__creator")
+            .prefetch_related("dataset__sets")
             .order_by("dataset__name")
         )
 
@@ -751,6 +752,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView):
         process_dataset = get_object_or_404(
             ProcessDataset.objects
             .select_related("dataset__creator", "process__corpus")
+            .prefetch_related("dataset__sets")
             # Required to check for a process that have already started
             .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))),
             dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"]
diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py
index e7c25f1ec7..f280a16a65 100644
--- a/arkindex/process/serializers/training.py
+++ b/arkindex/process/serializers/training.py
@@ -89,7 +89,7 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
             else:
                 dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user))
             try:
-                dataset = dataset_qs.select_related("creator").get(pk=data["dataset_id"])
+                dataset = dataset_qs.select_related("creator").prefetch_related("sets").get(pk=data["dataset_id"])
             except Dataset.DoesNotExist:
                 raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']})
         else:
@@ -109,11 +109,11 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
         sets = data.get("sets")
         if not sets or len(sets) == 0:
             if not self.instance:
-                data["sets"] = dataset.sets
+                data["sets"] = [item.name for item in list(dataset.sets.all())]
             else:
                 errors["sets"].append("This field cannot be empty.")
         else:
-            if any(s not in dataset.sets for s in sets):
+            if any(s not in [item.name for item in list(dataset.sets.all())] for s in sets):
                 errors["sets"].append("The specified sets must all exist in the specified dataset.")
             if len(set(sets)) != len(sets):
                 errors["sets"].append("Sets must be unique.")
diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py
index 34264e1a04..59e454700c 100644
--- a/arkindex/process/tests/test_create_process.py
+++ b/arkindex/process/tests/test_create_process.py
@@ -899,7 +899,8 @@ class TestCreateProcess(FixtureAPITestCase):
         self.client.force_login(self.user)
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
         dataset = self.corpus.datasets.first()
-        ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets)
+        test_sets = list(dataset.sets.values_list("name", flat=True))
+        ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets)
         process.versions.set([self.version_2, self.version_3])
 
         with self.assertNumQueries(9):
@@ -929,7 +930,8 @@ class TestCreateProcess(FixtureAPITestCase):
         self.worker_1.save()
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
         dataset = self.corpus.datasets.first()
-        ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets)
+        test_sets = list(dataset.sets.values_list("name", flat=True))
+        ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets)
         process.versions.add(self.version_1)
 
         with self.assertNumQueries(9):
diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py
index 5f35cff0b9..1785250e9e 100644
--- a/arkindex/process/tests/test_process_datasets.py
+++ b/arkindex/process/tests/test_process_datasets.py
@@ -9,7 +9,7 @@ from arkindex.documents.models import Corpus
 from arkindex.ponos.models import Farm
 from arkindex.process.models import Process, ProcessDataset, ProcessMode
 from arkindex.project.tests import FixtureAPITestCase
-from arkindex.training.models import Dataset
+from arkindex.training.models import Dataset, DatasetSet
 from arkindex.users.models import Role, User
 
 # Using the fake DB fixtures creation date when needed
@@ -28,6 +28,10 @@ class TestProcessDatasets(FixtureAPITestCase):
                 description="Human instrumentality manual",
                 creator=cls.user
             )
+        DatasetSet.objects.bulk_create([
+            DatasetSet(dataset_id=cls.private_dataset.id, name=set_name)
+            for set_name in ["validation", "training", "test"]
+        ])
         cls.test_user = User.objects.create(email="katsuragi@nerv.co.jp", verified_email=True)
         cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value)
 
@@ -40,8 +44,8 @@ class TestProcessDatasets(FixtureAPITestCase):
             corpus_id=cls.private_corpus.id,
             farm=Farm.objects.get(name="Wheat farm")
         )
-        cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=cls.dataset1.sets)
-        cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=cls.private_dataset.sets)
+        cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=list(cls.dataset1.sets.values_list("name", flat=True)))
+        cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=list(cls.private_dataset.sets.values_list("name", flat=True)))
 
         # Control process to check that its datasets are not retrieved
         cls.dataset_process_2 = Process.objects.create(
@@ -49,7 +53,7 @@ class TestProcessDatasets(FixtureAPITestCase):
             mode=ProcessMode.Dataset,
             corpus_id=cls.corpus.id
         )
-        ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=cls.dataset2.sets)
+        ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
 
     # List process datasets
 
@@ -78,9 +82,10 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_list(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.get(reverse("api:process-datasets", kwargs={"pk": self.dataset_process.id}))
             self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.maxDiff = None
         self.assertEqual(response.json()["results"], [
             {
                 "id": str(self.process_dataset_2.id),
@@ -89,7 +94,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                     "name": "Dead sea scrolls",
                     "description": "Human instrumentality manual",
                     "creator": "Test user",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.private_dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "corpus_id": str(self.private_corpus.id),
                     "state": "open",
@@ -97,7 +108,7 @@ class TestProcessDatasets(FixtureAPITestCase):
                     "created": FAKE_CREATED,
                     "updated": FAKE_CREATED
                 },
-                "sets": ["training", "test", "validation"]
+                "sets": ["validation", "training", "test"]
             },
             {
                 "id": str(self.process_dataset_1.id),
@@ -106,7 +117,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                     "name": "First Dataset",
                     "description": "dataset number one",
                     "creator": "Test user",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset1.sets.all()
+                    ],
                     "set_elements": None,
                     "corpus_id": str(self.corpus.id),
                     "state": "open",
@@ -114,33 +131,36 @@ class TestProcessDatasets(FixtureAPITestCase):
                     "created": FAKE_CREATED,
                     "updated": FAKE_CREATED
                 },
-                "sets": ["training", "test", "validation"]
+                "sets": ["validation", "training", "test"]
             }
         ])
 
     # Create process dataset
 
     def test_create_requires_login(self):
+        test_sets = list(self.dataset2.sets.values_list("name", flat=True))
         with self.assertNumQueries(0):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
-                data={"sets": self.dataset2.sets}
+                data={"sets": test_sets}
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
     def test_create_requires_verified(self):
         unverified_user = User.objects.create(email="email@mail.com")
+        test_sets = list(self.dataset2.sets.values_list("name", flat=True))
         self.client.force_login(unverified_user)
         with self.assertNumQueries(2):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
-                data={"sets": self.dataset2.sets}
+                data={"sets": test_sets}
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
     @patch("arkindex.project.mixins.get_max_level")
     def test_create_access_level(self, get_max_level_mock):
         cases = [None, Role.Guest.value, Role.Contributor.value]
+        test_sets = list(self.dataset2.sets.values_list("name", flat=True))
         for level in cases:
             with self.subTest(level=level):
                 get_max_level_mock.reset_mock()
@@ -150,7 +170,7 @@ class TestProcessDatasets(FixtureAPITestCase):
                 with self.assertNumQueries(3):
                     response = self.client.post(
                         reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
-                        data={"sets": self.dataset2.sets}
+                        data={"sets": test_sets}
                     )
                     self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
@@ -161,16 +181,17 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_create_process_mode(self):
         cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local}
+        test_sets = list(self.dataset2.sets.values_list("name", flat=True))
         for mode in cases:
             with self.subTest(mode=mode):
                 self.dataset_process.mode = mode
                 self.dataset_process.save()
                 self.client.force_login(self.test_user)
 
-                with self.assertNumQueries(5):
+                with self.assertNumQueries(6):
                     response = self.client.post(
                         reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
-                        data={"sets": self.dataset2.sets}
+                        data={"sets": test_sets}
                     )
                     self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
@@ -178,22 +199,24 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_create_process_mode_local(self):
         self.client.force_login(self.user)
+        test_sets = list(self.dataset2.sets.values_list("name", flat=True))
         local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local)
         with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}),
-                data={"sets": self.dataset2.sets}
+                data={"sets": test_sets}
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
         self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."})
 
     def test_create_wrong_process_uuid(self):
         self.client.force_login(self.test_user)
+        test_sets = list(self.dataset2.sets.values_list("name", flat=True))
         wrong_id = uuid.uuid4()
         with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.dataset2.id}),
-                data={"sets": self.dataset2.sets}
+                data={"sets": test_sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
@@ -213,12 +236,13 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_create_dataset_access(self, filter_rights_mock):
         new_corpus = Corpus.objects.create(name="NERV")
         new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user)
+        test_sets = list(new_dataset.sets.values_list("name", flat=True))
         self.client.force_login(self.test_user)
 
         with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}),
-                data={"sets": new_dataset.sets}
+                data={"sets": test_sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
@@ -229,12 +253,13 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_create_unique(self):
         self.client.force_login(self.test_user)
+        test_sets = list(self.dataset1.sets.values_list("name", flat=True))
         self.assertTrue(self.dataset_process.datasets.filter(id=self.dataset1.id).exists())
 
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
-                data={"sets": self.dataset1.sets}
+                data={"sets": test_sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
@@ -243,11 +268,12 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_create_started(self):
         self.client.force_login(self.test_user)
         self.dataset_process.tasks.create(run=0, depth=0, slug="makrout")
+        test_sets = list(self.dataset2.sets.values_list("name", flat=True))
 
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
-                data={"sets": self.dataset2.sets}
+                data={"sets": test_sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
@@ -257,7 +283,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         self.client.force_login(self.test_user)
         self.assertEqual(ProcessDataset.objects.count(), 3)
         self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.post(
                 reverse(
                     "api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}
@@ -272,6 +298,7 @@ class TestProcessDatasets(FixtureAPITestCase):
             self.dataset2
         ])
         created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id)
+        self.maxDiff = None
         self.assertDictEqual(response.json(), {
             "id": str(created.id),
             "dataset": {
@@ -279,7 +306,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                 "name": "Second Dataset",
                 "description": "dataset number two",
                 "creator": "Test user",
-                "sets": ["training", "test", "validation"],
+                "sets": [
+                    {
+                        "id": str(ds.id),
+                        "name": ds.name
+                    }
+                    for ds in self.dataset2.sets.all()
+                ],
                 "set_elements": None,
                 "corpus_id": str(self.corpus.id),
                 "state": "open",
@@ -287,14 +320,14 @@ class TestProcessDatasets(FixtureAPITestCase):
                 "created": FAKE_CREATED,
                 "updated": FAKE_CREATED
             },
-            "sets": ["training", "test", "validation"]
+            "sets": ["test", "training", "validation"]
         })
 
     def test_create(self):
         self.client.force_login(self.test_user)
         self.assertEqual(ProcessDataset.objects.count(), 3)
         self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
                 data={"sets": ["validation", "test"]}
@@ -315,7 +348,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                 "name": "Second Dataset",
                 "description": "dataset number two",
                 "creator": "Test user",
-                "sets": ["training", "test", "validation"],
+                "sets": [
+                    {
+                        "id": str(ds.id),
+                        "name": ds.name
+                    }
+                    for ds in self.dataset2.sets.all()
+                ],
                 "set_elements": None,
                 "corpus_id": str(self.corpus.id),
                 "state": "open",
@@ -330,7 +369,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         self.client.force_login(self.test_user)
         self.assertEqual(ProcessDataset.objects.count(), 3)
         self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
                 data={"sets": ["Unit-01"]}
@@ -367,7 +406,7 @@ class TestProcessDatasets(FixtureAPITestCase):
                 if level:
                     self.private_corpus.memberships.create(user=self.test_user, level=level.value)
                 self.client.force_login(self.test_user)
-                with self.assertNumQueries(4):
+                with self.assertNumQueries(5):
                     response = self.client.put(
                         reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                         data={"sets": ["test"]}
@@ -403,7 +442,7 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_update(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.put(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["test"]}
@@ -416,7 +455,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                 "name": "First Dataset",
                 "description": "dataset number one",
                 "creator": "Test user",
-                "sets": ["training", "test", "validation"],
+                "sets": [
+                    {
+                        "id": str(ds.id),
+                        "name": ds.name
+                    }
+                    for ds in self.dataset1.sets.all()
+                ],
                 "set_elements": None,
                 "corpus_id": str(self.corpus.id),
                 "state": "open",
@@ -429,7 +474,7 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_update_wrong_sets(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(3):
+        with self.assertNumQueries(4):
             response = self.client.put(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["Unit-01", "Unit-02"]}
@@ -439,7 +484,7 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_update_unique_sets(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(3):
+        with self.assertNumQueries(4):
             response = self.client.put(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["test", "test"]}
@@ -455,7 +500,7 @@ class TestProcessDatasets(FixtureAPITestCase):
             expiry=datetime(1970, 1, 1, tzinfo=timezone.utc),
         )
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(3):
+        with self.assertNumQueries(4):
             response = self.client.put(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["test"]}
@@ -468,7 +513,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         Non "sets" fields in the update request are ignored
         """
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.put(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]}
@@ -481,7 +526,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                 "name": "First Dataset",
                 "description": "dataset number one",
                 "creator": "Test user",
-                "sets": ["training", "test", "validation"],
+                "sets": [
+                    {
+                        "id": str(ds.id),
+                        "name": ds.name
+                    }
+                    for ds in self.dataset1.sets.all()
+                ],
                 "set_elements": None,
                 "corpus_id": str(self.corpus.id),
                 "state": "open",
@@ -520,7 +571,7 @@ class TestProcessDatasets(FixtureAPITestCase):
                 if level:
                     self.private_corpus.memberships.create(user=self.test_user, level=level.value)
                 self.client.force_login(self.test_user)
-                with self.assertNumQueries(4):
+                with self.assertNumQueries(5):
                     response = self.client.patch(
                         reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                         data={"sets": ["test"]}
@@ -556,7 +607,7 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_partial_update(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.patch(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["test"]}
@@ -569,7 +620,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                 "name": "First Dataset",
                 "description": "dataset number one",
                 "creator": "Test user",
-                "sets": ["training", "test", "validation"],
+                "sets": [
+                    {
+                        "id": str(ds.id),
+                        "name": ds.name
+                    }
+                    for ds in self.dataset1.sets.all()
+                ],
                 "set_elements": None,
                 "corpus_id": str(self.corpus.id),
                 "state": "open",
@@ -582,7 +639,7 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_partial_update_wrong_sets(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(3):
+        with self.assertNumQueries(4):
             response = self.client.patch(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["Unit-01", "Unit-02"]}
@@ -592,7 +649,7 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_partial_update_unique_sets(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(3):
+        with self.assertNumQueries(4):
             response = self.client.patch(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["test", "test"]}
@@ -608,7 +665,7 @@ class TestProcessDatasets(FixtureAPITestCase):
             expiry=datetime(1970, 1, 1, tzinfo=timezone.utc),
         )
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(3):
+        with self.assertNumQueries(4):
             response = self.client.patch(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"sets": ["test"]}
@@ -621,7 +678,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         Non "sets" fields in the partial update request are ignored
         """
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.patch(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
                 data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]}
@@ -634,7 +691,13 @@ class TestProcessDatasets(FixtureAPITestCase):
                 "name": "First Dataset",
                 "description": "dataset number one",
                 "creator": "Test user",
-                "sets": ["training", "test", "validation"],
+                "sets": [
+                    {
+                        "id": str(ds.id),
+                        "name": ds.name
+                    }
+                    for ds in self.dataset1.sets.all()
+                ],
                 "set_elements": None,
                 "corpus_id": str(self.corpus.id),
                 "state": "open",
@@ -677,7 +740,7 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_destroy_not_found(self):
         self.assertFalse(self.dataset_process.datasets.filter(id=self.dataset2.id).exists())
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
             )
@@ -700,10 +763,11 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_destroy_no_dataset_access_requirement(self):
         new_corpus = Corpus.objects.create(name="NERV")
         new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user)
-        ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=new_dataset.sets)
+        test_sets = list(new_dataset.sets.values_list("name", flat=True))
+        ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=test_sets)
         self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists())
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}),
             )
@@ -718,7 +782,7 @@ class TestProcessDatasets(FixtureAPITestCase):
                 self.dataset_process.save()
                 self.client.force_login(self.test_user)
 
-                with self.assertNumQueries(4):
+                with self.assertNumQueries(5):
                     response = self.client.delete(
                         reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
                     )
@@ -740,7 +804,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         self.client.force_login(self.test_user)
         self.dataset_process.tasks.create(run=0, depth=0, slug="makrout")
 
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
             )
@@ -750,7 +814,7 @@ class TestProcessDatasets(FixtureAPITestCase):
 
     def test_destroy(self):
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
             )
@@ -765,7 +829,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         self.process_dataset_1.sets = ["test"]
         self.process_dataset_1.save()
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
             )
diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py
index 8a4d715cd2..b897d3250d 100644
--- a/arkindex/process/tests/test_processes.py
+++ b/arkindex/process/tests/test_processes.py
@@ -2324,7 +2324,7 @@ class TestProcesses(FixtureAPITestCase):
 
     def test_start_process_dataset_requires_dataset_in_same_corpus(self):
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets)
+        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True)))
         process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
         self.assertFalse(process2.tasks.exists())
 
@@ -2341,8 +2341,8 @@ class TestProcesses(FixtureAPITestCase):
 
     def test_start_process_dataset_unsupported_parameters(self):
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets)
-        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.dataset2.sets)
+        ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
+        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.dataset2.sets.values_list("name", flat=True)))
         process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
 
         self.client.force_login(self.user)
@@ -2366,8 +2366,8 @@ class TestProcesses(FixtureAPITestCase):
 
     def test_start_process_dataset(self):
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets)
-        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets)
+        ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
+        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True)))
         run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
         self.assertFalse(process2.tasks.exists())
 
@@ -2562,8 +2562,8 @@ class TestProcesses(FixtureAPITestCase):
         It should be possible to pass chunks when starting a dataset process
         """
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=self.dataset1.sets)
-        ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=self.dataset2.sets)
+        ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True)))
+        ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=list(self.dataset2.sets.values_list("name", flat=True)))
         # Add a worker run to this process
         run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
 
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 530b759921..f0b71f16bc 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -954,9 +954,7 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
     serializer_class = DatasetSerializer
 
     def get_queryset(self):
-        return (
-            Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
-        )
+        return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
 
     def check_object_permissions(self, request, dataset):
         if not self.has_write_access(dataset.corpus):
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 20b854a193..72149bfc3c 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -2118,28 +2118,36 @@ class TestDatasetsAPI(FixtureAPITestCase):
         data = response.json()
         data.pop("created")
         data.pop("updated")
-        cloned_dataset = Dataset.objects.get(id=data["id"])
-        self.maxDiff = None
+        clone = Dataset.objects.get(id=data["id"])
+        test_clone, train_clone, val_clone = clone.sets.all().order_by("name")
+        cloned_sets = data.pop("sets")
         self.assertDictEqual(
             response.json(),
             {
-                "id": str(cloned_dataset.id),
+                "id": str(clone.id),
                 "name": "Clone of First Dataset 1",
                 "description": self.dataset.description,
                 "creator": self.user.display_name,
                 "corpus_id": str(self.corpus.id),
-                "sets": [
-                    {
-                        "id": str(ds.id),
-                        "name": ds.name
-                    }
-                    for ds in cloned_dataset.sets.all()
-                ],
                 "set_elements": {str(k.name): 0 for k in self.dataset.sets.all()},
                 "state": DatasetState.Open.value,
                 "task_id": None,
             },
         )
+        self.assertCountEqual(cloned_sets, [
+            {
+                "name": "training",
+                "id": str(train_clone.id)
+            },
+            {
+                "name": "test",
+                "id": str(test_clone.id)
+            },
+            {
+                "name": "validation",
+                "id": str(val_clone.id)
+            }
+        ])
 
     def test_clone_name_too_long(self):
         dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user)
-- 
GitLab