From 0806509f39f0d822e92b7cc4d3b7acf3f3d24bcc Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Thu, 7 Mar 2024 18:34:30 +0100
Subject: [PATCH] WIP fix API endpoints :'(

---
 arkindex/documents/fixtures/data.json         |  12 +-
 arkindex/project/api_v1.py                    |   4 +-
 arkindex/training/api.py                      |  63 ++---
 .../migrations/0007_datasetset_model.py       |   6 +-
 arkindex/training/models.py                   |   8 +
 arkindex/training/serializers.py              |  21 +-
 arkindex/training/tests/test_datasets_api.py  | 220 +++++++++++-------
 7 files changed, 197 insertions(+), 137 deletions(-)

diff --git a/arkindex/documents/fixtures/data.json b/arkindex/documents/fixtures/data.json
index ac55bc4e72..0037a3d48a 100644
--- a/arkindex/documents/fixtures/data.json
+++ b/arkindex/documents/fixtures/data.json
@@ -4022,7 +4022,7 @@
     "model": "training.datasetset",
     "pk": "00e3b37b-f1ed-4adb-a1de-a1c103edaa24",
     "fields": {
-        "name": "Test",
+        "name": "test",
         "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
     }
 },
@@ -4030,7 +4030,7 @@
     "model": "training.datasetset",
     "pk": "95255ff4-7bca-424c-8e7e-d8b5c33c585f",
     "fields": {
-        "name": "Train",
+        "name": "training",
         "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
     }
 },
@@ -4038,7 +4038,7 @@
     "model": "training.datasetset",
     "pk": "b21e4b31-1dc1-4015-9933-3a8e180cf2e0",
     "fields": {
-        "name": "Validation",
+        "name": "validation",
         "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
     }
 },
@@ -4046,7 +4046,7 @@
     "model": "training.datasetset",
     "pk": "b76d919c-ce7e-4896-94c3-9b47862b997a",
     "fields": {
-        "name": "Validation",
+        "name": "validation",
         "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
     }
 },
@@ -4054,7 +4054,7 @@
     "model": "training.datasetset",
     "pk": "d5f4d410-0588-4d19-8d08-b57a11bad67f",
     "fields": {
-        "name": "Train",
+        "name": "training",
         "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
     }
 },
@@ -4062,7 +4062,7 @@
     "model": "training.datasetset",
     "pk": "db70ad8a-8d6b-44c6-b8d3-837bf5a4ab8e",
     "fields": {
-        "name": "Test",
+        "name": "test",
         "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
     }
 }
diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index 17c92d941b..3c893fc12f 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -109,7 +109,7 @@ from arkindex.training.api import (
     DatasetElementDestroy,
     DatasetElements,
     DatasetUpdate,
-    ElementDatasets,
+    ElementDatasetSets,
     MetricValueBulkCreate,
     MetricValueCreate,
     ModelCompatibleWorkerManage,
@@ -184,7 +184,7 @@ api = [
     # Datasets
     path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"),
     path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"),
-    path("element/<uuid:pk>/datasets/", ElementDatasets.as_view(), name="element-datasets"),
+    path("element/<uuid:pk>/datasets/", ElementDatasetSets.as_view(), name="element-datasets"),
     path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"),
     path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"),
     path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"),
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 06034116e9..240282d88e 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap
 from arkindex.training.models import (
     Dataset,
     DatasetElement,
+    DatasetSet,
     DatasetState,
     MetricValue,
     Model,
@@ -40,7 +41,7 @@ from arkindex.training.serializers import (
     DatasetElementInfoSerializer,
     DatasetElementSerializer,
     DatasetSerializer,
-    ElementDatasetSerializer,
+    ElementDatasetSetSerializer,
     MetricValueBulkSerializer,
     MetricValueCreateSerializer,
     ModelCompatibleWorkerSerializer,
@@ -58,7 +59,7 @@ def _fetch_datasetelement_neighbors(datasetelements):
     """
     Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
     with next and previous attributes.
-    The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and
+    The ElementDatasetSets endpoint uses arkindex.project.tools.BulkMap to apply this method and
     perform the second request *after* DRF's pagination, because there is no way to perform
     post-processing after pagination in Django without having to use Django private methods.
     """
@@ -708,7 +709,8 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
             raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.")
 
     def perform_destroy(self, dataset):
-        dataset.dataset_elements.all().delete()
+        DatasetElement.objects.filter(set__dataset_id=dataset.id).delete()
+        dataset.sets.all().delete()
         super().perform_destroy(dataset)
 
 
@@ -769,13 +771,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
 
     def get_queryset(self):
         qs = (
-            self.dataset.dataset_elements
-            .prefetch_related("element")
+            DatasetElement.objects.filter(set__dataset_id=self.dataset.id)
+            .prefetch_related("element", "set")
             .select_related("element__type", "element__corpus", "element__image__server")
             .order_by("element_id", "id")
         )
         if "set" in self.request.query_params:
-            qs = qs.filter(set=self.request.query_params["set"])
+            qs = qs.filter(set_name=self.request.query_params["set"])
         return qs
 
     def get_serializer_context(self):
@@ -788,20 +790,12 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
 @extend_schema_view(
     delete=extend_schema(
         operation_id="DestroyDatasetElement",
-        parameters=[
-            OpenApiParameter(
-                "set",
-                type=str,
-                description="Name of the set from which to remove the element.",
-                required=True,
-            )
-        ],
         tags=["datasets"]
     )
 )
 class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
     """
-    Remove an element from a dataset.
+    Remove an element from a dataset set.
 
     Elements can only be removed from **open** datasets.
 
@@ -812,17 +806,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
     lookup_url_kwarg = "element"
 
     def destroy(self, request, *args, **kwargs):
-        if not self.request.query_params.get("set"):
+        if not (set_name := self.request.query_params.get("set")):
             raise ValidationError({"set": ["This field is required."]})
         dataset_element = get_object_or_404(
-            DatasetElement.objects.select_related("dataset__corpus"),
-            dataset_id=self.kwargs["dataset"],
-            element_id=self.kwargs["element"],
-            set=self.request.query_params.get("set")
+            DatasetElement.objects.select_related("set__dataset__corpus").filter(set__dataset_id=self.kwargs["dataset"], set__name=set_name),
+            element_id=self.kwargs["element"]
         )
-        if dataset_element.dataset.state != DatasetState.Open:
+        if dataset_element.set.dataset.state != DatasetState.Open:
             raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]})
-        if not self.has_write_access(dataset_element.dataset.corpus):
+        if not self.has_write_access(dataset_element.set.dataset.corpus):
             raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.")
         dataset_element.delete()
         return Response(status=status.HTTP_204_NO_CONTENT)
@@ -897,14 +889,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
         )
     ],
 )
-class ElementDatasets(CorpusACLMixin, ListAPIView):
+class ElementDatasetSets(CorpusACLMixin, ListAPIView):
     """
-    List all datasets containing a specific element.
+    List all dataset sets containing a specific element.
 
     Requires a **guest** access to the element's corpus.
     """
     permission_classes = (IsVerifiedOrReadOnly, )
-    serializer_class = ElementDatasetSerializer
+    serializer_class = ElementDatasetSetSerializer
 
     @cached_property
     def element(self):
@@ -916,9 +908,9 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
 
     def get_queryset(self):
         qs = (
-            self.element.dataset_elements.all()
-            .select_related("dataset__creator")
-            .order_by("dataset__name", "set", "dataset_id")
+            self.element.dataset_elements
+            .select_related("set__dataset__creator")
+            .order_by("set__name", "id")
         )
 
         with_neighbors = self.request.query_params.get("with_neighbors", "false")
@@ -961,7 +953,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
     serializer_class = DatasetSerializer
 
     def get_queryset(self):
-        return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
+        return (
+            Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
+        )
 
     def check_object_permissions(self, request, dataset):
         if not self.has_write_access(dataset.corpus):
@@ -996,11 +990,18 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
         clone.creator = request.user
         clone.save()
 
+        # Clone dataset sets
+        cloned_sets = DatasetSet.objects.bulk_create([
+            DatasetSet(dataset_id=clone.id, name=set.name)
+            for set in dataset.sets.all()
+        ])
         # Associate all elements to the clone
         DatasetElement.objects.bulk_create([
-            DatasetElement(element_id=elt_id, dataset_id=clone.id, set=set_name)
-            for elt_id, set_name in dataset.dataset_elements.values_list("element_id", "set")
+            DatasetElement(element_id=elt_id, set=next(new_set for new_set in cloned_sets if new_set.name == set_name))
+            for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id)
+            .values_list("element_id", "set__name")
         ])
+
         return Response(
             DatasetSerializer(clone).data,
             status=status.HTTP_201_CREATED,
diff --git a/arkindex/training/migrations/0007_datasetset_model.py b/arkindex/training/migrations/0007_datasetset_model.py
index 8630baa593..c3052f691c 100644
--- a/arkindex/training/migrations/0007_datasetset_model.py
+++ b/arkindex/training/migrations/0007_datasetset_model.py
@@ -59,7 +59,7 @@ class Migration(migrations.Migration):
             UPDATE training_datasetelement de
             SET set_id_id = ds.id
             FROM training_datasetset ds
-            WHERE de.dataset_id = ds.dataset_id
+            WHERE de.dataset_id = ds.dataset_id AND de.set = ds.name
             """,
             reverse_sql=migrations.RunSQL.noop,
         ),
@@ -85,6 +85,10 @@ class Migration(migrations.Migration):
             name="set",
             field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="set_elements", to="training.datasetset"),
         ),
+        migrations.AddConstraint(
+            model_name="datasetelement",
+            constraint=models.UniqueConstraint(fields=("element_id", "set_id"), name="unique_set_element"),
+        ),
         migrations.RemoveField(
             model_name="dataset",
             name="sets"
diff --git a/arkindex/training/models.py b/arkindex/training/models.py
index 06f4f84c5f..c37ea35b09 100644
--- a/arkindex/training/models.py
+++ b/arkindex/training/models.py
@@ -309,3 +309,11 @@ class DatasetElement(models.Model):
         related_name="set_elements",
         on_delete=models.DO_NOTHING,
     )
+
+    class Meta:
+        constraints = [
+            models.UniqueConstraint(
+                fields=["element_id", "set_id"],
+                name="unique_set_element",
+            ),
+        ]
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 4acb03a663..c65d048fc1 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -512,7 +512,7 @@ class DatasetSerializer(serializers.ModelSerializer):
         help_text="Display name of the user who created the dataset.",
     )
 
