From 42d879437c676a21a848e564048cd0f3b50ffa63 Mon Sep 17 00:00:00 2001
From: ml bonhomme <bonhomme@teklia.com>
Date: Wed, 21 Feb 2024 10:29:50 +0000
Subject: [PATCH] Allow selecting dataset sets in dataset processes

---
 .../tests/tasks/test_corpus_delete.py         |  15 +-
 arkindex/process/api.py                       |  45 +-
 .../migrations/0029_processdataset_sets.py    |  43 ++
 arkindex/process/models.py                    |   9 +-
 arkindex/process/serializers/training.py      | 109 ++--
 arkindex/process/tests/test_create_process.py |   8 +-
 .../process/tests/test_process_datasets.py    | 467 ++++++++++++++++--
 arkindex/process/tests/test_processes.py      |  12 +-
 arkindex/training/serializers.py              |   7 +-
 arkindex/training/tests/test_datasets_api.py  |  54 +-
 10 files changed, 665 insertions(+), 104 deletions(-)
 create mode 100644 arkindex/process/migrations/0029_processdataset_sets.py

diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py
index a17590abe7..de915af8ff 100644
--- a/arkindex/documents/tests/tasks/test_corpus_delete.py
+++ b/arkindex/documents/tests/tasks/test_corpus_delete.py
@@ -3,7 +3,7 @@ from django.db.models.signals import pre_delete
 from arkindex.documents.models import Corpus, Element, EntityType, MetaType, Transcription
 from arkindex.documents.tasks import corpus_delete
 from arkindex.ponos.models import Farm, State, Task
-from arkindex.process.models import CorpusWorkerVersion, ProcessMode, Repository, WorkerVersion
+from arkindex.process.models import CorpusWorkerVersion, ProcessDataset, ProcessMode, Repository, WorkerVersion
 from arkindex.project.tests import FixtureTestCase, force_constraints_immediate
 from arkindex.training.models import Dataset
 
@@ -118,13 +118,14 @@ class TestDeleteCorpus(FixtureTestCase):
         cls.dataset2 = Dataset.objects.create(name="Dead Sea Scrolls", description="How to trigger a Third Impact", creator=cls.user, corpus=cls.corpus2)
         # Process on cls.corpus and with a dataset from cls.corpus
         dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        dataset_process1.datasets.set([dataset1])
+        ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=dataset1.sets)
         # Process on cls.corpus with a dataset from another corpus
         dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        dataset_process2.datasets.set([dataset1, cls.dataset2])
+        ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=dataset1.sets)
+        ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=cls.dataset2.sets)
         # Process on another corpus with a dataset from another corpus and none from cls.corpus
-        cls.dataset_process2 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
-        cls.dataset_process2.datasets.set([cls.dataset2])
+        cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset)
+        ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=cls.dataset2.sets)
 
         cls.rev = cls.repo.revisions.create(
             hash="42",
@@ -200,14 +201,14 @@ class TestDeleteCorpus(FixtureTestCase):
         self.df.refresh_from_db()
         self.vol.refresh_from_db()
         self.page.refresh_from_db()
-        self.dataset_process2.refresh_from_db()
+        self.dataset_process3.refresh_from_db()
 
         self.assertTrue(self.repo.revisions.filter(id=self.rev.id).exists())
         self.assertEqual(self.process.revision, self.rev)
         self.assertEqual(self.process.files.get(), self.df)
         self.assertTrue(Element.objects.get_descending(self.vol.id).filter(id=self.page.id).exists())
         self.assertTrue(self.corpus2.datasets.filter(id=self.dataset2.id).exists())
-        self.assertTrue(self.corpus2.processes.filter(id=self.dataset_process2.id).exists())
+        self.assertTrue(self.corpus2.processes.filter(id=self.dataset_process3.id).exists())
 
         md = self.vol.metadatas.get()
         self.assertEqual(md.name, "meta")
diff --git a/arkindex/process/api.py b/arkindex/process/api.py
index 26001aaab9..db75eac89a 100644
--- a/arkindex/process/api.py
+++ b/arkindex/process/api.py
@@ -46,6 +46,7 @@ from rest_framework.generics import (
     RetrieveDestroyAPIView,
     RetrieveUpdateAPIView,
     RetrieveUpdateDestroyAPIView,
+    UpdateAPIView,
 )
 from rest_framework.response import Response
 from rest_framework.serializers import Serializer
@@ -125,8 +126,7 @@ from arkindex.project.pagination import CountCursorPagination
 from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
 from arkindex.project.tools import PercentileCont
 from arkindex.project.triggers import process_delete
-from arkindex.training.models import Dataset, Model
-from arkindex.training.serializers import DatasetSerializer
+from arkindex.training.models import Model
 from arkindex.users.models import Role, Scope
 
 logger = logging.getLogger(__name__)
@@ -695,8 +695,8 @@ class DataFileCreate(CreateAPIView):
 )
 class ProcessDatasets(ProcessACLMixin, ListAPIView):
     permission_classes = (IsVerified, )
-    serializer_class = DatasetSerializer
-    queryset = Dataset.objects.none()
+    serializer_class = ProcessDatasetSerializer
+    queryset = ProcessDataset.objects.none()
 
     @cached_property
     def process(self):
@@ -709,7 +709,11 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
         return process
 
     def get_queryset(self):
-        return self.process.datasets.select_related("creator").order_by("name")
+        return (
+            ProcessDataset.objects.filter(process_id=self.process.id)
+            .select_related("process__creator", "dataset__creator")
+            .order_by("dataset__name")
+        )
 
     def get_serializer_context(self):
         context = super().get_serializer_context()
@@ -717,6 +721,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
         if not self.kwargs:
             return context
         context["process"] = self.process
+        # Disable set elements counts in serialized dataset
         context["sets_count"] = False
         return context
 
@@ -744,24 +749,30 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView):
         ),
     ),
 )
-class ProcessDatasetManage(CreateAPIView, DestroyAPIView):
+class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView):
     permission_classes = (IsVerified, )
     serializer_class = ProcessDatasetSerializer
 
-    def get_serializer_from_params(self, process=None, dataset=None, **kwargs):
-        data = {"process": process, "dataset": dataset}
-        kwargs["context"] = self.get_serializer_context()
-        return ProcessDatasetSerializer(data=data, **kwargs)
+    def get_object(self):
+        process_dataset = get_object_or_404(
+            ProcessDataset.objects
+            .select_related("dataset__creator", "process__corpus")
+            # Required to check for a process that have already started
+            .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))),
+            dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"]
+        )
+        # Copy the has_tasks annotation onto the process
+        process_dataset.process.has_tasks = process_dataset.process_has_tasks
+        return process_dataset
 
-    def create(self, request, *args, **kwargs):
-        serializer = self.get_serializer_from_params(**kwargs)
-        serializer.is_valid(raise_exception=True)
-        serializer.create(serializer.validated_data)
-        headers = self.get_success_headers(serializer.data)
-        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+    def get_serializer_context(self):
+        context = super().get_serializer_context()
+        # Disable set elements counts in serialized dataset
+        context["sets_count"] = False
+        return context
 
     def destroy(self, request, *args, **kwargs):
-        serializer = self.get_serializer_from_params(**kwargs)
+        serializer = self.get_serializer(data=request.data)
         serializer.is_valid(raise_exception=True)
         get_object_or_404(ProcessDataset, **serializer.validated_data).delete()
         return Response(status=status.HTTP_204_NO_CONTENT)
diff --git a/arkindex/process/migrations/0029_processdataset_sets.py b/arkindex/process/migrations/0029_processdataset_sets.py
new file mode 100644
index 0000000000..868c1cc29f
--- /dev/null
+++ b/arkindex/process/migrations/0029_processdataset_sets.py
@@ -0,0 +1,43 @@
+import django.core.validators
+from django.db import migrations, models
+
+import arkindex.project.fields
+import arkindex.training.models
+
+
+class Migration(migrations.Migration):
+    dependencies = [
+        ("process", "0028_remove_process_model_remove_process_test_folder_and_more"),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name="processdataset",
+            name="sets",
+            field=arkindex.project.fields.ArrayField(base_field=models.CharField(max_length=50), blank=True, default=list, size=None),
+        ),
+        migrations.RunSQL(
+            [
+                """
+                UPDATE process_processdataset p
+                SET sets = d.sets
+                FROM training_dataset d
+                WHERE p.dataset_id = d.id
+                """
+            ],
+            reverse_sql=migrations.RunSQL.noop,
+            elidable=True,
+        ),
+        migrations.AlterField(
+            model_name="processdataset",
+            name="sets",
+            field=arkindex.project.fields.ArrayField(
+                base_field=models.CharField(
+                    max_length=50,
+                    validators=[django.core.validators.MinLengthValidator(1)]
+                ),
+                size=None,
+                validators=[django.core.validators.MinLengthValidator(1), arkindex.training.models.validate_unique_set_names]
+            ),
+        ),
+    ]
diff --git a/arkindex/process/models.py b/arkindex/process/models.py
index 8f1db92144..ac26165f47 100644
--- a/arkindex/process/models.py
+++ b/arkindex/process/models.py
@@ -27,7 +27,7 @@ from arkindex.project.aws import S3FileMixin, S3FileStatus
 from arkindex.project.fields import ArrayField, MD5HashField
 from arkindex.project.models import IndexableModel
 from arkindex.project.validators import MaxValueValidator
-from arkindex.training.models import ModelVersion, ModelVersionState
+from arkindex.training.models import ModelVersion, ModelVersionState, validate_unique_set_names
 from arkindex.users.models import Role
 
 
