diff --git a/arkindex/documents/export/dataset.sql b/arkindex/documents/export/dataset.sql
index fd7990752bddaf22b3103dea95054f08f5e8f3e2..48d54910f083f8c3c81ea74b13b1387f23595bb9 100644
--- a/arkindex/documents/export/dataset.sql
+++ b/arkindex/documents/export/dataset.sql
@@ -2,6 +2,8 @@ SELECT
     dataset.id,
     dataset.name,
     dataset.state,
-    ARRAY_TO_STRING(dataset.sets, ',', '')
+    string_agg(datasetset.name, ',')
 FROM training_dataset dataset
+INNER JOIN training_datasetset datasetset ON datasetset.dataset_id = dataset.id
 WHERE dataset.corpus_id = '{corpus_id}'::uuid
+GROUP BY dataset.id
diff --git a/arkindex/documents/export/dataset_element.sql b/arkindex/documents/export/dataset_element.sql
index 4084e2e0cde82b451187c36ac6db2304d4576486..c75624c81d1c5129b2864e453cc6d61be562f480 100644
--- a/arkindex/documents/export/dataset_element.sql
+++ b/arkindex/documents/export/dataset_element.sql
@@ -1,8 +1,9 @@
 SELECT
     dataset_element.id,
     dataset_element.element_id,
-    dataset_element.dataset_id,
-    dataset_element.set
+    dataset_set.dataset_id,
+    dataset_set.name
 FROM training_datasetelement dataset_element
-INNER JOIN training_dataset dataset ON (dataset_element.dataset_id = dataset.id)
+INNER JOIN training_datasetset dataset_set ON (dataset_element.set_id = dataset_set.id)
+INNER JOIN training_dataset dataset ON (dataset_set.dataset_id = dataset.id)
 WHERE dataset.corpus_id = '{corpus_id}'::uuid
diff --git a/arkindex/documents/tasks.py b/arkindex/documents/tasks.py
index acd269aede5489fcea12ae307f90fb61b7a452f4..ee4c680a850683f409fda99ca42e28d3d7e8efc7 100644
--- a/arkindex/documents/tasks.py
+++ b/arkindex/documents/tasks.py
@@ -24,7 +24,7 @@ from arkindex.documents.models import (
 )
 from arkindex.ponos.models import Task
 from arkindex.process.models import Process, ProcessDataset, ProcessElement, WorkerActivity, WorkerRun
-from arkindex.training.models import DatasetElement
+from arkindex.training.models import DatasetElement, DatasetSet
 from arkindex.users.models import User
 
 logger = logging.getLogger(__name__)
@@ -73,7 +73,8 @@ def corpus_delete(corpus_id: str) -> None:
         # ProcessDataset M2M
         ProcessDataset.objects.filter(dataset__corpus_id=corpus_id),
         ProcessDataset.objects.filter(process__corpus_id=corpus_id),
-        DatasetElement.objects.filter(dataset__corpus_id=corpus_id),
+        DatasetElement.objects.filter(set__dataset__corpus_id=corpus_id),
+        DatasetSet.objects.filter(dataset__corpus_id=corpus_id),
         corpus.datasets.all(),
         # Delete the hidden M2M task parents table
         Task.parents.through.objects.filter(from_task__process__corpus_id=corpus_id),
diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py
index de915af8ff86d040a5931354dab0954978288ec5..b007c1391c5cd2958a99b33435219a6595abf626 100644
--- a/arkindex/documents/tests/tasks/test_corpus_delete.py
+++ b/arkindex/documents/tests/tasks/test_corpus_delete.py
@@ -5,7 +5,7 @@ from arkindex.documents.tasks import corpus_delete
 from arkindex.ponos.models import Farm, State, Task
 from arkindex.process.models import CorpusWorkerVersion, ProcessDataset, ProcessMode, Repository, WorkerVersion
 from arkindex.project.tests import FixtureTestCase, force_constraints_immediate
-from arkindex.training.models import Dataset
+from arkindex.training.models import Dataset, DatasetSet
 
 
 class TestDeleteCorpus(FixtureTestCase):
@@ -114,18 +114,25 @@ class TestDeleteCorpus(FixtureTestCase):
         cls.corpus2 = Corpus.objects.create(name="Other corpus")
 
         dataset1 = Dataset.objects.get(name="First Dataset")
