diff --git a/arkindex/training/api.py b/arkindex/training/api.py index d0740434d46567c6d090f55bd778bcde234283cf..454663899ebd7c982c69f7a97ab8ac02fac47090 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -3,7 +3,7 @@ from textwrap import dedent from uuid import UUID from django.db import connection, transaction -from django.db.models import Q +from django.db.models import Count, Prefetch, Q, prefetch_related_objects from django.shortcuts import get_object_or_404 from django.utils.functional import cached_property from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view @@ -54,6 +54,12 @@ from arkindex.training.serializers import ( from arkindex.users.models import Role from arkindex.users.utils import get_max_level +# A prefetch object that includes the number of elements per set. +SET_COUNTS_PREFETCH = Prefetch( + "sets", + DatasetSet.objects.annotate(element_count=Count("set_elements")) +) + def _fetch_datasetelement_neighbors(datasetelements): """ @@ -684,8 +690,14 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): serializer_class = DatasetSerializer def get_queryset(self): - queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) - return queryset.select_related("corpus", "creator") + queryset = ( + Dataset.objects + .filter(corpus__in=Corpus.objects.readable(self.request.user)) + .select_related("corpus", "creator") + ) + if self.request.method != "DELETE": + queryset = queryset.prefetch_related(SET_COUNTS_PREFETCH) + return queryset def check_object_permissions(self, request, obj): super().check_object_permissions(request, obj) @@ -998,6 +1010,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): .values_list("element_id", "set__name") ]) + # Add the set counts to the API response + prefetch_related_objects([clone], SET_COUNTS_PREFETCH) + return Response( DatasetSerializer(clone).data, status=status.HTTP_201_CREATED,