@@ -466,6 +466,13 @@ class ProcessDataset(models.Model):
     id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
     process = models.ForeignKey(Process, on_delete=models.CASCADE, related_name="process_datasets")
     dataset = models.ForeignKey("training.Dataset", on_delete=models.CASCADE, related_name="process_datasets")
+    sets = ArrayField(
+        models.CharField(max_length=50, validators=[MinLengthValidator(1)]),
+        validators=[
+            MinLengthValidator(1),
+            validate_unique_set_names,
+        ]
+    )
 
     class Meta:
         constraints = [
diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py
index 96db960576..f56fc916a9 100644
--- a/arkindex/process/serializers/training.py
+++ b/arkindex/process/serializers/training.py
@@ -5,62 +5,96 @@ from rest_framework import serializers
 from rest_framework.exceptions import PermissionDenied, ValidationError
 
 from arkindex.documents.models import Corpus
-from arkindex.ponos.models import Task
-from arkindex.process.models import Process, ProcessDataset, ProcessMode
+from arkindex.process.models import Process, ProcessDataset, ProcessMode, Task
 from arkindex.project.mixins import ProcessACLMixin
 from arkindex.training.models import Dataset
+from arkindex.training.serializers import DatasetSerializer
 from arkindex.users.models import Role
 
 
+def _dataset_id_from_context(serializer_field):
+    return serializer_field.context.get("view").kwargs["dataset"]
+
+
+def _process_id_from_context(serializer_field):
+    return serializer_field.context.get("view").kwargs["process"]
+
+
+_dataset_id_from_context.requires_context = True
+_process_id_from_context.requires_context = True
+
+
 class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
-    process = serializers.PrimaryKeyRelatedField(
-        queryset=(
-            Process
-            .objects
-            # Avoid a stale read when adding a dataset right after creating a process
-            .using("default")
-            # Required for ACL checks
-            .select_related("corpus")
-            # Required to check for a process that have already started
-            .annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef("pk"))))
-        ),
-        style={"base_template": "input.html"},
+    process_id = serializers.HiddenField(
+        write_only=True,
+        default=_process_id_from_context
     )
-    dataset = serializers.PrimaryKeyRelatedField(
-        queryset=Dataset.objects.none(),
-        style={"base_template": "input.html"},
+    dataset_id = serializers.HiddenField(
+        write_only=True,
+        default=_dataset_id_from_context
+    )
+    dataset = DatasetSerializer(read_only=True)
+    sets = serializers.ListField(
+        child=serializers.CharField(max_length=50),
+        required=False,
+        allow_null=False,
+        allow_empty=False,
+        min_length=1
     )
 
     class Meta:
         model = ProcessDataset
-        fields = ("dataset", "process", "id", )
-        read_only_fields = ("process", "id", )
+        fields = ("dataset_id", "dataset", "process_id", "id", "sets", )
+        read_only_fields = ("process_id", "id", )
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         if not self.context.get("request"):
             # Do not raise Error in order to create OpenAPI schema
             return
-
-        request_method = self.context["request"].method
         # Required for the ProcessACLMixin and readable corpora
         self._user = self.context["request"].user
 
-        if request_method == "DELETE":
-            # Allow deleting ProcessDatasets even if the user looses access to the corpus
-            self.fields["dataset"].queryset = Dataset.objects.all()
-        else:
-            self.fields["dataset"].queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user))
+    def validate(self, data):
+        request_method = self.context["request"].method
+        errors = defaultdict(list)
 
-    def validate_process(self, process):
+        # Validate process
+        process_qs = (
+            Process
+            .objects
+            # Avoid a stale read when adding a dataset right after creating a process
+            .using("default")
+            # Required for ACL checks
+            .select_related("corpus", "revision__repo")
+            # Required to check for a process that has already started
+            .annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef("pk"))))
+        )
+        if not self.instance:
+            try:
+                process = process_qs.get(pk=data["process_id"])
+            except Process.DoesNotExist:
+                raise ValidationError({"process": [f'Invalid pk "{str(data["process_id"])}" - object does not exist.']})
+        else:
+            process = self.instance.process
         access = self.process_access_level(process)
         if not access or not (access >= Role.Admin.value):
             raise PermissionDenied(detail="You do not have admin access to this process.")
-        return process
 
-    def validate(self, data):
-        process, dataset = data["process"], data["dataset"]
-        errors = defaultdict(list)
+        # Validate dataset
+        if not self.instance:
+            if request_method == "DELETE":
+                # Allow deleting ProcessDatasets even if the user looses access to the corpus
+                dataset_qs = Dataset.objects.all()
+            else:
+                dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user))
+            try:
+                dataset = dataset_qs.select_related("creator").get(pk=data["dataset_id"])
+            except Dataset.DoesNotExist:
+                raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']})
+        else:
+            dataset = self.instance.dataset
+        data["dataset"] = dataset
 
         if process.mode != ProcessMode.Dataset:
             errors["process"].append('Datasets can only be added to or removed from processes of mode "dataset".')
@@ -71,6 +105,19 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
         if self.context["request"].method == "POST" and process.datasets.filter(id=dataset.id).exists():
             errors["dataset"].append("This dataset is already selected in this process.")
 
+        # Validate sets
+        sets = data.get("sets")
+        if not sets or len(sets) == 0:
+            if not self.instance:
+                data["sets"] = dataset.sets
+            else:
+                errors["sets"].append("This field cannot be empty.")
+        else:
+            if any(s not in dataset.sets for s in sets):
+                errors["sets"].append("The specified sets must all exist in the specified dataset.")
+            if len(set(sets)) != len(sets):
+                errors["sets"].append("Sets must be unique.")
+
         if errors:
             raise ValidationError(errors)
 
diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py
index 79b0e98ba4..008c963eaa 100644
--- a/arkindex/process/tests/test_create_process.py
+++ b/arkindex/process/tests/test_create_process.py
@@ -11,6 +11,7 @@ from arkindex.ponos.models import Farm, State
 from arkindex.process.models import (
     ActivityState,
     Process,
+    ProcessDataset,
     ProcessMode,
     Repository,
     WorkerActivity,
@@ -826,7 +827,8 @@ class TestCreateProcess(FixtureAPITestCase):
     def test_dataset_gpu_required(self):
         self.client.force_login(self.user)
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        process.datasets.add(self.corpus.datasets.first())
+        dataset = self.corpus.datasets.first()
+        ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets)
         process.versions.set([self.version_2, self.version_3])
 
         with self.assertNumQueries(15):
@@ -854,9 +856,9 @@ class TestCreateProcess(FixtureAPITestCase):
         self.client.force_login(self.user)
         self.worker_1.archived = datetime.now(timezone.utc)
         self.worker_1.save()
-
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        process.datasets.add(self.corpus.datasets.first())
+        dataset = self.corpus.datasets.first()
+        ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets)
         process.versions.add(self.version_1)
 
         with self.assertNumQueries(15):
diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py
index 142708b126..16005a9513 100644
--- a/arkindex/process/tests/test_process_datasets.py
+++ b/arkindex/process/tests/test_process_datasets.py
@@ -1,10 +1,12 @@
 import uuid
+from datetime import datetime, timezone
 from unittest.mock import patch
 
 from django.urls import reverse
 from rest_framework import status
 
 from arkindex.documents.models import Corpus
+from arkindex.ponos.models import Farm
 from arkindex.process.models import Process, ProcessDataset, ProcessMode, Repository
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.training.models import Dataset
@@ -35,9 +37,11 @@ class TestProcessDatasets(FixtureAPITestCase):
         cls.dataset_process = Process.objects.create(
             creator_id=cls.user.id,
             mode=ProcessMode.Dataset,
-            corpus_id=cls.private_corpus.id
+            corpus_id=cls.private_corpus.id,
+            farm=Farm.objects.get(name="Wheat farm")
         )
-        cls.dataset_process.datasets.set([cls.dataset1, cls.private_dataset])
+        cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=cls.dataset1.sets)
+        cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=cls.private_dataset.sets)
 
         # Control process to check that its datasets are not retrieved
         cls.dataset_process_2 = Process.objects.create(
@@ -45,6 +49,7 @@ class TestProcessDatasets(FixtureAPITestCase):
             mode=ProcessMode.Dataset,
             corpus_id=cls.corpus.id
         )
+        ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=cls.dataset2.sets)
         cls.dataset_process_2.datasets.set([cls.dataset2])
 
         # For repository process
@@ -80,30 +85,38 @@ class TestProcessDatasets(FixtureAPITestCase):
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertEqual(response.json()["results"], [
             {
-                "id": str(self.private_dataset.id),
-                "name": "Dead sea scrolls",
-                "description": "Human instrumentality manual",
-                "creator": "Test user",
-                "sets": ["training", "test", "validation"],
-                "set_elements": None,
-                "corpus_id": str(self.private_corpus.id),
-                "state": "open",
-                "task_id": None,
-                "created": FAKE_CREATED,
-                "updated": FAKE_CREATED
+                "id": str(self.process_dataset_2.id),
+                "dataset": {
+                    "id": str(self.private_dataset.id),
+                    "name": "Dead sea scrolls",
+                    "description": "Human instrumentality manual",
+                    "creator": "Test user",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "corpus_id": str(self.private_corpus.id),
+                    "state": "open",
+                    "task_id": None,
+                    "created": FAKE_CREATED,
+                    "updated": FAKE_CREATED
+                },
+                "sets": ["training", "test", "validation"]
             },
             {
-                "id": str(self.dataset1.id),
-                "name": "First Dataset",
-                "description": "dataset number one",
-                "creator": "Test user",
-                "sets": ["training", "test", "validation"],
-                "set_elements": None,
-                "corpus_id": str(self.corpus.id),
-                "state": "open",
-                "task_id": None,
-                "created": FAKE_CREATED,
-                "updated": FAKE_CREATED
+                "id": str(self.process_dataset_1.id),
+                "dataset": {
+                    "id": str(self.dataset1.id),
+                    "name": "First Dataset",
+                    "description": "dataset number one",
+                    "creator": "Test user",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "corpus_id": str(self.corpus.id),
+                    "state": "open",
+                    "task_id": None,
+                    "created": FAKE_CREATED,
+                    "updated": FAKE_CREATED
+                },
+                "sets": ["training", "test", "validation"]
             }
         ])
 
@@ -113,6 +126,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         with self.assertNumQueries(0):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                data={"sets": self.dataset2.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
@@ -122,6 +136,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         with self.assertNumQueries(2):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                data={"sets": self.dataset2.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
@@ -133,9 +148,10 @@ class TestProcessDatasets(FixtureAPITestCase):
                 if level:
                     self.private_corpus.memberships.create(user=self.test_user, level=level.value)
                 self.client.force_login(self.test_user)
-                with self.assertNumQueries(6):
+                with self.assertNumQueries(5):
                     response = self.client.post(
                         reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                        data={"sets": self.dataset2.sets}
                     )
                     self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
                 self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."})
@@ -151,6 +167,7 @@ class TestProcessDatasets(FixtureAPITestCase):
                 with self.assertNumQueries(8):
                     response = self.client.post(
                         reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                        data={"sets": self.dataset2.sets}
                     )
                     self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
@@ -159,9 +176,10 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_create_process_mode_local(self):
         self.client.force_login(self.user)
         local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}),
+                data={"sets": self.dataset2.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
         self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."})
@@ -169,9 +187,10 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_create_process_mode_repository(self):
         self.client.force_login(self.user)
         process = Process.objects.create(creator=self.user, mode=ProcessMode.Repository, revision=self.rev)
-        with self.assertNumQueries(10):
+        with self.assertNumQueries(6):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": process.id, "dataset": self.dataset2.id}),
+                data={"sets": self.dataset2.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
         self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."})
@@ -179,9 +198,10 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_create_wrong_process_uuid(self):
         self.client.force_login(self.test_user)
         wrong_id = uuid.uuid4()
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(3):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.dataset2.id}),
+                data={"sets": self.dataset2.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
@@ -192,6 +212,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         with self.assertNumQueries(7):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": wrong_id}),
+                data={"sets": ["test"]}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(response.json(), {"dataset": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
@@ -203,6 +224,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         with self.assertNumQueries(7):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}),
+                data={"sets": new_dataset.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(response.json(), {"dataset": [f'Invalid pk "{str(new_dataset.id)}" - object does not exist.']})
@@ -214,6 +236,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         with self.assertNumQueries(8):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": self.dataset1.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
@@ -226,11 +249,49 @@ class TestProcessDatasets(FixtureAPITestCase):
         with self.assertNumQueries(8):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                data={"sets": self.dataset2.sets}
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
         self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]})
 
+    def test_create_default_sets(self):
+        self.client.force_login(self.test_user)
+        self.assertEqual(ProcessDataset.objects.count(), 3)
+        self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
+        with self.assertNumQueries(9):
+            response = self.client.post(
+                reverse(
+                    "api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}
+                ),
+            )
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        self.assertEqual(ProcessDataset.objects.count(), 4)
+        self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
+        self.assertQuerysetEqual(self.dataset_process.datasets.order_by("name"), [
+            self.private_dataset,
+            self.dataset1,
+            self.dataset2
+        ])
+        created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id)
+        self.assertDictEqual(response.json(), {
+            "id": str(created.id),
+            "dataset": {
+                "id": str(self.dataset2.id),
+                "name": "Second Dataset",
+                "description": "dataset number two",
+                "creator": "Test user",
+                "sets": ["training", "test", "validation"],
+                "set_elements": None,
+                "corpus_id": str(self.corpus.id),
+                "state": "open",
+                "task_id": None,
+                "created": FAKE_CREATED,
+                "updated": FAKE_CREATED
+            },
+            "sets": ["training", "test", "validation"]
+        })
+
     def test_create(self):
         self.client.force_login(self.test_user)
         self.assertEqual(ProcessDataset.objects.count(), 3)
@@ -238,6 +299,7 @@ class TestProcessDatasets(FixtureAPITestCase):
         with self.assertNumQueries(9):
             response = self.client.post(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                data={"sets": ["validation", "test"]}
             )
             self.assertEqual(response.status_code, status.HTTP_201_CREATED)
         self.assertEqual(ProcessDataset.objects.count(), 4)
@@ -247,6 +309,345 @@ class TestProcessDatasets(FixtureAPITestCase):
             self.dataset1,
             self.dataset2
         ])