-    set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, default=["Training", "Validation", "Test"])
+    set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, default=["training", "validation", "test"])
     sets = DatasetSetSerializer(many=True, read_only=True)
 
     # When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
@@ -657,6 +657,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         default=_dataset_from_context,
         write_only=True,
     )
+    set = serializers.SlugRelatedField(queryset=DatasetSet.objects.none(), slug_field="name")
 
     class Meta:
         model = DatasetElement
@@ -665,7 +666,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         validators = [
             UniqueTogetherValidator(
                 queryset=DatasetElement.objects.all(),
-                fields=["dataset", "element_id", "set"],
+                fields=["element_id", "set"],
                 message="This element is already part of this set.",
             )
         ]
@@ -674,13 +675,12 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         super().__init__(*args, **kwargs)
         if dataset := self.context.get("dataset"):
             self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
+            self.fields["set"].queryset = dataset.sets.all()
 
-    def validate_set(self, value):
-        # The set must match the `sets` array defined at the dataset level
-        dataset = self.context["dataset"]
-        if dataset and value not in dataset.sets:
-            raise ValidationError(f"This dataset has no set named {value}.")
-        return value
+    def validate(self, data):
+        data = super().validate(data)
+        data.pop("dataset")
+        return data
 
 
 class DatasetElementInfoSerializer(DatasetElementSerializer):
@@ -698,10 +698,11 @@ class DatasetElementInfoSerializer(DatasetElementSerializer):
         fields = DatasetElementSerializer.Meta.fields + ("dataset",)
 
 
-class ElementDatasetSerializer(serializers.ModelSerializer):
-    dataset = DatasetSerializer()
+class ElementDatasetSetSerializer(serializers.ModelSerializer):
+    dataset = DatasetSerializer(source="set.dataset")
     previous = serializers.UUIDField(allow_null=True, read_only=True)
     next = serializers.UUIDField(allow_null=True, read_only=True)