-        dataset1.dataset_elements.create(element=element, set="test")
+        test_set_1 = dataset1.sets.get(name="test")
+        test_set_1.set_elements.create(element=element)
         cls.dataset2 = Dataset.objects.create(name="Dead Sea Scrolls", description="How to trigger a Third Impact", creator=cls.user, corpus=cls.corpus2)
+        DatasetSet.objects.bulk_create(
+            DatasetSet(
+                dataset=cls.dataset2,
+                name=set_name
+            ) for set_name in ["test", "training", "validation"]
+        )
         # Process on cls.corpus and with a dataset from cls.corpus
         dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=dataset1.sets)
+        ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
         # Process on cls.corpus with a dataset from another corpus
         dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=dataset1.sets)
-        ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=cls.dataset2.sets)
+        ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True)))
+        ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
         # Process on another corpus with a dataset from another corpus and none from cls.corpus
         cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=cls.dataset2.sets)
+        ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True)))
 
         cls.rev = cls.repo.revisions.create(
             hash="42",
diff --git a/arkindex/documents/tests/tasks/test_export.py b/arkindex/documents/tests/tasks/test_export.py
index 32770443300e641ed4d2b1675e71e2c6da670093..c7da921c30f56ef5af641964d8cee48ce972449f 100644
--- a/arkindex/documents/tests/tasks/test_export.py
+++ b/arkindex/documents/tests/tasks/test_export.py
@@ -25,6 +25,7 @@ from arkindex.images.models import Image, ImageServer
 from arkindex.ponos.models import Artifact
 from arkindex.process.models import Repository, WorkerType, WorkerVersion, WorkerVersionState
 from arkindex.project.tests import FixtureTestCase
+from arkindex.training.models import DatasetElement
 
 TABLE_NAMES = {
     "export_version",
@@ -132,8 +133,9 @@ class TestExport(FixtureTestCase):
         )
 
         dataset = self.corpus.datasets.get(name="First Dataset")
-        dataset.dataset_elements.create(element=element, set="train")
-        dataset.dataset_elements.create(element=element, set="validation")
+        _, train_set, validation_set = dataset.sets.all().order_by("name")
+        train_set.set_elements.create(element=element)
+        validation_set.set_elements.create(element=element)
 
         export = self.corpus.exports.create(user=self.user)
 
@@ -489,7 +491,7 @@ class TestExport(FixtureTestCase):
                 (
                     str(dataset.id),
                     dataset.name,
-                    ",".join(dataset.sets),
+                    ",".join(list(dataset.sets.values_list("name", flat=True))),
                 ) for dataset in self.corpus.datasets.all()
             ]
         )
@@ -507,9 +509,9 @@ class TestExport(FixtureTestCase):
                 (
                     str(dataset_element.id),
                     str(dataset_element.element_id),
-                    str(dataset_element.dataset_id),
-                    dataset_element.set
-                ) for dataset_element in dataset.dataset_elements.all()
+                    str(dataset_element.set.dataset_id),
+                    dataset_element.set.name
+                ) for dataset_element in DatasetElement.objects.filter(set__dataset_id=dataset.id)
             ]
         )
 
diff --git a/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py b/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py
index 1f25a1aa0716d5ed7d452606e325f312da4a3102..f766bc0b0220384bab1b7ffccfb10d0d876418b0 100644
--- a/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py
+++ b/arkindex/documents/tests/tasks/test_selection_worker_results_delete.py
@@ -183,7 +183,8 @@ class TestDeleteSelectionWorkerResults(FixtureTestCase):
         job_mock.return_value.user_id = self.user.id
         self.page1.worker_version = self.version
         self.page1.save()
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test")
+        dataset = Dataset.objects.get(name="First Dataset")
+        dataset.sets.get(name="test").set_elements.create(element=self.page1)
         self.user.selected_elements.set([self.page1])
 
         selection_worker_results_delete(corpus_id=self.corpus.id, version_id=self.version.id)
diff --git a/arkindex/documents/tests/tasks/test_worker_results_delete.py b/arkindex/documents/tests/tasks/test_worker_results_delete.py
index 0fb898cf350c4bcf2eb3bb07ac6f960ee752c78f..52546e1bac401f5cd53028fb17a460e07c755b31 100644
--- a/arkindex/documents/tests/tasks/test_worker_results_delete.py
+++ b/arkindex/documents/tests/tasks/test_worker_results_delete.py
@@ -6,7 +6,7 @@ from arkindex.documents.models import Entity, EntityType, MLClass, Transcription
 from arkindex.documents.tasks import worker_results_delete
 from arkindex.process.models import ProcessMode, WorkerVersion
 from arkindex.project.tests import FixtureTestCase
-from arkindex.training.models import Dataset, Model, ModelVersionState
+from arkindex.training.models import DatasetSet, Model, ModelVersionState
 
 
 class TestDeleteWorkerResults(FixtureTestCase):
@@ -270,7 +270,7 @@ class TestDeleteWorkerResults(FixtureTestCase):
         self.page1.worker_run = self.worker_run_1
         self.page1.worker_version = self.version_1
         self.page1.save()
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.page1, set="test")
+        DatasetSet.objects.get(name="test", dataset__name="First Dataset").set_elements.create(element=self.page1)
 
         worker_results_delete(corpus_id=self.corpus.id)
         # Prevent delaying constraints check at end of the test transaction
diff --git a/arkindex/documents/tests/test_destroy_elements.py b/arkindex/documents/tests/test_destroy_elements.py
index da6312ac19d0fcd0863ce87752baa5bd7dcf95ef..af9fe5fb44f798e919d50a245102ab6a1c68a0f9 100644
--- a/arkindex/documents/tests/test_destroy_elements.py
+++ b/arkindex/documents/tests/test_destroy_elements.py
@@ -148,7 +148,8 @@ class TestDestroyElements(FixtureAPITestCase):
         """
         An element cannot be deleted via the API if linked to a dataset
         """
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(element=self.vol, set="test")
+        dataset = Dataset.objects.get(name="First Dataset")
+        dataset.sets.get(name="test").set_elements.create(element=self.vol)
         self.client.force_login(self.user)
         with self.assertNumQueries(3):
             response = self.client.delete(reverse("api:element-retrieve", kwargs={"pk": str(self.vol.id)}))
@@ -179,9 +180,9 @@ class TestDestroyElements(FixtureAPITestCase):
         """
         Elements that are part of a dataset cannot be deleted
         """
-        Dataset.objects.get(name="First Dataset").dataset_elements.create(
-            element=Element.objects.get_descending(self.vol.id).first(),
-            set="test",
+        dataset = Dataset.objects.get(name="First Dataset")
+        dataset.sets.get(name="test").set_elements.create(
+            element=Element.objects.get_descending(self.vol.id).first()
         )
 
         Element.objects.filter(id=self.vol.id).trash()
diff --git a/arkindex/sql_validation/corpus_delete.sql b/arkindex/sql_validation/corpus_delete.sql
index a454d05cdbb69ded46a1d1f175a0af1d2f630b59..766566825d6557a396a95ec21749ab734761f9b4 100644
--- a/arkindex/sql_validation/corpus_delete.sql
+++ b/arkindex/sql_validation/corpus_delete.sql
@@ -185,6 +185,15 @@ FROM "training_datasetelement"
 WHERE "training_datasetelement"."id" IN
         (SELECT U0."id"
          FROM "training_datasetelement" U0
+         INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id")
+         INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id")
+         WHERE U2."corpus_id" = '{corpus_id}'::uuid);
+
+DELETE
+FROM "training_datasetset"
+WHERE "training_datasetset"."id" IN
+        (SELECT U0."id"
+         FROM "training_datasetset" U0
          INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
          WHERE U1."corpus_id" = '{corpus_id}'::uuid);
 
diff --git a/arkindex/sql_validation/corpus_delete_top_level_type.sql b/arkindex/sql_validation/corpus_delete_top_level_type.sql
index 9c7ae8cd60709aa0bd672e1f72dbb020aea86153..d64cf0bb8b2eaabeb9a56b71612ba54223d8f0d2 100644
--- a/arkindex/sql_validation/corpus_delete_top_level_type.sql
+++ b/arkindex/sql_validation/corpus_delete_top_level_type.sql
@@ -189,8 +189,17 @@ FROM "training_datasetelement"
 WHERE "training_datasetelement"."id" IN
         (SELECT U0."id"
          FROM "training_datasetelement" U0
-         INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
-         WHERE U1."corpus_id" = '{corpus_id}'::uuid);
+         INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id")
+         INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id")
+         WHERE U2."corpus_id" = '{corpus_id}'::uuid);
+
+DELETE
+FROM "training_datasetset"
+WHERE "training_datasetset"."id" IN
+        (SELECT U0."id"
+         FROM "training_datasetset" U0
+           INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
+           WHERE U1."corpus_id" = '{corpus_id}'::uuid);
 
 DELETE
 FROM "training_dataset"