+        created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id)
+        self.assertDictEqual(response.json(), {
+            "id": str(created.id),
+            "dataset": {
+                "id": str(self.dataset2.id),
+                "name": "Second Dataset",
+                "description": "dataset number two",
+                "creator": "Test user",
+                "sets": ["training", "test", "validation"],
+                "set_elements": None,
+                "corpus_id": str(self.corpus.id),
+                "state": "open",
+                "task_id": None,
+                "created": FAKE_CREATED,
+                "updated": FAKE_CREATED
+            },
+            "sets": ["validation", "test"]
+        })
+
+    def test_create_wrong_sets(self):
+        self.client.force_login(self.test_user)
+        self.assertEqual(ProcessDataset.objects.count(), 3)
+        self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
+        with self.assertNumQueries(8):
+            response = self.client.post(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                data={"sets": ["Unit-01"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
+        self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]})
+
+    # Update process dataset
+
+    def test_update_requires_login(self):
+        with self.assertNumQueries(0):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_update_requires_verified(self):
+        unverified_user = User.objects.create(email="email@mail.com")
+        self.client.force_login(unverified_user)
+        with self.assertNumQueries(2):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_update_access_level(self):
+        cases = [None, Role.Guest, Role.Contributor]
+        for level in cases:
+            with self.subTest(level=level):
+                self.private_corpus.memberships.filter(user=self.test_user).delete()
+                if level:
+                    self.private_corpus.memberships.create(user=self.test_user, level=level.value)
+                self.client.force_login(self.test_user)
+                with self.assertNumQueries(5):
+                    response = self.client.put(
+                        reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                        data={"sets": ["test"]}
+                    )
+                    self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+                self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."})
+
+    def test_update_process_does_not_exist(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(3):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_update_dataset_does_not_exist(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(3):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_update_m2m_does_not_exist(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(3):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_update(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(7):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            "id": str(self.process_dataset_1.id),
+            "dataset": {
+                "id": str(self.dataset1.id),
+                "name": "First Dataset",
+                "description": "dataset number one",
+                "creator": "Test user",
+                "sets": ["training", "test", "validation"],
+                "set_elements": None,
+                "corpus_id": str(self.corpus.id),
+                "state": "open",
+                "task_id": None,
+                "created": FAKE_CREATED,
+                "updated": FAKE_CREATED
+            },
+            "sets": ["test"]
+        })
+
+    def test_update_wrong_sets(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(6):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["Unit-01", "Unit-02"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]})
+
+    def test_update_unique_sets(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(6):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test", "test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"sets": ["Sets must be unique."]})
+
+    def test_update_started_process(self):
+        self.dataset_process.tasks.create(
+            run=0,
+            depth=0,
+            slug="task",
+            expiry=datetime(1970, 1, 1, tzinfo=timezone.utc),
+        )
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(6):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]})
+
+    def test_update_only_sets(self):
+        """
+        Non "sets" fields in the update request are ignored
+        """
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(7):
+            response = self.client.put(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            "id": str(self.process_dataset_1.id),
+            "dataset": {
+                "id": str(self.dataset1.id),
+                "name": "First Dataset",
+                "description": "dataset number one",
+                "creator": "Test user",
+                "sets": ["training", "test", "validation"],
+                "set_elements": None,
+                "corpus_id": str(self.corpus.id),
+                "state": "open",
+                "task_id": None,
+                "created": FAKE_CREATED,
+                "updated": FAKE_CREATED
+            },
+            "sets": ["test"]
+        })
+
+    # Partial update process dataset
+
+    def test_partial_update_requires_login(self):
+        with self.assertNumQueries(0):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_partial_update_requires_verified(self):
+        unverified_user = User.objects.create(email="email@mail.com")
+        self.client.force_login(unverified_user)
+        with self.assertNumQueries(2):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_partial_update_access_level(self):
+        cases = [None, Role.Guest, Role.Contributor]
+        for level in cases:
+            with self.subTest(level=level):
+                self.private_corpus.memberships.filter(user=self.test_user).delete()
+                if level:
+                    self.private_corpus.memberships.create(user=self.test_user, level=level.value)
+                self.client.force_login(self.test_user)
+                with self.assertNumQueries(5):
+                    response = self.client.patch(
+                        reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                        data={"sets": ["test"]}
+                    )
+                    self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+                self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."})
+
+    def test_partial_update_process_does_not_exist(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(3):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_partial_update_dataset_does_not_exist(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(3):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_partial_update_m2m_does_not_exist(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(3):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    def test_partial_update(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(7):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            "id": str(self.process_dataset_1.id),
+            "dataset": {
+                "id": str(self.dataset1.id),
+                "name": "First Dataset",
+                "description": "dataset number one",
+                "creator": "Test user",
+                "sets": ["training", "test", "validation"],
+                "set_elements": None,
+                "corpus_id": str(self.corpus.id),
+                "state": "open",
+                "task_id": None,
+                "created": FAKE_CREATED,
+                "updated": FAKE_CREATED
+            },
+            "sets": ["test"]
+        })
+
+    def test_partial_update_wrong_sets(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(6):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["Unit-01", "Unit-02"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]})
+
+    def test_partial_update_unique_sets(self):
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(6):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test", "test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"sets": ["Sets must be unique."]})
+
+    def test_partial_update_started_process(self):
+        self.dataset_process.tasks.create(
+            run=0,
+            depth=0,
+            slug="task",
+            expiry=datetime(1970, 1, 1, tzinfo=timezone.utc),
+        )
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(6):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]})
+
+    def test_partial_update_only_sets(self):
+        """
+        Non "sets" fields in the partial update request are ignored
+        """
+        self.client.force_login(self.test_user)
+        with self.assertNumQueries(7):
+            response = self.client.patch(
+                reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}),
+                data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            "id": str(self.process_dataset_1.id),
+            "dataset": {
+                "id": str(self.dataset1.id),
+                "name": "First Dataset",
+                "description": "dataset number one",
+                "creator": "Test user",
+                "sets": ["training", "test", "validation"],
+                "set_elements": None,
+                "corpus_id": str(self.corpus.id),
+                "state": "open",
+                "task_id": None,
+                "created": FAKE_CREATED,
+                "updated": FAKE_CREATED
+            },
+            "sets": ["test"]
+        })
 
     # Destroy process dataset
 
