Skip to content
Snippets Groups Projects
Commit ca9ea129 authored by ml bonhomme's avatar ml bonhomme :bee:
Browse files

WIP fix API endpoints :'(

parent 2e89a321
No related branches found
No related tags found
1 merge request!2256New DatasetSet model
......@@ -5579,7 +5579,7 @@
"model": "training.datasetset",
"pk": "00e3b37b-f1ed-4adb-a1de-a1c103edaa24",
"fields": {
"name": "Test",
"name": "test",
"dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
}
},
......@@ -5587,7 +5587,7 @@
"model": "training.datasetset",
"pk": "95255ff4-7bca-424c-8e7e-d8b5c33c585f",
"fields": {
"name": "Train",
"name": "training",
"dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
}
},
......@@ -5595,7 +5595,7 @@
"model": "training.datasetset",
"pk": "b21e4b31-1dc1-4015-9933-3a8e180cf2e0",
"fields": {
"name": "Validation",
"name": "validation",
"dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
}
},
......@@ -5603,7 +5603,7 @@
"model": "training.datasetset",
"pk": "b76d919c-ce7e-4896-94c3-9b47862b997a",
"fields": {
"name": "Validation",
"name": "validation",
"dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
}
},
......@@ -5611,7 +5611,7 @@
"model": "training.datasetset",
"pk": "d5f4d410-0588-4d19-8d08-b57a11bad67f",
"fields": {
"name": "Train",
"name": "training",
"dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
}
},
......@@ -5619,7 +5619,7 @@
"model": "training.datasetset",
"pk": "db70ad8a-8d6b-44c6-b8d3-837bf5a4ab8e",
"fields": {
"name": "Test",
"name": "test",
"dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
}
}
......
......@@ -109,7 +109,7 @@ from arkindex.training.api import (
DatasetElementDestroy,
DatasetElements,
DatasetUpdate,
ElementDatasets,
ElementDatasetSets,
MetricValueBulkCreate,
MetricValueCreate,
ModelCompatibleWorkerManage,
......@@ -184,7 +184,7 @@ api = [
# Datasets
path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"),
path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"),
path("element/<uuid:pk>/datasets/", ElementDatasets.as_view(), name="element-datasets"),
path("element/<uuid:pk>/datasets/", ElementDatasetSets.as_view(), name="element-datasets"),
path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"),
path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"),
path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"),
......
......@@ -286,7 +286,7 @@ class DatasetSetsCountField(serializers.DictField):
return None
elts_count = {k.name: 0 for k in instance.sets.all()}
elts_count.update(
DatasetElement.objects.filter(set__dataset_id=instance.id)
DatasetElement.objects.filter(set_id__in=instance.sets.values_list("id"))
.values("set__name")
.annotate(count=Count("id"))
.values_list("set__name", "count")
......
......@@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap
from arkindex.training.models import (
Dataset,
DatasetElement,
DatasetSet,
DatasetState,
MetricValue,
Model,
......@@ -40,7 +41,7 @@ from arkindex.training.serializers import (
DatasetElementInfoSerializer,
DatasetElementSerializer,
DatasetSerializer,
ElementDatasetSerializer,
ElementDatasetSetSerializer,
MetricValueBulkSerializer,
MetricValueCreateSerializer,
ModelCompatibleWorkerSerializer,
......@@ -58,7 +59,7 @@ def _fetch_datasetelement_neighbors(datasetelements):
"""
Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
with next and previous attributes.
The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and
The ElementDatasetSets endpoint uses arkindex.project.tools.BulkMap to apply this method and
perform the second request *after* DRF's pagination, because there is no way to perform
post-processing after pagination in Django without having to use Django private methods.
"""
......@@ -687,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
serializer_class = DatasetSerializer
def get_queryset(self):
queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)).prefetch_related("sets")
return queryset.select_related("corpus", "creator")
def check_object_permissions(self, request, obj):
......@@ -708,7 +709,8 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.")
def perform_destroy(self, dataset):
dataset.dataset_elements.all().delete()
DatasetElement.objects.filter(set__dataset_id=dataset.id).delete()
dataset.sets.all().delete()
super().perform_destroy(dataset)
......@@ -769,13 +771,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
def get_queryset(self):
qs = (
self.dataset.dataset_elements
.prefetch_related("element")
DatasetElement.objects.filter(set__dataset_id=self.dataset.id)
.prefetch_related("element", "set")
.select_related("element__type", "element__corpus", "element__image__server")
.order_by("element_id", "id")
)
if "set" in self.request.query_params:
qs = qs.filter(set=self.request.query_params["set"])
qs = qs.filter(set__name=self.request.query_params["set"])
return qs
def get_serializer_context(self):
......@@ -788,20 +790,12 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
@extend_schema_view(
delete=extend_schema(
operation_id="DestroyDatasetElement",
parameters=[
OpenApiParameter(
"set",
type=str,
description="Name of the set from which to remove the element.",
required=True,
)
],
tags=["datasets"]
)
)
class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
"""
Remove an element from a dataset.
Remove an element from a dataset set.
Elements can only be removed from **open** datasets.
......@@ -812,17 +806,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
lookup_url_kwarg = "element"
def destroy(self, request, *args, **kwargs):
if not self.request.query_params.get("set"):
if not (set_name := self.request.query_params.get("set")):
raise ValidationError({"set": ["This field is required."]})
dataset_element = get_object_or_404(
DatasetElement.objects.select_related("dataset__corpus"),
dataset_id=self.kwargs["dataset"],
element_id=self.kwargs["element"],
set=self.request.query_params.get("set")
DatasetElement.objects.select_related("set__dataset__corpus").filter(set__dataset_id=self.kwargs["dataset"], set__name=set_name),
element_id=self.kwargs["element"]
)
if dataset_element.dataset.state != DatasetState.Open:
if dataset_element.set.dataset.state != DatasetState.Open:
raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]})
if not self.has_write_access(dataset_element.dataset.corpus):
if not self.has_write_access(dataset_element.set.dataset.corpus):
raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.")
dataset_element.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
......@@ -897,14 +889,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
)
],
)
class ElementDatasets(CorpusACLMixin, ListAPIView):
class ElementDatasetSets(CorpusACLMixin, ListAPIView):
"""
List all datasets containing a specific element.
List all dataset sets containing a specific element.
Requires a **guest** access to the element's corpus.
"""
permission_classes = (IsVerifiedOrReadOnly, )
serializer_class = ElementDatasetSerializer
serializer_class = ElementDatasetSetSerializer
@cached_property
def element(self):
......@@ -916,9 +908,9 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
def get_queryset(self):
qs = (
self.element.dataset_elements.all()
.select_related("dataset__creator")
.order_by("dataset__name", "set", "dataset_id")
self.element.dataset_elements
.select_related("set__dataset__creator")
.order_by("set__name", "id")
)
with_neighbors = self.request.query_params.get("with_neighbors", "false")
......@@ -996,11 +988,18 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
clone.creator = request.user
clone.save()
# Clone dataset sets
cloned_sets = DatasetSet.objects.bulk_create([
DatasetSet(dataset_id=clone.id, name=set.name)
for set in dataset.sets.all()
])
# Associate all elements to the clone
DatasetElement.objects.bulk_create([
DatasetElement(element_id=elt_id, dataset_id=clone.id, set=set_name)
for elt_id, set_name in dataset.dataset_elements.values_list("element_id", "set")
DatasetElement(element_id=elt_id, set=next(new_set for new_set in cloned_sets if new_set.name == set_name))
for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id)
.values_list("element_id", "set__name")
])
return Response(
DatasetSerializer(clone).data,
status=status.HTTP_201_CREATED,
......
......@@ -59,7 +59,7 @@ class Migration(migrations.Migration):
UPDATE training_datasetelement de
SET set_id_id = ds.id
FROM training_datasetset ds
WHERE de.dataset_id = ds.dataset_id
WHERE de.dataset_id = ds.dataset_id AND de.set = ds.name
""",
reverse_sql=migrations.RunSQL.noop,
),
......@@ -85,6 +85,10 @@ class Migration(migrations.Migration):
name="set",
field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="set_elements", to="training.datasetset"),
),
migrations.AddConstraint(
model_name="datasetelement",
constraint=models.UniqueConstraint(fields=("element_id", "set_id"), name="unique_set_element"),
),
migrations.RemoveField(
model_name="dataset",
name="sets"
......
......@@ -309,3 +309,11 @@ class DatasetElement(models.Model):
related_name="set_elements",
on_delete=models.DO_NOTHING,
)
class Meta:
constraints = [
models.UniqueConstraint(
fields=["element_id", "set_id"],
name="unique_set_element",
),
]
......@@ -10,7 +10,7 @@ from rest_framework import permissions, serializers
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.validators import UniqueTogetherValidator
from arkindex.documents.models import Element
from arkindex.documents.models import Corpus, Element
from arkindex.documents.serializers.elements import ElementListSerializer
from arkindex.ponos.models import Task
from arkindex.process.models import Worker
......@@ -512,7 +512,7 @@ class DatasetSerializer(serializers.ModelSerializer):
help_text="Display name of the user who created the dataset.",
)
set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, default=["Training", "Validation", "Test"])
set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, required=False)
sets = DatasetSetSerializer(many=True, read_only=True)
# When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
......@@ -555,6 +555,8 @@ class DatasetSerializer(serializers.ModelSerializer):
raise ValidationError("This API endpoint does not allow updating a dataset's sets.")
if set_names is not None and len(set(set_names)) != len(set_names):
raise ValidationError("Set names must be unique.")
if set_names is not None and len(set_names) == 0:
raise ValidationError("Either do not specify set names to use the default values, or specify a non-empty list of names.")
return set_names
def validate(self, data):
......@@ -585,7 +587,10 @@ class DatasetSerializer(serializers.ModelSerializer):
@transaction.atomic
def create(self, validated_data):
sets = validated_data.pop("set_names")
if "set_names" not in validated_data:
sets = ["training", "validation", "test"]
else:
sets = validated_data.pop("set_names")
dataset = Dataset.objects.create(**validated_data)
DatasetSet.objects.bulk_create(
DatasetSet(
......@@ -657,6 +662,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
default=_dataset_from_context,
write_only=True,
)
set = serializers.SlugRelatedField(queryset=DatasetSet.objects.none(), slug_field="name")
class Meta:
model = DatasetElement
......@@ -665,7 +671,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
validators = [
UniqueTogetherValidator(
queryset=DatasetElement.objects.all(),
fields=["dataset", "element_id", "set"],
fields=["element_id", "set"],
message="This element is already part of this set.",
)
]
......@@ -674,13 +680,12 @@ class DatasetElementSerializer(serializers.ModelSerializer):
super().__init__(*args, **kwargs)
if dataset := self.context.get("dataset"):
self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
self.fields["set"].queryset = dataset.sets.all()
def validate_set(self, value):
# The set must match the `sets` array defined at the dataset level
dataset = self.context["dataset"]
if dataset and value not in dataset.sets:
raise ValidationError(f"This dataset has no set named {value}.")
return value
def validate(self, data):
data = super().validate(data)
data.pop("dataset")
return data
class DatasetElementInfoSerializer(DatasetElementSerializer):
......@@ -698,10 +703,11 @@ class DatasetElementInfoSerializer(DatasetElementSerializer):
fields = DatasetElementSerializer.Meta.fields + ("dataset",)
class ElementDatasetSerializer(serializers.ModelSerializer):
dataset = DatasetSerializer()
class ElementDatasetSetSerializer(serializers.ModelSerializer):
dataset = DatasetSerializer(source="set.dataset")
previous = serializers.UUIDField(allow_null=True, read_only=True)
next = serializers.UUIDField(allow_null=True, read_only=True)
set = serializers.SlugRelatedField(slug_field="name", read_only=True)
class Meta:
model = DatasetElement
......@@ -721,7 +727,7 @@ class SelectionDatasetElementSerializer(serializers.Serializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields["set_id"].queryset = DatasetSet.objects.filter(
dataset__corpus_id=self.context["corpus"].id
dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user)
).select_related("dataset")
def validate_set_id(self, set):
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment