diff --git a/arkindex/documents/fixtures/data.json b/arkindex/documents/fixtures/data.json
index 85895a434fbde8fb01a0bd2d4073aca23ebe8669..17a2db9bca93b5660affba6729723065a5164426 100644
--- a/arkindex/documents/fixtures/data.json
+++ b/arkindex/documents/fixtures/data.json
@@ -5579,7 +5579,7 @@
     "model": "training.datasetset",
     "pk": "00e3b37b-f1ed-4adb-a1de-a1c103edaa24",
     "fields": {
-        "name": "Test",
+        "name": "test",
         "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
     }
 },
@@ -5587,7 +5587,7 @@
     "model": "training.datasetset",
     "pk": "95255ff4-7bca-424c-8e7e-d8b5c33c585f",
     "fields": {
-        "name": "Train",
+        "name": "training",
         "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
     }
 },
@@ -5595,7 +5595,7 @@
     "model": "training.datasetset",
     "pk": "b21e4b31-1dc1-4015-9933-3a8e180cf2e0",
     "fields": {
-        "name": "Validation",
+        "name": "validation",
         "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
     }
 },
@@ -5603,7 +5603,7 @@
     "model": "training.datasetset",
     "pk": "b76d919c-ce7e-4896-94c3-9b47862b997a",
     "fields": {
-        "name": "Validation",
+        "name": "validation",
         "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
     }
 },
@@ -5611,7 +5611,7 @@
     "model": "training.datasetset",
     "pk": "d5f4d410-0588-4d19-8d08-b57a11bad67f",
     "fields": {
-        "name": "Train",
+        "name": "training",
         "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
     }
 },
@@ -5619,7 +5619,7 @@
     "model": "training.datasetset",
     "pk": "db70ad8a-8d6b-44c6-b8d3-837bf5a4ab8e",
     "fields": {
-        "name": "Test",
+        "name": "test",
         "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
     }
 }
diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index 17c92d941b684639b7d6219340ce33da615ba0ed..3c893fc12fc95707cd24b4363ff115a9f67bfbfe 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -109,7 +109,7 @@ from arkindex.training.api import (
     DatasetElementDestroy,
     DatasetElements,
     DatasetUpdate,
-    ElementDatasets,
+    ElementDatasetSets,
     MetricValueBulkCreate,
     MetricValueCreate,
     ModelCompatibleWorkerManage,
@@ -184,7 +184,7 @@ api = [
     # Datasets
     path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"),
     path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"),
-    path("element/<uuid:pk>/datasets/", ElementDatasets.as_view(), name="element-datasets"),
+    path("element/<uuid:pk>/datasets/", ElementDatasetSets.as_view(), name="element-datasets"),
     path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"),
     path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"),
     path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"),
diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py
index 7667083d861344b70be059872111055dcea360c7..b7bf6429586f4d9f737ee86382b6f2126953b223 100644
--- a/arkindex/project/serializer_fields.py
+++ b/arkindex/project/serializer_fields.py
@@ -286,7 +286,7 @@ class DatasetSetsCountField(serializers.DictField):
             return None
         elts_count = {k.name: 0 for k in instance.sets.all()}
         elts_count.update(
-            DatasetElement.objects.filter(set__dataset_id=instance.id)
+            DatasetElement.objects.filter(set_id__in=instance.sets.values_list("id"))
             .values("set__name")
             .annotate(count=Count("id"))
             .values_list("set__name", "count")
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 06034116e91209f2d4f5eb20be8d18ba66921670..0a42b292e206b7a48e9b1eb247a31a02e03dad09 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap
 from arkindex.training.models import (
     Dataset,
     DatasetElement,
+    DatasetSet,
     DatasetState,
     MetricValue,
     Model,
@@ -40,7 +41,7 @@ from arkindex.training.serializers import (
     DatasetElementInfoSerializer,
     DatasetElementSerializer,
     DatasetSerializer,
-    ElementDatasetSerializer,
+    ElementDatasetSetSerializer,
     MetricValueBulkSerializer,
     MetricValueCreateSerializer,
     ModelCompatibleWorkerSerializer,
@@ -58,7 +59,7 @@ def _fetch_datasetelement_neighbors(datasetelements):
     """
     Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
     with next and previous attributes.
-    The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and
+    The ElementDatasetSets endpoint uses arkindex.project.tools.BulkMap to apply this method and
     perform the second request *after* DRF's pagination, because there is no way to perform
     post-processing after pagination in Django without having to use Django private methods.
     """
@@ -687,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
     serializer_class = DatasetSerializer
 
     def get_queryset(self):
-        queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
+        queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)).prefetch_related("sets")
         return queryset.select_related("corpus", "creator")
 
     def check_object_permissions(self, request, obj):
@@ -708,7 +709,8 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
             raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.")
 
     def perform_destroy(self, dataset):
-        dataset.dataset_elements.all().delete()
+        DatasetElement.objects.filter(set__dataset_id=dataset.id).delete()
+        dataset.sets.all().delete()
         super().perform_destroy(dataset)
 
 
@@ -769,13 +771,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
 
     def get_queryset(self):
         qs = (
-            self.dataset.dataset_elements
-            .prefetch_related("element")
+            DatasetElement.objects.filter(set__dataset_id=self.dataset.id)
+            .prefetch_related("element", "set")
             .select_related("element__type", "element__corpus", "element__image__server")
             .order_by("element_id", "id")
         )
         if "set" in self.request.query_params:
-            qs = qs.filter(set=self.request.query_params["set"])
+            qs = qs.filter(set__name=self.request.query_params["set"])
         return qs
 
     def get_serializer_context(self):
@@ -788,20 +790,12 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
 @extend_schema_view(
     delete=extend_schema(
         operation_id="DestroyDatasetElement",
-        parameters=[
-            OpenApiParameter(
-                "set",
-                type=str,
-                description="Name of the set from which to remove the element.",
-                required=True,
-            )
-        ],
         tags=["datasets"]
     )
 )
 class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
     """
-    Remove an element from a dataset.
+    Remove an element from a dataset set.
 
     Elements can only be removed from **open** datasets.
 
@@ -812,17 +806,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
     lookup_url_kwarg = "element"
 
     def destroy(self, request, *args, **kwargs):
-        if not self.request.query_params.get("set"):
+        if not (set_name := self.request.query_params.get("set")):
             raise ValidationError({"set": ["This field is required."]})
         dataset_element = get_object_or_404(
-            DatasetElement.objects.select_related("dataset__corpus"),
-            dataset_id=self.kwargs["dataset"],
-            element_id=self.kwargs["element"],
-            set=self.request.query_params.get("set")
+            DatasetElement.objects.select_related("set__dataset__corpus").filter(set__dataset_id=self.kwargs["dataset"], set__name=set_name),
+            element_id=self.kwargs["element"]
         )
-        if dataset_element.dataset.state != DatasetState.Open:
+        if dataset_element.set.dataset.state != DatasetState.Open:
             raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]})
-        if not self.has_write_access(dataset_element.dataset.corpus):
+        if not self.has_write_access(dataset_element.set.dataset.corpus):
             raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.")
         dataset_element.delete()
         return Response(status=status.HTTP_204_NO_CONTENT)
@@ -897,14 +889,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
         )
     ],
 )
-class ElementDatasets(CorpusACLMixin, ListAPIView):
+class ElementDatasetSets(CorpusACLMixin, ListAPIView):
     """
-    List all datasets containing a specific element.
+    List all dataset sets containing a specific element.
 
     Requires a **guest** access to the element's corpus.
     """
     permission_classes = (IsVerifiedOrReadOnly, )
-    serializer_class = ElementDatasetSerializer
+    serializer_class = ElementDatasetSetSerializer
 
     @cached_property
     def element(self):
@@ -916,9 +908,9 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
 
     def get_queryset(self):
         qs = (
-            self.element.dataset_elements.all()
-            .select_related("dataset__creator")
-            .order_by("dataset__name", "set", "dataset_id")
+            self.element.dataset_elements
+            .select_related("set__dataset__creator")
+            .order_by("set__name", "id")
         )
 
         with_neighbors = self.request.query_params.get("with_neighbors", "false")
@@ -996,11 +988,18 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
         clone.creator = request.user
         clone.save()
 
+        # Clone dataset sets
+        cloned_sets = DatasetSet.objects.bulk_create([
+            DatasetSet(dataset_id=clone.id, name=set.name)
+            for set in dataset.sets.all()
+        ])
         # Associate all elements to the clone
         DatasetElement.objects.bulk_create([
-            DatasetElement(element_id=elt_id, dataset_id=clone.id, set=set_name)
-            for elt_id, set_name in dataset.dataset_elements.values_list("element_id", "set")
+            DatasetElement(element_id=elt_id, set=next(new_set for new_set in cloned_sets if new_set.name == set_name))
+            for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id)
+            .values_list("element_id", "set__name")
         ])
+
         return Response(
             DatasetSerializer(clone).data,
             status=status.HTTP_201_CREATED,
diff --git a/arkindex/training/migrations/0007_datasetset_model.py b/arkindex/training/migrations/0007_datasetset_model.py
index 8630baa59354605087acdc25145170ebb2183b87..c3052f691c88ebbbe65ace7e66e24cfbaa5f3065 100644
--- a/arkindex/training/migrations/0007_datasetset_model.py
+++ b/arkindex/training/migrations/0007_datasetset_model.py
@@ -59,7 +59,7 @@ class Migration(migrations.Migration):
             UPDATE training_datasetelement de
             SET set_id_id = ds.id
             FROM training_datasetset ds
-            WHERE de.dataset_id = ds.dataset_id
+            WHERE de.dataset_id = ds.dataset_id AND de.set = ds.name
             """,
             reverse_sql=migrations.RunSQL.noop,
         ),
@@ -85,6 +85,10 @@ class Migration(migrations.Migration):
             name="set",
             field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="set_elements", to="training.datasetset"),
         ),
+        migrations.AddConstraint(
+            model_name="datasetelement",
+            constraint=models.UniqueConstraint(fields=("element_id", "set_id"), name="unique_set_element"),
+        ),
         migrations.RemoveField(
             model_name="dataset",
             name="sets"
diff --git a/arkindex/training/models.py b/arkindex/training/models.py
index 06f4f84c5ff4452f56d7f65437256751cb93f1c5..c37ea35b09193f087793b2c48c194cc5f3c64e48 100644
--- a/arkindex/training/models.py
+++ b/arkindex/training/models.py
@@ -309,3 +309,11 @@ class DatasetElement(models.Model):
         related_name="set_elements",
         on_delete=models.DO_NOTHING,
     )
+
+    class Meta:
+        constraints = [
+            models.UniqueConstraint(
+                fields=["element_id", "set_id"],
+                name="unique_set_element",
+            ),
+        ]
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 4acb03a6638b16a41be046c37fef9667096f5a3a..37dacc0d76fda21965f198cdefbf847a4f4d8b17 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -10,7 +10,7 @@ from rest_framework import permissions, serializers
 from rest_framework.exceptions import PermissionDenied, ValidationError
 from rest_framework.validators import UniqueTogetherValidator
 
-from arkindex.documents.models import Element
+from arkindex.documents.models import Corpus, Element
 from arkindex.documents.serializers.elements import ElementListSerializer
 from arkindex.ponos.models import Task
 from arkindex.process.models import Worker
@@ -512,7 +512,7 @@ class DatasetSerializer(serializers.ModelSerializer):
         help_text="Display name of the user who created the dataset.",
     )
 
-    set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, default=["Training", "Validation", "Test"])
+    set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, required=False)
     sets = DatasetSetSerializer(many=True, read_only=True)
 
     # When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
@@ -555,6 +555,8 @@ class DatasetSerializer(serializers.ModelSerializer):
             raise ValidationError("This API endpoint does not allow updating a dataset's sets.")
         if set_names is not None and len(set(set_names)) != len(set_names):
             raise ValidationError("Set names must be unique.")
+        if set_names is not None and len(set_names) == 0:
+            raise ValidationError("Either do not specify set names to use the default values, or specify a non-empty list of names.")
         return set_names
 
     def validate(self, data):
@@ -585,7 +587,10 @@ class DatasetSerializer(serializers.ModelSerializer):
 
     @transaction.atomic
     def create(self, validated_data):
-        sets = validated_data.pop("set_names")
+        if "set_names" not in validated_data:
+            sets = ["training", "validation", "test"]
+        else:
+            sets = validated_data.pop("set_names")
         dataset = Dataset.objects.create(**validated_data)
         DatasetSet.objects.bulk_create(
             DatasetSet(
@@ -657,6 +662,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         default=_dataset_from_context,
         write_only=True,
     )
+    set = serializers.SlugRelatedField(queryset=DatasetSet.objects.none(), slug_field="name")
 
     class Meta:
         model = DatasetElement
@@ -665,7 +671,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         validators = [
             UniqueTogetherValidator(
                 queryset=DatasetElement.objects.all(),
-                fields=["dataset", "element_id", "set"],
+                fields=["element_id", "set"],
                 message="This element is already part of this set.",
             )
         ]
@@ -674,13 +680,12 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         super().__init__(*args, **kwargs)
         if dataset := self.context.get("dataset"):
             self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
+            self.fields["set"].queryset = dataset.sets.all()
 
-    def validate_set(self, value):
-        # The set must match the `sets` array defined at the dataset level
-        dataset = self.context["dataset"]
-        if dataset and value not in dataset.sets:
-            raise ValidationError(f"This dataset has no set named {value}.")
-        return value
+    def validate(self, data):
+        data = super().validate(data)
+        data.pop("dataset")
+        return data
 
 
 class DatasetElementInfoSerializer(DatasetElementSerializer):
@@ -698,10 +703,11 @@ class DatasetElementInfoSerializer(DatasetElementSerializer):
         fields = DatasetElementSerializer.Meta.fields + ("dataset",)
 
 
-class ElementDatasetSerializer(serializers.ModelSerializer):
-    dataset = DatasetSerializer()
+class ElementDatasetSetSerializer(serializers.ModelSerializer):
+    dataset = DatasetSerializer(source="set.dataset")
     previous = serializers.UUIDField(allow_null=True, read_only=True)
     next = serializers.UUIDField(allow_null=True, read_only=True)
+    set = serializers.SlugRelatedField(slug_field="name", read_only=True)
 
     class Meta:
         model = DatasetElement
@@ -721,7 +727,7 @@ class SelectionDatasetElementSerializer(serializers.Serializer):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.fields["set_id"].queryset = DatasetSet.objects.filter(
-            dataset__corpus_id=self.context["corpus"].id
+            dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user)
         ).select_related("dataset")
 
     def validate_set_id(self, set):
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 618d806ed1b2a4f74b0260de31546734243ed329..56f7e3fb418a4d8364a3831c4cc6948630144791 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -9,7 +9,7 @@ from arkindex.documents.models import Corpus
 from arkindex.process.models import Process, ProcessDataset, ProcessMode
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.project.tools import fake_now
-from arkindex.training.models import Dataset, DatasetSet, DatasetState
+from arkindex.training.models import Dataset, DatasetElement, DatasetSet, DatasetState
 from arkindex.users.models import Role, User
 
 # Using the fake DB fixtures creation date when needed