@@ -260,7 +661,7 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_destroy_process_does_not_exist(self):
         self.client.force_login(self.test_user)
         wrong_id = uuid.uuid4()
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(3):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.private_dataset.id})
             )
@@ -289,7 +690,7 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_destroy_process_access_level(self):
         self.private_corpus.memberships.filter(user=self.test_user).delete()
         self.client.force_login(self.test_user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(5):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.private_dataset.id})
             )
@@ -299,7 +700,7 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_destroy_no_dataset_access_requirement(self):
         new_corpus = Corpus.objects.create(name="NERV")
         new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user)
-        self.dataset_process.datasets.add(new_dataset)
+        ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=new_dataset.sets)
         self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists())
         self.client.force_login(self.test_user)
         with self.assertNumQueries(9):
@@ -328,7 +729,7 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_destroy_process_mode_local(self):
         self.client.force_login(self.user)
         local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local)
-        with self.assertNumQueries(4):
+        with self.assertNumQueries(3):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}),
             )
@@ -338,7 +739,7 @@ class TestProcessDatasets(FixtureAPITestCase):
     def test_destroy_process_mode_repository(self):
         self.client.force_login(self.user)
         process = Process.objects.create(creator=self.user, mode=ProcessMode.Repository, revision=self.rev)
-        with self.assertNumQueries(9):
+        with self.assertNumQueries(6):
             response = self.client.delete(
                 reverse("api:process-dataset", kwargs={"process": process.id, "dataset": self.dataset2.id}),
             )
diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py
index 3ec9cc6491..d8bdb5cef8 100644
--- a/arkindex/process/tests/test_processes.py
+++ b/arkindex/process/tests/test_processes.py
@@ -15,6 +15,7 @@ from arkindex.process.models import (
     ActivityState,
     DataFile,
     Process,
+    ProcessDataset,
     ProcessMode,
     Repository,
     WorkerActivity,
@@ -2436,7 +2437,7 @@ class TestProcesses(FixtureAPITestCase):
 
     def test_start_process_dataset_requires_dataset_in_same_corpus(self):
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        process2.datasets.set([self.private_dataset])
+        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets)
         process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
         self.assertFalse(process2.tasks.exists())
 
@@ -2453,7 +2454,8 @@ class TestProcesses(FixtureAPITestCase):
 
     def test_start_process_dataset_unsupported_parameters(self):
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        process2.datasets.set([self.dataset1, self.private_dataset])
+        ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets)
+        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.dataset2.sets)
         process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
 
         self.client.force_login(self.user)
@@ -2477,7 +2479,8 @@ class TestProcesses(FixtureAPITestCase):
 
     def test_start_process_dataset(self):
         process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        process2.datasets.set([self.dataset1, self.private_dataset])
+        ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets)
+        ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets)
         run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
         self.assertFalse(process2.tasks.exists())
 
