From 05bd4c11aabfd13472b9ca842d6fc85f4885c3fd Mon Sep 17 00:00:00 2001
From: Valentin Rigal <rigal@teklia.com>
Date: Tue, 2 Apr 2024 14:22:47 +0200
Subject: [PATCH] Update dataset serializer validation

---
 arkindex/training/serializers.py | 29 +++++++++++++++++------------
 1 file changed, 17 insertions(+), 12 deletions(-)

diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 9c4e357380..837db06673 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -4,7 +4,7 @@ from collections import defaultdict
 from textwrap import dedent
 
 from django.db import transaction
-from django.db.models import Count, Exists, OuterRef, Q
+from django.db.models import Count, Q
 from drf_spectacular.utils import extend_schema_field
 from rest_framework import permissions, serializers
 from rest_framework.exceptions import PermissionDenied, ValidationError
@@ -576,12 +576,13 @@ class DatasetSerializer(serializers.ModelSerializer):
         return set_names
 
     def validate_unique_elements(self, unique):
-        if unique is True and self.instance and Exists(
+        if unique is True and self.instance and (
             DatasetElement.objects
-            .filter(set__dataset_id=OuterRef(self.instance.pk))
+            .filter(set__dataset_id=self.instance.pk)
             .values("element")
             .annotate(dups=Count("element"))
             .filter(dups__gte=2)
+            .exists()
         ):
             raise ValidationError("Elements are currently contained by multiple sets.")
         return unique
@@ -712,20 +713,24 @@ class DatasetElementSerializer(serializers.ModelSerializer):
             self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
             self.fields["set"].queryset = dataset.sets.all()
 
-    def validate(self, data):
-        data = super().validate(data)
-        dataset = data.pop("dataset")
-        if dataset.unique_elements and (
-            set := (
+    def validate_element_id(self, element):
+        dataset = self.context.get("dataset")
+        if dataset and dataset.unique_elements and (
+            existing_set := (
                 dataset.sets
-                .filter(set_elements__element_id=data["element_id"])
+                .filter(set_elements__element=element)
                 .values_list("name", flat=True)
                 .first()
             )
         ):
-            raise ValidationError({"element_id": [
-                f"The dataset prevent duplication and this element is already present in set {set}."
-            ]})
+            raise ValidationError([
+                f"The dataset prevent duplication and this element is already present in set {existing_set}."
+            ])
+        return element
+
+    def validate(self, data):
+        data = super().validate(data)
+        data.pop("dataset")
         return data
 
 
-- 
GitLab