@@ -267,7 +267,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_create(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(10):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}),
                 data={"name": "My dataset", "description": "My dataset for my experiments."},
@@ -279,11 +279,17 @@ class TestDatasetsAPI(FixtureAPITestCase):
             "id": str(created_dataset.id),
             "name": "My dataset",
             "description": "My dataset for my experiments.",
-            "sets": {},
+            "sets": [
+                {
+                    "id": str(ds.id),
+                    "name": ds.name
+                }
+                for ds in created_dataset.sets.all()
+            ],
             "set_elements": {
-                "Training": 0,
-                "Test": 0,
-                "Validation": 0,
+                "training": 0,
+                "test": 0,
+                "validation": 0,
             },
             "state": "open",
             "creator": "Test user",
@@ -295,7 +301,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_create_state_ignored(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}),
                 data={"name": "My dataset", "description": "My dataset for my experiments.", "state": "complete"},
@@ -308,8 +314,18 @@ class TestDatasetsAPI(FixtureAPITestCase):
             "id": str(created_dataset.id),
             "name": "My dataset",
             "description": "My dataset for my experiments.",
-            "sets": ["training", "test", "validation"],
-            "set_elements": {"test": 0, "training": 0, "validation": 0},
+            "sets": [
+                {
+                    "id": str(ds.id),
+                    "name": ds.name
+                }
+                for ds in created_dataset.sets.all()
+            ],
+            "set_elements": {
+                "training": 0,
+                "test": 0,
+                "validation": 0,
+            },
             "state": "open",
             "creator": "Test user",
             "task_id": None,
@@ -320,10 +336,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_create_sets(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}),
-                data={"name": "My dataset", "description": "My dataset for my experiments.", "sets": ["a", "b", "c", "d"]},
+                data={"name": "My dataset", "description": "My dataset for my experiments.", "set_names": ["a", "b", "c", "d"]},
                 format="json"
             )
             self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -332,7 +348,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
             "id": str(created_dataset.id),
             "name": "My dataset",
             "description": "My dataset for my experiments.",
-            "sets": ["a", "b", "c", "d"],
+            "sets": [
+                {
+                    "id": str(ds.id),
+                    "name": ds.name
+                }
+                for ds in created_dataset.sets.all()
+            ],
             "set_elements": {"a": 0, "b": 0, "c": 0, "d": 0},
             "state": "open",
             "creator": "Test user",
@@ -342,39 +364,39 @@ class TestDatasetsAPI(FixtureAPITestCase):
             "updated": created_dataset.updated.isoformat().replace("+00:00", "Z"),
         })
 