+    set = serializers.SlugRelatedField(slug_field="name", read_only=True)
 
     class Meta:
         model = DatasetElement
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 9b19becc27..6a7387c464 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -9,7 +9,7 @@ from arkindex.documents.models import Corpus
 from arkindex.process.models import Process, ProcessDataset, ProcessMode
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.project.tools import fake_now
-from arkindex.training.models import Dataset, DatasetSet, DatasetState
+from arkindex.training.models import Dataset, DatasetElement, DatasetSet, DatasetState
 from arkindex.users.models import Role, User
 
 # Using the fake DB fixtures creation date when needed
@@ -267,7 +267,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_create(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(10):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}),
                 data={"name": "My dataset", "description": "My dataset for my experiments."},
@@ -279,11 +279,17 @@ class TestDatasetsAPI(FixtureAPITestCase):
             "id": str(created_dataset.id),
             "name": "My dataset",
             "description": "My dataset for my experiments.",
-            "sets": {},
+            "sets": [
+                {
+                    "id": str(ds.id),
+                    "name": ds.name
+                }
+                for ds in created_dataset.sets.all()
+            ],
             "set_elements": {
-                "Training": 0,
-                "Test": 0,
-                "Validation": 0,
+                "training": 0,
+                "test": 0,
+                "validation": 0,
             },
             "state": "open",
             "creator": "Test user",
@@ -1066,14 +1072,15 @@ class TestDatasetsAPI(FixtureAPITestCase):
         """
         DestroyDataset also deletes DatasetElements
         """
-        self.dataset.dataset_elements.create(element_id=self.vol.id, set="test")
-        self.dataset.dataset_elements.create(element_id=self.vol.id, set="training")
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="training")
-        self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation")
-        self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation")
+        test_set, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element_id=self.vol.id)
+        train_set.set_elements.create(element_id=self.vol.id, set="training")
+        train_set.set_elements.create(element_id=self.page1.id, set="training")
+        validation_set.set_elements.create(element_id=self.page2.id, set="validation")
+        validation_set.set_elements.create(element_id=self.page3.id, set="validation")
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.delete(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}),
             )
@@ -1082,7 +1089,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
         with self.assertRaises(Dataset.DoesNotExist):
             self.dataset.refresh_from_db()
 
-        self.assertFalse(self.dataset.dataset_elements.exists())
+        self.assertFalse(DatasetElement.objects.filter(set__dataset_id=self.dataset.id).exists())
 
         # No elements should have been deleted
         self.vol.refresh_from_db()
@@ -1112,12 +1119,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value))
 
     def test_list_elements_set_filter_wrong_set(self):
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="test")
+        test_set = self.dataset.sets.order_by("name").first()
+        test_set.set_elements.create(element_id=self.page1.id)
         self.client.force_login(self.user)
         with self.assertNumQueries(4):
             response = self.client.get(
                 reverse("api:dataset-elements", kwargs={"pk": str(self.dataset.id)}),
-                data={"set": "aaaaa"}
+                data={"set": "training"}
             )
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1128,10 +1136,11 @@ class TestDatasetsAPI(FixtureAPITestCase):
         })
 
     def test_list_elements_set_filter(self):
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="test")
-        self.dataset.dataset_elements.create(element_id=self.page2.id, set="training")
+        test_set, train_set, _ = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element_id=self.page1.id)
+        train_set.set_elements.create(element_id=self.page2.id)
         self.client.force_login(self.user)
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.get(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}),
                 data={"set": "training", "with_count": "true"},
@@ -1142,15 +1151,17 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertDictEqual(data, {"count": 1, "next": None, "previous": None})
         self.assertEqual(len(results), 1)
         dataset_element = results[0]
+        print(dataset_element)
         self.assertEqual(dataset_element["element"]["id"], str(self.page2.id))
         self.assertEqual(dataset_element["set"], "training")
 
     @patch("arkindex.documents.models.Element.thumbnail", MagicMock(s3_url="s3_url"))
     def test_list_elements(self):
-        self.dataset.dataset_elements.create(element_id=self.vol.id, set="test")
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="training")
-        self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation")
-        self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation")
+        test_set, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element_id=self.vol.id)
+        train_set.set_elements.create(element_id=self.page1.id)
+        validation_set.set_elements.create(element_id=self.page2.id)
+        validation_set.set_elements.create(element_id=self.page3.id)
         self.page3.confidence = 0.42
         self.page3.mirrored = True
         self.page3.rotation_angle = 42
