Skip to content
Snippets Groups Projects

New DatasetSet model

Merged ml bonhomme requested to merge dataset-sets-reset into master
1 file
+ 8
0
Compare changes
  • Side-by-side
  • Inline
+ 75
40
@@ -3,7 +3,7 @@ from textwrap import dedent
from uuid import UUID
from django.db import connection, transaction
from django.db.models import Q
from django.db.models import Count, Prefetch, Q, prefetch_related_objects
from django.shortcuts import get_object_or_404
from django.utils.functional import cached_property
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
@@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap
from arkindex.training.models import (
Dataset,
DatasetElement,
DatasetSet,
DatasetState,
MetricValue,
Model,
@@ -40,7 +41,7 @@ from arkindex.training.serializers import (
DatasetElementInfoSerializer,
DatasetElementSerializer,
DatasetSerializer,
ElementDatasetSerializer,
ElementDatasetSetSerializer,
MetricValueBulkSerializer,
MetricValueCreateSerializer,
ModelCompatibleWorkerSerializer,
@@ -53,12 +54,18 @@ from arkindex.training.serializers import (
from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
# A prefetch object that includes the number of elements per set.
DATASET_SET_COUNTS_PREFETCH = Prefetch(
"sets",
DatasetSet.objects.annotate(element_count=Count("set_elements")).order_by("name")
)
def _fetch_datasetelement_neighbors(datasetelements):
"""
Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
with next and previous attributes.
The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and
The ElementDatasetSets endpoint uses arkindex.project.tools.BulkMap to apply this method and
perform the second request *after* DRF's pagination, because there is no way to perform
post-processing after pagination in Django without having to use Django private methods.
"""
@@ -71,18 +78,18 @@ def _fetch_datasetelement_neighbors(datasetelements):
SELECT
n.id,
lag(element_id) OVER (
partition BY (n.dataset_id, n.set)
partition BY (n.set_id)
order by
n.element_id
) as previous,
lead(element_id) OVER (
partition BY (n.dataset_id, n.set)
partition BY (n.set_id)
order by
n.element_id
) as next
FROM training_datasetelement as n
WHERE (dataset_id, set) IN (
SELECT dataset_id, set
WHERE set_id IN (
SELECT set_id
FROM training_datasetelement
WHERE id IN %(ids)s
)
@@ -609,6 +616,11 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView):
def get_queryset(self):
return Dataset.objects \
.select_related("creator") \
.prefetch_related(Prefetch(
"sets",
# Prefetch sets, but ensure they are ordered by name
DatasetSet.objects.order_by("name")
)) \
.filter(corpus=self.corpus) \
.order_by("name")
@@ -625,10 +637,6 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView):
if not self.kwargs:
return context
context["corpus"] = self.corpus
# Avoids aggregating the number of elements per set on each
# entry, which would cause 1 extra query per dataset
if self.request.method in permissions.SAFE_METHODS:
context["sets_count"] = False
return context
@@ -686,8 +694,14 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
serializer_class = DatasetSerializer
def get_queryset(self):
queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
return queryset.select_related("corpus", "creator")
queryset = (
Dataset.objects
.filter(corpus__in=Corpus.objects.readable(self.request.user))
.select_related("corpus", "creator")
)
if self.request.method != "DELETE":
queryset = queryset.prefetch_related(DATASET_SET_COUNTS_PREFETCH)
return queryset
def check_object_permissions(self, request, obj):
super().check_object_permissions(request, obj)
@@ -706,8 +720,20 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
if obj.state == DatasetState.Complete:
raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.")
def update(self, request, *args, **kwargs):
# Do exactly the same thing as what DRF does, but without the automatic prefetch cache removal:
# https://github.com/encode/django-rest-framework/blob/2da473c8c8e024e80c13a624782f1da6272812da/rest_framework/mixins.py#L70
# This allows `set_elements` to still be returned after the update.
partial = kwargs.pop("partial", False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return Response(serializer.data)
def perform_destroy(self, dataset):
dataset.dataset_elements.all().delete()
DatasetElement.objects.filter(set__dataset_id=dataset.id).delete()
dataset.sets.all().delete()
super().perform_destroy(dataset)
@@ -768,13 +794,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
def get_queryset(self):
qs = (
self.dataset.dataset_elements
.prefetch_related("element")
DatasetElement.objects.filter(set__dataset_id=self.dataset.id)
.prefetch_related("element", "set")
.select_related("element__type", "element__corpus", "element__image__server")
.order_by("element_id", "id")
)
if "set" in self.request.query_params:
qs = qs.filter(set=self.request.query_params["set"])
qs = qs.filter(set__name=self.request.query_params["set"])
return qs
def get_serializer_context(self):
@@ -800,7 +826,7 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
)
class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
"""
Remove an element from a dataset.
Remove an element from a dataset set.
Elements can only be removed from **open** datasets.
@@ -811,17 +837,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
lookup_url_kwarg = "element"
def destroy(self, request, *args, **kwargs):
if not self.request.query_params.get("set"):
if not (set_name := self.request.query_params.get("set")):
raise ValidationError({"set": ["This field is required."]})
dataset_element = get_object_or_404(
DatasetElement.objects.select_related("dataset__corpus"),
dataset_id=self.kwargs["dataset"],
element_id=self.kwargs["element"],
set=self.request.query_params.get("set")
DatasetElement.objects.select_related("set__dataset__corpus").filter(set__dataset_id=self.kwargs["dataset"], set__name=set_name),
element_id=self.kwargs["element"]
)
if dataset_element.dataset.state != DatasetState.Open:
if dataset_element.set.dataset.state != DatasetState.Open:
raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]})
if not self.has_write_access(dataset_element.dataset.corpus):
if not self.has_write_access(dataset_element.set.dataset.corpus):
raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.")
dataset_element.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
@@ -896,14 +920,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
)
],
)
class ElementDatasets(CorpusACLMixin, ListAPIView):
class ElementDatasetSets(CorpusACLMixin, ListAPIView):
"""
List all datasets containing a specific element.
List all dataset sets containing a specific element.
Requires a **guest** access to the element's corpus.
"""
permission_classes = (IsVerifiedOrReadOnly, )
serializer_class = ElementDatasetSerializer
serializer_class = ElementDatasetSetSerializer
@cached_property
def element(self):
@@ -915,9 +939,14 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
def get_queryset(self):
qs = (
self.element.dataset_elements.all()
.select_related("dataset__creator")
.order_by("dataset__name", "set", "dataset_id")
self.element.dataset_elements
.select_related("set__dataset__creator")
.prefetch_related(Prefetch(
"set__dataset__sets",
# Prefetch sets, but ensure they are ordered by name
DatasetSet.objects.order_by("name")
))
.order_by("set__dataset__name", "set__name")
)
with_neighbors = self.request.query_params.get("with_neighbors", "false")
@@ -926,13 +955,6 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
return qs
def get_serializer_context(self):
context = super().get_serializer_context()
# Avoids aggregating the number of elements per set on each
# entry, which would cause 1 extra query per dataset
context["sets_count"] = False
return context
@extend_schema_view(
post=extend_schema(
@@ -995,11 +1017,24 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
clone.creator = request.user
clone.save()
# Clone dataset sets
cloned_sets = DatasetSet.objects.bulk_create([
DatasetSet(dataset_id=clone.id, name=set.name)
for set in dataset.sets.all()
])
set_map = {set.name: set for set in cloned_sets}
# Associate all elements to the clone
DatasetElement.objects.bulk_create([
DatasetElement(element_id=elt_id, dataset_id=clone.id, set=set_name)
for elt_id, set_name in dataset.dataset_elements.values_list("element_id", "set")
DatasetElement(element_id=elt_id, set=set_map[set_name])
for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id)
.values_list("element_id", "set__name")
.iterator()
])
# Add the set counts to the API response
prefetch_related_objects([clone], DATASET_SET_COUNTS_PREFETCH)
return Response(
DatasetSerializer(clone).data,
status=status.HTTP_201_CREATED,
Loading