-    def test_create_sets_length(self):
+    def test_create_sets_empty_list(self):
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}),
-                data={"name": "My dataset", "description": "My dataset for my experiments.", "sets": []},
+                data={"name": "My dataset", "description": "My dataset for my experiments.", "set_names": []},
                 format="json"
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-        self.assertDictEqual(response.json(), {"sets": ["Either do not specify set names to use the default values, or specify a non-empty list of names."]})
+        self.assertDictEqual(response.json(), {"set_names": ["Either do not specify set names to use the default values, or specify a non-empty list of names."]})
 
     def test_create_sets_unique_names(self):
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}),
-                data={"name": "My dataset", "description": "My dataset for my experiments.", "sets": ["a", "a", "b"]},
+                data={"name": "My dataset", "description": "My dataset for my experiments.", "set_names": ["a", "a", "b"]},
                 format="json"
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-        self.assertDictEqual(response.json(), {"sets": ["Set names must be unique."]})
+        self.assertDictEqual(response.json(), {"set_names": ["Set names must be unique."]})
 
     def test_create_sets_blank_names(self):
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}),
-                data={"name": "My dataset", "description": "My dataset for my experiments.", "sets": ["     ", " ", "b"]},
+                data={"name": "My dataset", "description": "My dataset for my experiments.", "set_names": ["     ", " ", "b"]},
                 format="json"
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
-            "sets":
+            "set_names":
                 {
                     "0": ["This field may not be blank."],
                     "1": ["This field may not be blank."]
@@ -389,13 +411,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 data={
                     "name": "My Dataset",
                     "description": "My dataset for my experiments.",
-                    "sets": ["unit-00", "Etiam accumsan ullamcorper mauris eget mattis. Ut porttitor."]
+                    "set_names": ["unit-00", "Etiam accumsan ullamcorper mauris eget mattis. Ut porttitor."]
                 },
                 format="json"
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
-            "sets": {
+            "set_names": {
                 "1": ["Ensure this field has no more than 50 characters."]
             }
         })
@@ -409,7 +431,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 data={
                     "name": "Shin Seiki Evangelion",
                     "description": "Omedeto!",
-                    "sets": ["unit-01", "unit-00", "unit-02"],
                 },
                 format="json"
             )
@@ -425,7 +446,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 data={
                     "name": "Shin Seiki Evangelion",
                     "description": "Omedeto!",
-                    "sets": ["unit-01", "unit-00", "unit-02"],
                 },
                 format="json"
             )
@@ -441,7 +461,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 data={
                     "name": "Shin Seiki Evangelion",
                     "description": "Omedeto!",
-                    "sets": ["unit-01", "unit-00", "unit-02"],
                 },
                 format="json"
             )
@@ -459,7 +478,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 data={
                     "name": "Shin Seiki Evangelion",
                     "description": "Omedeto!",
-                    "sets": ["unit-01", "unit-00", "unit-02"],
                 },
                 format="json"
             )
@@ -489,7 +507,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 data={
                     "name": "Another Dataset",
                     "description": "My dataset for my experiments.",
-                    "sets": self.dataset.sets,
                 },
                 format="json"
             )
@@ -652,7 +669,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
                         data={
                             "name": "AA",
                             "description": "BB",
-                            "sets": self.dataset.sets + ["CC"],
                             "state": new_state.value,
                         },
                         format="json"
@@ -779,7 +795,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertEqual(self.dataset.state, DatasetState.Open)
         self.assertEqual(self.dataset.name, "First Dataset")
         self.assertEqual(self.dataset.description, "Omedeto!")
-        self.assertListEqual(self.dataset.sets, ["training", "test", "validation"])
+        self.assertCountEqual(list(self.dataset.sets.values_list("name", flat=True)), ["training", "test", "validation"])
 
     def test_partial_update_empty_or_blank_description_or_name(self):
         self.client.force_login(self.user)
@@ -909,7 +925,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
                         data={
                             "name": "AA",
                             "description": "BB",
-                            "sets": self.dataset.sets + ["CC"],
                             "state": new_state.value,
                         },
                         format="json"
@@ -966,7 +981,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_retrieve(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.get(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk})
             )
@@ -976,7 +991,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
             "name": "First Dataset",
             "description": "dataset number one",
             "state": "open",
-            "sets": ["training", "test", "validation"],
+            "sets": [
+                {
+                    "id": str(ds.id),
+                    "name": ds.name
+                }
+                for ds in self.dataset.sets.all()
+            ],
             "set_elements": {"test": 0, "training": 0, "validation": 0},
             "creator": "Test user",
             "task_id": None,
@@ -989,7 +1010,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.client.force_login(self.user)
         self.dataset.task = self.task
         self.dataset.save()
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.get(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk})
             )
@@ -1040,7 +1061,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_delete(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.delete(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}),
             )
@@ -1066,14 +1087,15 @@ class TestDatasetsAPI(FixtureAPITestCase):
         """
         DestroyDataset also deletes DatasetElements
         """
-        self.dataset.dataset_elements.create(element_id=self.vol.id, set="test")
-        self.dataset.dataset_elements.create(element_id=self.vol.id, set="training")
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="training")
-        self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation")
-        self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation")
+        test_set, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element_id=self.vol.id)
+        train_set.set_elements.create(element_id=self.vol.id, set="training")
+        train_set.set_elements.create(element_id=self.page1.id, set="training")
+        validation_set.set_elements.create(element_id=self.page2.id, set="validation")
+        validation_set.set_elements.create(element_id=self.page3.id, set="validation")
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.delete(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}),
             )