@@ -1339,7 +1350,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             self.dataset.state = state
             self.dataset.save()
             with self.subTest(state=state):
-                with self.assertNumQueries(4):
+                with self.assertNumQueries(5):
                     response = self.client.get(
                         reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}),
                         {"page_size": 3},
@@ -1418,7 +1429,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_add_element_wrong_set(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.post(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
                 data={"set": "aaaaaaaaaaa", "element_id": str(self.vol.id)},
@@ -1426,7 +1437,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
-            "set": ["This dataset has no set named aaaaaaaaaaa."],
+            "set": ["Object with name=aaaaaaaaaaa does not exist."],
         })
 
     def test_add_element_dataset_requires_open(self):
@@ -1443,9 +1454,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertListEqual(response.json(), ["You can only add elements to a dataset in an open state."])
 
     def test_add_element_already_exists(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="test")
+        test_set = self.dataset.sets.order_by("name").first()
+        test_set.set_elements.create(element=self.page1)
         self.client.force_login(self.user)
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
                 data={"set": "test", "element_id": str(self.page1.id)},
@@ -1455,8 +1467,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertDictEqual(response.json(), {"non_field_errors": ["This element is already part of this set."]})
 
     def test_add_element(self):
+        train_set = self.dataset.sets.get(name="training")
         self.client.force_login(self.user)
-        with self.assertNumQueries(10):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
                 data={"set": "training", "element_id": str(self.page1.id)},
@@ -1464,7 +1477,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             )
             self.assertEqual(response.status_code, status.HTTP_201_CREATED)
         self.assertQuerysetEqual(
-            self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
+            train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"),
             [("training", "Volume 1, page 1r")]
         )
 
@@ -1575,9 +1588,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
         })
 
     def test_add_from_selection(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="training")
+        train_set = self.dataset.sets.get(name="training")
+        train_set.set_elements.create(element=self.page1)
         self.assertQuerysetEqual(
-            self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
+            train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"),
             [("training", "Volume 1, page 1r")]
         )
         self.user.selected_elements.set([self.vol, self.page1, self.page2])
@@ -1586,12 +1600,12 @@ class TestDatasetsAPI(FixtureAPITestCase):
         with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}),
-                data={"set": "training", "dataset_id": self.dataset.id},
+                data={"set_id": str(train_set.id)},
                 format="json",
             )
             self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
         self.assertQuerysetEqual(
-            self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
+            train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"),
             [
                 ("training", "Volume 1"),
                 ("training", "Volume 1, page 1r"),
@@ -1623,8 +1637,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
         """
         A non authenticated user can list datasets of a public element
         """
-        self.dataset.dataset_elements.create(element=self.vol, set="train")
-        with self.assertNumQueries(3):
+        train_set = self.dataset.sets.get(name="training")
+        train_set.set_elements.create(element=self.vol)
+        with self.assertNumQueries(4):
             response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.vol.id)}))
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1637,7 +1652,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1646,7 +1667,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": None,
                 "next": None
             }]
@@ -1654,10 +1675,12 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_element_datasets(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset2.dataset_elements.create(element=self.page1, set="train")
-        with self.assertNumQueries(5):
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1, set="train")
+        validation_set.set_elements.create(element=self.page1, set="validation")
+        train_set_2 = self.dataset2.sets.get(name="training")
+        train_set_2.set_elements.create(element=self.page1, set="train")
+        with self.assertNumQueries(6):
             response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}))
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1721,10 +1744,11 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_element_datasets_with_neighbors_false(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset2.dataset_elements.create(element=self.page1, set="train")
-        with self.assertNumQueries(5):
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        train_set.set_elements.create(element=self.page1)
+        with self.assertNumQueries(6):
             response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False})
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1788,12 +1812,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_element_datasets_with_neighbors(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page2, set="train")
-        self.dataset.dataset_elements.create(element=self.page3, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset2.dataset_elements.create(element=self.page1, set="train")
-        self.dataset2.dataset_elements.create(element=self.page3, set="train")
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        train_set.set_elements.create(element=self.page2)
+        train_set.set_elements.create(element=self.page3)
+        validation_set.set_elements.create(element=self.page1)
+        train_set.set_elements.create(element=self.page1)
+        train_set.set_elements.create(element=self.page3)
 
         # Results are alphabetically ordered and must not depend on the random page UUIDs
         sorted_dataset_elements = sorted([str(self.page1.id), str(self.page2.id), str(self.page3.id)])
@@ -1940,13 +1965,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.dataset.state = DatasetState.Error
         self.dataset.task = self.task
         self.dataset.save()
-        self.dataset.dataset_elements.create(element=self.page1, set="test")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset.dataset_elements.create(element=self.vol, set="validation")
+        test_set, _, validation_set = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.vol)
         self.assertCountEqual(self.corpus.datasets.values_list("name", flat=True), ["First Dataset", "Second Dataset"])
 
         self.client.force_login(self.user)
