From d67c18b005b60defb0c6510620afded952ad399d Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Wed, 20 Mar 2024 18:45:11 +0100
Subject: [PATCH] Add prefetch for both sets and set_elements

---
 arkindex/training/api.py | 21 ++++++++++++++++++---
 1 file changed, 18 insertions(+), 3 deletions(-)

diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index d0740434d4..454663899e 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -3,7 +3,7 @@ from textwrap import dedent
 from uuid import UUID
 
 from django.db import connection, transaction
-from django.db.models import Q
+from django.db.models import Count, Prefetch, Q, prefetch_related_objects
 from django.shortcuts import get_object_or_404
 from django.utils.functional import cached_property
 from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
@@ -54,6 +54,12 @@ from arkindex.training.serializers import (
 from arkindex.users.models import Role
 from arkindex.users.utils import get_max_level
 
+# A prefetch object that includes the number of elements per set.
+SET_COUNTS_PREFETCH = Prefetch(
+    "sets",
+    DatasetSet.objects.annotate(element_count=Count("set_elements"))
+)
+
 
 def _fetch_datasetelement_neighbors(datasetelements):
     """
@@ -684,8 +690,14 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
     serializer_class = DatasetSerializer
 
     def get_queryset(self):
-        queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
-        return queryset.select_related("corpus", "creator")
+        queryset = (
+            Dataset.objects
+            .filter(corpus__in=Corpus.objects.readable(self.request.user))
+            .select_related("corpus", "creator")
+        )
+        if self.request.method != "DELETE":
+            queryset = queryset.prefetch_related(SET_COUNTS_PREFETCH)
+        return queryset
 
     def check_object_permissions(self, request, obj):
         super().check_object_permissions(request, obj)
@@ -998,6 +1010,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
             .values_list("element_id", "set__name")
         ])
 
+        # Add the set counts to the API response
+        prefetch_related_objects([clone], SET_COUNTS_PREFETCH)
+
         return Response(
             DatasetSerializer(clone).data,
             status=status.HTTP_201_CREATED,
-- 
GitLab