@@ -1082,7 +1104,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
         with self.assertRaises(Dataset.DoesNotExist):
             self.dataset.refresh_from_db()
 
-        self.assertFalse(self.dataset.dataset_elements.exists())
+        self.assertFalse(DatasetElement.objects.filter(set__dataset_id=self.dataset.id).exists())
 
         # No elements should have been deleted
         self.vol.refresh_from_db()
@@ -1112,12 +1134,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value))
 
     def test_list_elements_set_filter_wrong_set(self):
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="test")
+        test_set = self.dataset.sets.order_by("name").first()
+        test_set.set_elements.create(element_id=self.page1.id)
         self.client.force_login(self.user)
         with self.assertNumQueries(4):
             response = self.client.get(
                 reverse("api:dataset-elements", kwargs={"pk": str(self.dataset.id)}),
-                data={"set": "aaaaa"}
+                data={"set": "training"}
             )
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1128,10 +1151,11 @@ class TestDatasetsAPI(FixtureAPITestCase):
         })
 
     def test_list_elements_set_filter(self):
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="test")
-        self.dataset.dataset_elements.create(element_id=self.page2.id, set="training")
+        test_set, train_set, _ = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element_id=self.page1.id)
+        train_set.set_elements.create(element_id=self.page2.id)
         self.client.force_login(self.user)
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.get(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}),
                 data={"set": "training", "with_count": "true"},
@@ -1142,15 +1166,17 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertDictEqual(data, {"count": 1, "next": None, "previous": None})
         self.assertEqual(len(results), 1)
         dataset_element = results[0]
+        print(dataset_element)
         self.assertEqual(dataset_element["element"]["id"], str(self.page2.id))
         self.assertEqual(dataset_element["set"], "training")
 
     @patch("arkindex.documents.models.Element.thumbnail", MagicMock(s3_url="s3_url"))
     def test_list_elements(self):
-        self.dataset.dataset_elements.create(element_id=self.vol.id, set="test")
-        self.dataset.dataset_elements.create(element_id=self.page1.id, set="training")
-        self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation")
-        self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation")
+        test_set, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element_id=self.vol.id)
+        train_set.set_elements.create(element_id=self.page1.id)
+        validation_set.set_elements.create(element_id=self.page2.id)
+        validation_set.set_elements.create(element_id=self.page3.id)
         self.page3.confidence = 0.42
         self.page3.mirrored = True
         self.page3.rotation_angle = 42
@@ -1339,7 +1365,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             self.dataset.state = state
             self.dataset.save()
             with self.subTest(state=state):
-                with self.assertNumQueries(4):
+                with self.assertNumQueries(5):
                     response = self.client.get(
                         reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}),
                         {"page_size": 3},
@@ -1418,7 +1444,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_add_element_wrong_set(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(5):
             response = self.client.post(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
                 data={"set": "aaaaaaaaaaa", "element_id": str(self.vol.id)},
@@ -1426,7 +1452,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
-            "set": ["This dataset has no set named aaaaaaaaaaa."],
+            "set": ["Object with name=aaaaaaaaaaa does not exist."],
         })
 
     def test_add_element_dataset_requires_open(self):
@@ -1443,9 +1469,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertListEqual(response.json(), ["You can only add elements to a dataset in an open state."])
 
     def test_add_element_already_exists(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="test")
+        test_set = self.dataset.sets.order_by("name").first()
+        test_set.set_elements.create(element=self.page1)
         self.client.force_login(self.user)
-        with self.assertNumQueries(5):
+        with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
                 data={"set": "test", "element_id": str(self.page1.id)},
@@ -1455,8 +1482,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertDictEqual(response.json(), {"non_field_errors": ["This element is already part of this set."]})
 
     def test_add_element(self):
+        train_set = self.dataset.sets.get(name="training")
         self.client.force_login(self.user)
-        with self.assertNumQueries(10):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
                 data={"set": "training", "element_id": str(self.page1.id)},
@@ -1464,7 +1492,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             )
             self.assertEqual(response.status_code, status.HTTP_201_CREATED)
         self.assertQuerysetEqual(
-            self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
+            train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"),
             [("training", "Volume 1, page 1r")]
         )
 
@@ -1541,10 +1569,11 @@ class TestDatasetsAPI(FixtureAPITestCase):
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
-            "set_id": ["`AAA` is not a valid UUID."],
+            "set_id": ["“AAA” is not a valid UUID."],
         })
 
     def test_add_from_selection_wrong_dataset(self):
