From 6cafe5012ffc989782bb055b2e7873b444e23fa1 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Tue, 26 Mar 2024 11:51:57 +0100
Subject: [PATCH] Restore set_elements on create responses

---
 arkindex/training/serializers.py | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 819042749e..9fc5c2bae4 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -15,6 +15,7 @@ from arkindex.documents.serializers.elements import ElementListSerializer
 from arkindex.ponos.models import Task
 from arkindex.process.models import Worker
 from arkindex.project.serializer_fields import ArchivedField, DatasetSetsCountField, EnumField
+from arkindex.project.tools import add_as_prefetch
 from arkindex.training.models import (
     Dataset,
     DatasetElement,
@@ -591,14 +592,19 @@ class DatasetSerializer(serializers.ModelSerializer):
 
     @transaction.atomic
     def create(self, validated_data):
-        sets = validated_data.pop("set_names")
+        set_names = validated_data.pop("set_names")
         dataset = Dataset.objects.create(**validated_data)
-        DatasetSet.objects.bulk_create(
+        sets = DatasetSet.objects.bulk_create(
             DatasetSet(
                 name=set_name,
                 dataset_id=dataset.id
-            ) for set_name in sets
+            ) for set_name in set_names
         )
+        # We will output set element counts in the API, but we know there are zero,
+        # so no need to make another query to prefetch the sets and count them
+        for set in sets:
+            set.element_count = 0
+        add_as_prefetch(dataset.sets, sets)
         return dataset
 
     class Meta:
-- 
GitLab