@@ -2669,7 +2672,8 @@ class TestProcesses(FixtureAPITestCase):
         It should be possible to pass chunks when starting a dataset process
         """
         process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
-        process.datasets.set([self.dataset1, self.dataset2])
+        ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=self.dataset1.sets)
+        ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=self.dataset2.sets)
         # Add a worker run to this process
         run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
 
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index b92c44d1c5..4cff4c8f0f 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -13,7 +13,7 @@ from rest_framework.validators import UniqueTogetherValidator
 from arkindex.documents.models import Element
 from arkindex.documents.serializers.elements import ElementListSerializer
 from arkindex.ponos.models import Task
-from arkindex.process.models import Worker
+from arkindex.process.models import ProcessDataset, Worker
 from arkindex.project.serializer_fields import ArchivedField, DatasetSetsCountField, EnumField
 from arkindex.training.models import (
     Dataset,
@@ -560,8 +560,11 @@ class DatasetSerializer(serializers.ModelSerializer):
             raise ValidationError("Set names must be unique.")
 
         removed, added = self.sets_diff(sets)
+        if removed and ProcessDataset.objects.filter(sets__overlap=removed, dataset_id=self.instance.id).exists():
+            # Sets that are used in a ProcessDataset cannot be renamed or deleted
+            raise ValidationError("These sets cannot be updated because one or more are selected in a dataset process.")
         if not removed or not added:
-            # Some sets have been added or removed, do nothing
+            # Some sets have either been added or removed, but not both; do nothing
             return sets
         elif len(removed) == 1 and len(added) == 1:
             # A single set has been renamed. Move its elements later, while performing the update
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 653ffa152b..f6a935b162 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -6,7 +6,7 @@ from django.utils import timezone as DjangoTimeZone
 from rest_framework import status
 
 from arkindex.documents.models import Corpus
-from arkindex.process.models import Process, ProcessMode
+from arkindex.process.models import Process, ProcessDataset, ProcessMode
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.project.tools import fake_now
 from arkindex.training.models import Dataset, DatasetState
@@ -30,7 +30,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
         cls.write_user = User.objects.get(email="user2@user.fr")
         cls.dataset = Dataset.objects.get(name="First Dataset")
         cls.dataset2 = Dataset.objects.get(name="Second Dataset")
-        cls.process.datasets.set((cls.dataset, cls.dataset2))
+        ProcessDataset.objects.create(process=cls.process, dataset=cls.dataset, sets=["training", "test", "validation"])
+        ProcessDataset.objects.create(process=cls.process, dataset=cls.dataset2, sets=["test"])
         cls.private_dataset = Dataset.objects.create(name="Private Dataset", description="Dead Sea Scrolls", corpus=cls.private_corpus, creator=cls.dataset_creator)
         cls.vol = cls.corpus.elements.get(name="Volume 1")
         cls.page1 = cls.corpus.elements.get(name="Volume 1, page 1r")
@@ -504,9 +505,11 @@ class TestDatasetsAPI(FixtureAPITestCase):
         """
         It is possible to remove many sets, no elements are moved
         """
+        # Remove ProcessDataset relation
+        ProcessDataset.objects.get(process=self.process, dataset=self.dataset).delete()
         self.client.force_login(self.user)
         dataset_elt = self.dataset.dataset_elements.create(element=self.page1, set="training")
-        with self.assertNumQueries(11):
+        with self.assertNumQueries(12):
             response = self.client.put(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}),
                 data={
@@ -532,13 +535,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
 
     def test_update_sets_update_single_set(self):
         """
-        It is possible to rename a single set
+        It is possible to rename a single set, if it is not referenced by a ProcessDataset
         """
+        ProcessDataset.objects.get(process=self.process, dataset=self.dataset, sets=["training", "test", "validation"]).delete()
         self.client.force_login(self.user)
         self.dataset.dataset_elements.create(element_id=self.page1.id, set="training")
         self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation")
         self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation")
-        with self.assertNumQueries(12):
+        with self.assertNumQueries(13):
             response = self.client.put(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}),
                 data={
@@ -565,12 +569,50 @@ class TestDatasetsAPI(FixtureAPITestCase):
             ]
         )
 
+    def test_update_sets_processdataset_reference(self):
+        """
+        If a dataset's sets are referenced by a ProcessDataset, they cannot be updated
+        """
+        self.client.force_login(self.user)
+        self.dataset.dataset_elements.create(element_id=self.page1.id, set="training")
+        self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation")
+        self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation")
+        with self.assertNumQueries(7):
+            response = self.client.put(
+                reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}),
+                data={
+                    "name": "Shin Seiki Evangelion",
+                    "description": "Omedeto!",
+                    # validation set is renamed to AAAAAAA
+                    "sets": ["test", "training", "AAAAAAA"],
+                },
+                format="json"
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"sets": ["These sets cannot be updated because one or more are selected in a dataset process."]})
+        self.dataset.refresh_from_db()
+        self.assertEqual(self.dataset.state, DatasetState.Open)
+        self.assertEqual(self.dataset.name, "First Dataset")
+        self.assertEqual(self.dataset.description, "dataset number one")
+        self.assertListEqual(self.dataset.sets, ["training", "test", "validation"])
+        self.assertIsNone(self.dataset.task_id)
+        self.assertQuerysetEqual(
+            self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"),
+            [
+                ("training", "Volume 1, page 1r"),
+                ("validation", "Volume 1, page 1v"),
+                ("validation", "Volume 1, page 2r"),
+            ]
+        )
+
     def test_update_sets_ambiguous(self):
         """
         No more than one set can be updated
         """
+        # Remove ProcessDataset relation
+        ProcessDataset.objects.get(process=self.process, dataset=self.dataset).delete()
         self.client.force_login(self.user)
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.put(
                 reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}),
                 data={
-- 
GitLab