+        self.private_corpus.memberships.create(user=self.user, level=Role.Contributor.value)
         self.client.force_login(self.user)
         with self.assertNumQueries(4):
             response = self.client.post(
@@ -1554,7 +1583,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
-            "dataset_id": [f"Dataset {self.private_dataset.id} is not part of corpus Unit Tests."],
+            "set_id": [f"Dataset {self.private_dataset.id} is not part of corpus Unit Tests."],
         })
 
     def test_add_from_selection_completed_dataset(self):
@@ -1575,9 +1604,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
         })
 
     def test_add_from_selection(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="training")
+        train_set = self.dataset.sets.get(name="training")
+        train_set.set_elements.create(element=self.page1)
         self.assertQuerysetEqual(
-            self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
+            train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"),
             [("training", "Volume 1, page 1r")]
         )
         self.user.selected_elements.set([self.vol, self.page1, self.page2])
@@ -1586,12 +1616,12 @@ class TestDatasetsAPI(FixtureAPITestCase):
         with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}),
-                data={"set": "training", "dataset_id": self.dataset.id},
+                data={"set_id": str(train_set.id)},
                 format="json",
             )
             self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
         self.assertQuerysetEqual(
-            self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
+            train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"),
             [
                 ("training", "Volume 1"),
                 ("training", "Volume 1, page 1r"),
@@ -1623,8 +1653,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
         """
         A non authenticated user can list datasets of a public element
         """
-        self.dataset.dataset_elements.create(element=self.vol, set="train")
-        with self.assertNumQueries(3):
+        train_set = self.dataset.sets.get(name="training")
+        train_set.set_elements.create(element=self.vol)
+        with self.assertNumQueries(4):
             response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.vol.id)}))
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1637,7 +1668,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1646,7 +1683,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
-                "set": "train",
+                "set": "training",
                 "previous": None,
                 "next": None
             }]
@@ -1654,10 +1691,12 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_element_datasets(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset2.dataset_elements.create(element=self.page1, set="train")
-        with self.assertNumQueries(5):
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1, set="train")
+        validation_set.set_elements.create(element=self.page1, set="validation")
+        train_set_2 = self.dataset2.sets.get(name="training")
+        train_set_2.set_elements.create(element=self.page1, set="train")
+        with self.assertNumQueries(6):
             response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}))
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1670,7 +1709,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1687,7 +1732,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1704,7 +1755,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset2.id),
                     "name": "Second Dataset",
                     "description": "dataset number two",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset2.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1721,10 +1778,12 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_element_datasets_with_neighbors_false(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset2.dataset_elements.create(element=self.page1, set="train")
-        with self.assertNumQueries(5):
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set_2 = self.dataset2.sets.get(name="training")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        train_set_2.set_elements.create(element=self.page1)
+        with self.assertNumQueries(6):
             response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False})
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
@@ -1737,7 +1796,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1754,7 +1819,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1771,7 +1842,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset2.id),
                     "name": "Second Dataset",
                     "description": "dataset number two",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset2.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1788,12 +1865,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_element_datasets_with_neighbors(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page2, set="train")
-        self.dataset.dataset_elements.create(element=self.page3, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset2.dataset_elements.create(element=self.page1, set="train")
-        self.dataset2.dataset_elements.create(element=self.page3, set="train")
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set_2 = self.dataset2.sets.get(name="training")
+        train_set.set_elements.create(element=self.page1)
+        train_set.set_elements.create(element=self.page2)
+        train_set.set_elements.create(element=self.page3)
+        validation_set.set_elements.create(element=self.page1)
+        train_set_2.set_elements.create(element=self.page1)
+        train_set_2.set_elements.create(element=self.page3)
 
         # Results are alphabetically ordered and must not depend on the random page UUIDs
         sorted_dataset_elements = sorted([str(self.page1.id), str(self.page2.id), str(self.page3.id)])
@@ -1814,7 +1893,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1839,7 +1924,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset.id),
                     "name": "First Dataset",
                     "description": "dataset number one",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1856,7 +1947,13 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "id": str(self.dataset2.id),
                     "name": "Second Dataset",
                     "description": "dataset number two",
-                    "sets": ["training", "test", "validation"],
+                    "sets": [
+                        {
+                            "id": str(ds.id),
+                            "name": ds.name
+                        }
+                        for ds in self.dataset.sets.all()
+                    ],
                     "set_elements": None,
                     "state": "open",
                     "corpus_id": str(self.corpus.id),
@@ -1940,13 +2037,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.dataset.state = DatasetState.Error
         self.dataset.task = self.task
         self.dataset.save()
-        self.dataset.dataset_elements.create(element=self.page1, set="test")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.dataset.dataset_elements.create(element=self.vol, set="validation")
+        test_set, _, validation_set = self.dataset.sets.all().order_by("name")
+        test_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.vol)
         self.assertCountEqual(self.corpus.datasets.values_list("name", flat=True), ["First Dataset", "Second Dataset"])
 
         self.client.force_login(self.user)
