From 60a2ea2e7136f34f77df0f3078838f929e9e567d Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Tue, 1 Aug 2023 14:24:50 +0000
Subject: [PATCH] Add unique constraint on ProcessDataset

---
 ...3_processdataset_unique_process_dataset.py | 30 +++++++++++++++++++
 arkindex/process/models.py                    |  8 +++++
 arkindex/process/serializers/training.py      | 10 +++++--
 .../process/tests/test_process_datasets.py    | 14 ++++++++-
 4 files changed, 59 insertions(+), 3 deletions(-)
 create mode 100644 arkindex/process/migrations/0013_processdataset_unique_process_dataset.py

diff --git a/arkindex/process/migrations/0013_processdataset_unique_process_dataset.py b/arkindex/process/migrations/0013_processdataset_unique_process_dataset.py
new file mode 100644
index 0000000000..40865e386c
--- /dev/null
+++ b/arkindex/process/migrations/0013_processdataset_unique_process_dataset.py
@@ -0,0 +1,30 @@
+# Generated by Django 4.1.7 on 2023-08-01 12:57
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('process', '0012_corpusworkerversion_model_version_configuration'),
+    ]
+
+    operations = [
+        migrations.RunSQL(
+            # Remove any duplicates before applying the constraint
+            """
+            DELETE FROM process_processdataset
+            USING (
+                SELECT id, ROW_NUMBER() OVER (PARTITION BY process_id, dataset_id) AS i
+                FROM process_processdataset
+            ) duplicates
+            WHERE process_processdataset.id = duplicates.id
+            AND duplicates.i > 1
+            """,
+            reverse_sql=migrations.RunSQL.noop,
+        ),
+        migrations.AddConstraint(
+            model_name='processdataset',
+            constraint=models.UniqueConstraint(fields=('process', 'dataset'), name='unique_process_dataset'),
+        ),
+    ]
diff --git a/arkindex/process/models.py b/arkindex/process/models.py
index 7f9de2924a..a7a3e3f062 100644
--- a/arkindex/process/models.py
+++ b/arkindex/process/models.py
@@ -900,6 +900,14 @@ class ProcessDataset(models.Model):
     process = models.ForeignKey(Process, on_delete=models.CASCADE, related_name='process_datasets')
     dataset = models.ForeignKey('training.Dataset', on_delete=models.CASCADE, related_name='process_datasets')
 
+    class Meta:
+        constraints = [
+            models.UniqueConstraint(
+                fields=['process', 'dataset'],
+                name='unique_process_dataset',
+            )
+        ]
+
 
 class DataFile(S3FileMixin, models.Model):
     id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py
index 3bcf27262b..72ee3ccbb9 100644
--- a/arkindex/process/serializers/training.py
+++ b/arkindex/process/serializers/training.py
@@ -196,7 +196,7 @@ class StartTrainingSerializer(serializers.ModelSerializer, WorkerACLMixin, Train
         return self.instance
 
 
-class ProcessDatasetSerializer(serializers.ModelSerializer, ProcessACLMixin):
+class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer):
     process = serializers.PrimaryKeyRelatedField(
         queryset=Process.objects.using('default').select_related('corpus'),
         style={'base_template': 'input.html'},
@@ -206,7 +206,7 @@ class ProcessDatasetSerializer(serializers.ModelSerializer, ProcessACLMixin):
         style={'base_template': 'input.html'},
     )
 
-    class Meta():
+    class Meta:
         model = ProcessDataset
         fields = ('dataset', 'process', 'id', )
         read_only_fields = ('process', 'id', )
@@ -234,3 +234,9 @@ class ProcessDatasetSerializer(serializers.ModelSerializer, ProcessACLMixin):
         if not access or not (access >= Role.Admin.value):
             raise PermissionDenied(detail='You do not have admin access to this process.')
         return process
+
+    def validate(self, data):
+        process, dataset = data['process'], data['dataset']
+        if self.context['request'].method == 'POST' and process.datasets.filter(id=dataset.id).exists():
+            raise ValidationError({'detail': ['This dataset is already selected in this process.']})
+        return data
diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py
index 1af573724d..d59a69c3f2 100644
--- a/arkindex/process/tests/test_process_datasets.py
+++ b/arkindex/process/tests/test_process_datasets.py
@@ -200,11 +200,23 @@ class TestProcessDatasets(FixtureAPITestCase):
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(response.json(), {'dataset': [f'Invalid pk "{str(new_dataset.id)}" - object does not exist.']})
 
+    def test_create_unique(self):
+        self.client.force_login(self.test_user)
+        self.assertTrue(self.dataset_process.datasets.filter(id=self.dataset1.id).exists())
+
+        with self.assertNumQueries(8):
+            response = self.client.post(
+                reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset1.id}),
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
+        self.assertDictEqual(response.json(), {'detail': ['This dataset is already selected in this process.']})
+
     def test_create_process_dataset(self):
         self.client.force_login(self.test_user)
         self.assertEqual(ProcessDataset.objects.count(), 3)
         self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
-        with self.assertNumQueries(8):
+        with self.assertNumQueries(9):
             response = self.client.post(
                 reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
             )
-- 
GitLab