-        with self.assertNumQueries(12):
+        with self.assertNumQueries(16):
             response = self.client.post(
                 reverse("api:dataset-clone", kwargs={"pk": self.dataset.id}),
                 format="json",
@@ -1959,9 +1985,11 @@ class TestDatasetsAPI(FixtureAPITestCase):
         ])
         data = response.json()
         clone = self.corpus.datasets.get(id=data["id"])
+        test_clone, train_clone, val_clone = clone.sets.all().order_by("name")
         self.assertEqual(clone.creator, self.user)
         data.pop("created")
         data.pop("updated")
+        cloned_sets = data.pop("sets")
         self.assertDictEqual(
             response.json(),
             {
@@ -1970,14 +1998,27 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 "description": self.dataset.description,
                 "creator": self.user.display_name,
                 "corpus_id": str(self.corpus.id),
-                "sets": ["training", "test", "validation"],
                 "set_elements": {"test": 1, "training": 0, "validation": 2},
                 "state": DatasetState.Open.value,
                 "task_id": str(self.task.id),
             },
         )
+        self.assertCountEqual(cloned_sets, [
+            {
+                "name": "training",
+                "id": str(train_clone.id)
+            },
+            {
+                "name": "test",
+                "id": str(test_clone.id)
+            },
+            {
+                "name": "validation",
+                "id": str(val_clone.id)
+            }
+        ])
         self.assertQuerysetEqual(
-            clone.dataset_elements.values_list("set", "element__name").order_by("element__name", "set"),
+            DatasetElement.objects.filter(set__dataset_id=clone.id).values_list("set__name", "element__name").order_by("element__name", "set__name"),
             [
                 ("validation", "Volume 1"),
                 ("test", "Volume 1, page 1r"),
@@ -2040,7 +2081,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
@@ -2051,28 +2092,29 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
     @patch("arkindex.project.mixins.has_access", return_value=False)
     def test_destroy_dataset_element_requires_contributor(self, has_access_mock):
         self.client.force_login(self.read_user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
         self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the dataset to perform this action."})
         self.dataset.refresh_from_db()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
 
         self.assertEqual(has_access_mock.call_count, 1)
         self.assertEqual(has_access_mock.call_args, call(self.read_user, self.corpus, Role.Contributor.value, skip_public=False))
@@ -2089,23 +2131,24 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_destroy_dataset_element_requires_open_dataset(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
         self.dataset.state = DatasetState.Error
         self.dataset.save()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {"dataset": ["Elements can only be removed from open Datasets."]})
         self.dataset.refresh_from_db()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
 
     def test_destroy_dataset_element_dataset_doesnt_exist(self):
         self.client.force_login(self.user)
@@ -2113,7 +2156,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
@@ -2135,49 +2178,52 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
 
     def test_destroy_dataset_element_element_not_in_dataset(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
+        train_set = self.dataset.sets.get(name="training")
+        train_set.set_elements.create(element=self.page1, set="train")
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page2.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
 
     def test_destroy_dataset_element_wrong_set(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page2, set="validation")
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page2.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
 
     def test_destroy_dataset_element(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
         with self.assertNumQueries(4):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
         self.dataset.refresh_from_db()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 0)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 0)
+        self.assertEqual(validation_set.set_elements.count(), 1)
-- 
GitLab