-        with self.assertNumQueries(12):
+        with self.assertNumQueries(16):
             response = self.client.post(
                 reverse("api:dataset-clone", kwargs={"pk": self.dataset.id}),
                 format="json",
@@ -1959,9 +2057,11 @@ class TestDatasetsAPI(FixtureAPITestCase):
         ])
         data = response.json()
         clone = self.corpus.datasets.get(id=data["id"])
+        test_clone, train_clone, val_clone = clone.sets.all().order_by("name")
         self.assertEqual(clone.creator, self.user)
         data.pop("created")
         data.pop("updated")
+        cloned_sets = data.pop("sets")
         self.assertDictEqual(
             response.json(),
             {
@@ -1970,14 +2070,27 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 "description": self.dataset.description,
                 "creator": self.user.display_name,
                 "corpus_id": str(self.corpus.id),
-                "sets": ["training", "test", "validation"],
                 "set_elements": {"test": 1, "training": 0, "validation": 2},
                 "state": DatasetState.Open.value,
                 "task_id": str(self.task.id),
             },
         )
+        self.assertCountEqual(cloned_sets, [
+            {
+                "name": "training",
+                "id": str(train_clone.id)
+            },
+            {
+                "name": "test",
+                "id": str(test_clone.id)
+            },
+            {
+                "name": "validation",
+                "id": str(val_clone.id)
+            }
+        ])
         self.assertQuerysetEqual(
-            clone.dataset_elements.values_list("set", "element__name").order_by("element__name", "set"),
+            DatasetElement.objects.filter(set__dataset_id=clone.id).values_list("set__name", "element__name").order_by("element__name", "set__name"),
             [
                 ("validation", "Volume 1"),
                 ("test", "Volume 1, page 1r"),
@@ -2012,8 +2125,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
                 "description": self.dataset.description,
                 "creator": self.user.display_name,
                 "corpus_id": str(self.corpus.id),
-                "sets": self.dataset.sets,
-                "set_elements": {k: 0 for k in self.dataset.sets},
+                "sets": [
+                    {
+                        "id": str(ds.id),
+                        "name": ds.name
+                    }
+                    for ds in self.dataset.sets.all()
+                ],
+                "set_elements": {k: 0 for k in self.dataset.sets.all()},
                 "state": DatasetState.Open.value,
                 "task_id": None,
             },
@@ -2040,7 +2159,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
@@ -2051,28 +2170,29 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
     @patch("arkindex.project.mixins.has_access", return_value=False)
     def test_destroy_dataset_element_requires_contributor(self, has_access_mock):
         self.client.force_login(self.read_user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
         self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the dataset to perform this action."})
         self.dataset.refresh_from_db()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
 
         self.assertEqual(has_access_mock.call_count, 1)
         self.assertEqual(has_access_mock.call_args, call(self.read_user, self.corpus, Role.Contributor.value, skip_public=False))
@@ -2089,23 +2209,24 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_destroy_dataset_element_requires_open_dataset(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
         self.dataset.state = DatasetState.Error
         self.dataset.save()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {"dataset": ["Elements can only be removed from open Datasets."]})
         self.dataset.refresh_from_db()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
 
     def test_destroy_dataset_element_dataset_doesnt_exist(self):
         self.client.force_login(self.user)
@@ -2113,7 +2234,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
@@ -2135,49 +2256,52 @@ class TestDatasetsAPI(FixtureAPITestCase):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
 
     def test_destroy_dataset_element_element_not_in_dataset(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
+        train_set = self.dataset.sets.get(name="training")
+        train_set.set_elements.create(element=self.page1, set="train")
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page2.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
 
     def test_destroy_dataset_element_wrong_set(self):
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page2, set="validation")
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page2.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertDictEqual(response.json(), {"detail": "Not found."})
 
     def test_destroy_dataset_element(self):
         self.client.force_login(self.user)
-        self.dataset.dataset_elements.create(element=self.page1, set="train")
-        self.dataset.dataset_elements.create(element=self.page1, set="validation")
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        _, train_set, validation_set = self.dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=self.page1)
+        validation_set.set_elements.create(element=self.page1)
+        self.assertEqual(train_set.set_elements.count(), 1)
+        self.assertEqual(validation_set.set_elements.count(), 1)
         with self.assertNumQueries(4):
             response = self.client.delete(reverse(
                 "api:dataset-element",
                 kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)})
-                + "?set=train"
+                + "?set=training"
             )
             self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
         self.dataset.refresh_from_db()
-        self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 0)
-        self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1)
+        self.assertEqual(train_set.set_elements.count(), 0)
+        self.assertEqual(validation_set.set_elements.count(), 1)