Skip to content
Snippets Groups Projects
Commit 2a64e13d authored by ml bonhomme's avatar ml bonhomme :bee:
Browse files

WIP fix API endpoints :'(

parent 5fe13227
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !2256. Comments created here will be created in the context of that merge request.
...@@ -5579,7 +5579,7 @@ ...@@ -5579,7 +5579,7 @@
"model": "training.datasetset", "model": "training.datasetset",
"pk": "00e3b37b-f1ed-4adb-a1de-a1c103edaa24", "pk": "00e3b37b-f1ed-4adb-a1de-a1c103edaa24",
"fields": { "fields": {
"name": "Test", "name": "test",
"dataset": "170bc276-0e11-483a-9486-88a4bf2079b2" "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
} }
}, },
...@@ -5587,7 +5587,7 @@ ...@@ -5587,7 +5587,7 @@
"model": "training.datasetset", "model": "training.datasetset",
"pk": "95255ff4-7bca-424c-8e7e-d8b5c33c585f", "pk": "95255ff4-7bca-424c-8e7e-d8b5c33c585f",
"fields": { "fields": {
"name": "Train", "name": "training",
"dataset": "170bc276-0e11-483a-9486-88a4bf2079b2" "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
} }
}, },
...@@ -5595,7 +5595,7 @@ ...@@ -5595,7 +5595,7 @@
"model": "training.datasetset", "model": "training.datasetset",
"pk": "b21e4b31-1dc1-4015-9933-3a8e180cf2e0", "pk": "b21e4b31-1dc1-4015-9933-3a8e180cf2e0",
"fields": { "fields": {
"name": "Validation", "name": "validation",
"dataset": "170bc276-0e11-483a-9486-88a4bf2079b2" "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2"
} }
}, },
...@@ -5603,7 +5603,7 @@ ...@@ -5603,7 +5603,7 @@
"model": "training.datasetset", "model": "training.datasetset",
"pk": "b76d919c-ce7e-4896-94c3-9b47862b997a", "pk": "b76d919c-ce7e-4896-94c3-9b47862b997a",
"fields": { "fields": {
"name": "Validation", "name": "validation",
"dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810" "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
} }
}, },
...@@ -5611,7 +5611,7 @@ ...@@ -5611,7 +5611,7 @@
"model": "training.datasetset", "model": "training.datasetset",
"pk": "d5f4d410-0588-4d19-8d08-b57a11bad67f", "pk": "d5f4d410-0588-4d19-8d08-b57a11bad67f",
"fields": { "fields": {
"name": "Train", "name": "training",
"dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810" "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
} }
}, },
...@@ -5619,7 +5619,7 @@ ...@@ -5619,7 +5619,7 @@
"model": "training.datasetset", "model": "training.datasetset",
"pk": "db70ad8a-8d6b-44c6-b8d3-837bf5a4ab8e", "pk": "db70ad8a-8d6b-44c6-b8d3-837bf5a4ab8e",
"fields": { "fields": {
"name": "Test", "name": "test",
"dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810" "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810"
} }
} }
......
...@@ -109,7 +109,7 @@ from arkindex.training.api import ( ...@@ -109,7 +109,7 @@ from arkindex.training.api import (
DatasetElementDestroy, DatasetElementDestroy,
DatasetElements, DatasetElements,
DatasetUpdate, DatasetUpdate,
ElementDatasets, ElementDatasetSets,
MetricValueBulkCreate, MetricValueBulkCreate,
MetricValueCreate, MetricValueCreate,
ModelCompatibleWorkerManage, ModelCompatibleWorkerManage,
...@@ -184,7 +184,7 @@ api = [ ...@@ -184,7 +184,7 @@ api = [
# Datasets # Datasets
path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"), path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"),
path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"), path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"),
path("element/<uuid:pk>/datasets/", ElementDatasets.as_view(), name="element-datasets"), path("element/<uuid:pk>/datasets/", ElementDatasetSets.as_view(), name="element-datasets"),
path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"), path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"),
path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"), path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"),
path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"), path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"),
......
...@@ -286,7 +286,7 @@ class DatasetSetsCountField(serializers.DictField): ...@@ -286,7 +286,7 @@ class DatasetSetsCountField(serializers.DictField):
return None return None
elts_count = {k.name: 0 for k in instance.sets.all()} elts_count = {k.name: 0 for k in instance.sets.all()}
elts_count.update( elts_count.update(
DatasetElement.objects.filter(set__dataset_id=instance.id) DatasetElement.objects.filter(set_id__in=instance.sets.values_list("id"))
.values("set__name") .values("set__name")
.annotate(count=Count("id")) .annotate(count=Count("id"))
.values_list("set__name", "count") .values_list("set__name", "count")
......
...@@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap ...@@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap
from arkindex.training.models import ( from arkindex.training.models import (
Dataset, Dataset,
DatasetElement, DatasetElement,
DatasetSet,
DatasetState, DatasetState,
MetricValue, MetricValue,
Model, Model,
...@@ -40,7 +41,7 @@ from arkindex.training.serializers import ( ...@@ -40,7 +41,7 @@ from arkindex.training.serializers import (
DatasetElementInfoSerializer, DatasetElementInfoSerializer,
DatasetElementSerializer, DatasetElementSerializer,
DatasetSerializer, DatasetSerializer,
ElementDatasetSerializer, ElementDatasetSetSerializer,
MetricValueBulkSerializer, MetricValueBulkSerializer,
MetricValueCreateSerializer, MetricValueCreateSerializer,
ModelCompatibleWorkerSerializer, ModelCompatibleWorkerSerializer,
...@@ -58,7 +59,7 @@ def _fetch_datasetelement_neighbors(datasetelements): ...@@ -58,7 +59,7 @@ def _fetch_datasetelement_neighbors(datasetelements):
""" """
Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
with next and previous attributes. with next and previous attributes.
The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and The ElementDatasetSets endpoint uses arkindex.project.tools.BulkMap to apply this method and
perform the second request *after* DRF's pagination, because there is no way to perform perform the second request *after* DRF's pagination, because there is no way to perform
post-processing after pagination in Django without having to use Django private methods. post-processing after pagination in Django without having to use Django private methods.
""" """
...@@ -687,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): ...@@ -687,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
serializer_class = DatasetSerializer serializer_class = DatasetSerializer
def get_queryset(self): def get_queryset(self):
queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)).prefetch_related("sets")
return queryset.select_related("corpus", "creator") return queryset.select_related("corpus", "creator")
def check_object_permissions(self, request, obj): def check_object_permissions(self, request, obj):
...@@ -708,7 +709,8 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): ...@@ -708,7 +709,8 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.") raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.")
def perform_destroy(self, dataset): def perform_destroy(self, dataset):
dataset.dataset_elements.all().delete() DatasetElement.objects.filter(set__dataset_id=dataset.id).delete()
dataset.sets.all().delete()
super().perform_destroy(dataset) super().perform_destroy(dataset)
...@@ -769,13 +771,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView): ...@@ -769,13 +771,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
def get_queryset(self): def get_queryset(self):
qs = ( qs = (
self.dataset.dataset_elements DatasetElement.objects.filter(set__dataset_id=self.dataset.id)
.prefetch_related("element") .prefetch_related("element", "set")
.select_related("element__type", "element__corpus", "element__image__server") .select_related("element__type", "element__corpus", "element__image__server")
.order_by("element_id", "id") .order_by("element_id", "id")
) )
if "set" in self.request.query_params: if "set" in self.request.query_params:
qs = qs.filter(set=self.request.query_params["set"]) qs = qs.filter(set__name=self.request.query_params["set"])
return qs return qs
def get_serializer_context(self): def get_serializer_context(self):
...@@ -788,20 +790,12 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView): ...@@ -788,20 +790,12 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
@extend_schema_view( @extend_schema_view(
delete=extend_schema( delete=extend_schema(
operation_id="DestroyDatasetElement", operation_id="DestroyDatasetElement",
parameters=[
OpenApiParameter(
"set",
type=str,
description="Name of the set from which to remove the element.",
required=True,
)
],
tags=["datasets"] tags=["datasets"]
) )
) )
class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView): class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
""" """
Remove an element from a dataset. Remove an element from a dataset set.
Elements can only be removed from **open** datasets. Elements can only be removed from **open** datasets.
...@@ -812,17 +806,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView): ...@@ -812,17 +806,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
lookup_url_kwarg = "element" lookup_url_kwarg = "element"
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
if not self.request.query_params.get("set"): if not (set_name := self.request.query_params.get("set")):
raise ValidationError({"set": ["This field is required."]}) raise ValidationError({"set": ["This field is required."]})
dataset_element = get_object_or_404( dataset_element = get_object_or_404(
DatasetElement.objects.select_related("dataset__corpus"), DatasetElement.objects.select_related("set__dataset__corpus").filter(set__dataset_id=self.kwargs["dataset"], set__name=set_name),
dataset_id=self.kwargs["dataset"], element_id=self.kwargs["element"]
element_id=self.kwargs["element"],
set=self.request.query_params.get("set")
) )
if dataset_element.dataset.state != DatasetState.Open: if dataset_element.set.dataset.state != DatasetState.Open:
raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]}) raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]})
if not self.has_write_access(dataset_element.dataset.corpus): if not self.has_write_access(dataset_element.set.dataset.corpus):
raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.") raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.")
dataset_element.delete() dataset_element.delete()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)
...@@ -897,14 +889,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView): ...@@ -897,14 +889,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
) )
], ],
) )
class ElementDatasets(CorpusACLMixin, ListAPIView): class ElementDatasetSets(CorpusACLMixin, ListAPIView):
""" """
List all datasets containing a specific element. List all dataset sets containing a specific element.
Requires a **guest** access to the element's corpus. Requires a **guest** access to the element's corpus.
""" """
permission_classes = (IsVerifiedOrReadOnly, ) permission_classes = (IsVerifiedOrReadOnly, )
serializer_class = ElementDatasetSerializer serializer_class = ElementDatasetSetSerializer
@cached_property @cached_property
def element(self): def element(self):
...@@ -916,9 +908,9 @@ class ElementDatasets(CorpusACLMixin, ListAPIView): ...@@ -916,9 +908,9 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
def get_queryset(self): def get_queryset(self):
qs = ( qs = (
self.element.dataset_elements.all() self.element.dataset_elements
.select_related("dataset__creator") .select_related("set__dataset__creator")
.order_by("dataset__name", "set", "dataset_id") .order_by("set__name", "id")
) )
with_neighbors = self.request.query_params.get("with_neighbors", "false") with_neighbors = self.request.query_params.get("with_neighbors", "false")
...@@ -996,11 +988,18 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): ...@@ -996,11 +988,18 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
clone.creator = request.user clone.creator = request.user
clone.save() clone.save()
# Clone dataset sets
cloned_sets = DatasetSet.objects.bulk_create([
DatasetSet(dataset_id=clone.id, name=set.name)
for set in dataset.sets.all()
])
# Associate all elements to the clone # Associate all elements to the clone
DatasetElement.objects.bulk_create([ DatasetElement.objects.bulk_create([
DatasetElement(element_id=elt_id, dataset_id=clone.id, set=set_name) DatasetElement(element_id=elt_id, set=next(new_set for new_set in cloned_sets if new_set.name == set_name))
for elt_id, set_name in dataset.dataset_elements.values_list("element_id", "set") for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id)
.values_list("element_id", "set__name")
]) ])
return Response( return Response(
DatasetSerializer(clone).data, DatasetSerializer(clone).data,
status=status.HTTP_201_CREATED, status=status.HTTP_201_CREATED,
......
...@@ -59,7 +59,7 @@ class Migration(migrations.Migration): ...@@ -59,7 +59,7 @@ class Migration(migrations.Migration):
UPDATE training_datasetelement de UPDATE training_datasetelement de
SET set_id_id = ds.id SET set_id_id = ds.id
FROM training_datasetset ds FROM training_datasetset ds
WHERE de.dataset_id = ds.dataset_id WHERE de.dataset_id = ds.dataset_id AND de.set = ds.name
""", """,
reverse_sql=migrations.RunSQL.noop, reverse_sql=migrations.RunSQL.noop,
), ),
...@@ -85,6 +85,10 @@ class Migration(migrations.Migration): ...@@ -85,6 +85,10 @@ class Migration(migrations.Migration):
name="set", name="set",
field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="set_elements", to="training.datasetset"), field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="set_elements", to="training.datasetset"),
), ),
migrations.AddConstraint(
model_name="datasetelement",
constraint=models.UniqueConstraint(fields=("element_id", "set_id"), name="unique_set_element"),
),
migrations.RemoveField( migrations.RemoveField(
model_name="dataset", model_name="dataset",
name="sets" name="sets"
......
...@@ -309,3 +309,11 @@ class DatasetElement(models.Model): ...@@ -309,3 +309,11 @@ class DatasetElement(models.Model):
related_name="set_elements", related_name="set_elements",
on_delete=models.DO_NOTHING, on_delete=models.DO_NOTHING,
) )
class Meta:
constraints = [
models.UniqueConstraint(
fields=["element_id", "set_id"],
name="unique_set_element",
),
]
...@@ -10,7 +10,7 @@ from rest_framework import permissions, serializers ...@@ -10,7 +10,7 @@ from rest_framework import permissions, serializers
from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.validators import UniqueTogetherValidator from rest_framework.validators import UniqueTogetherValidator
from arkindex.documents.models import Element from arkindex.documents.models import Corpus, Element
from arkindex.documents.serializers.elements import ElementListSerializer from arkindex.documents.serializers.elements import ElementListSerializer
from arkindex.ponos.models import Task from arkindex.ponos.models import Task
from arkindex.process.models import Worker from arkindex.process.models import Worker
...@@ -512,7 +512,7 @@ class DatasetSerializer(serializers.ModelSerializer): ...@@ -512,7 +512,7 @@ class DatasetSerializer(serializers.ModelSerializer):
help_text="Display name of the user who created the dataset.", help_text="Display name of the user who created the dataset.",
) )
set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, default=["Training", "Validation", "Test"]) set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, required=False)
sets = DatasetSetSerializer(many=True, read_only=True) sets = DatasetSetSerializer(many=True, read_only=True)
# When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through # When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
...@@ -555,6 +555,8 @@ class DatasetSerializer(serializers.ModelSerializer): ...@@ -555,6 +555,8 @@ class DatasetSerializer(serializers.ModelSerializer):
raise ValidationError("This API endpoint does not allow updating a dataset's sets.") raise ValidationError("This API endpoint does not allow updating a dataset's sets.")
if set_names is not None and len(set(set_names)) != len(set_names): if set_names is not None and len(set(set_names)) != len(set_names):
raise ValidationError("Set names must be unique.") raise ValidationError("Set names must be unique.")
if set_names is not None and len(set_names) == 0:
raise ValidationError("Either do not specify set names to use the default values, or specify a non-empty list of names.")
return set_names return set_names
def validate(self, data): def validate(self, data):
...@@ -585,7 +587,10 @@ class DatasetSerializer(serializers.ModelSerializer): ...@@ -585,7 +587,10 @@ class DatasetSerializer(serializers.ModelSerializer):
@transaction.atomic @transaction.atomic
def create(self, validated_data): def create(self, validated_data):
sets = validated_data.pop("set_names") if "set_names" not in validated_data:
sets = ["training", "validation", "test"]
else:
sets = validated_data.pop("set_names")
dataset = Dataset.objects.create(**validated_data) dataset = Dataset.objects.create(**validated_data)
DatasetSet.objects.bulk_create( DatasetSet.objects.bulk_create(
DatasetSet( DatasetSet(
...@@ -657,6 +662,7 @@ class DatasetElementSerializer(serializers.ModelSerializer): ...@@ -657,6 +662,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
default=_dataset_from_context, default=_dataset_from_context,
write_only=True, write_only=True,
) )
set = serializers.SlugRelatedField(queryset=DatasetSet.objects.none(), slug_field="name")
class Meta: class Meta:
model = DatasetElement model = DatasetElement
...@@ -665,7 +671,7 @@ class DatasetElementSerializer(serializers.ModelSerializer): ...@@ -665,7 +671,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
validators = [ validators = [
UniqueTogetherValidator( UniqueTogetherValidator(
queryset=DatasetElement.objects.all(), queryset=DatasetElement.objects.all(),
fields=["dataset", "element_id", "set"], fields=["element_id", "set"],
message="This element is already part of this set.", message="This element is already part of this set.",
) )
] ]
...@@ -674,13 +680,12 @@ class DatasetElementSerializer(serializers.ModelSerializer): ...@@ -674,13 +680,12 @@ class DatasetElementSerializer(serializers.ModelSerializer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if dataset := self.context.get("dataset"): if dataset := self.context.get("dataset"):
self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus) self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
self.fields["set"].queryset = dataset.sets.all()
def validate_set(self, value): def validate(self, data):
# The set must match the `sets` array defined at the dataset level data = super().validate(data)
dataset = self.context["dataset"] data.pop("dataset")
if dataset and value not in dataset.sets: return data
raise ValidationError(f"This dataset has no set named {value}.")
return value
class DatasetElementInfoSerializer(DatasetElementSerializer): class DatasetElementInfoSerializer(DatasetElementSerializer):
...@@ -698,10 +703,11 @@ class DatasetElementInfoSerializer(DatasetElementSerializer): ...@@ -698,10 +703,11 @@ class DatasetElementInfoSerializer(DatasetElementSerializer):
fields = DatasetElementSerializer.Meta.fields + ("dataset",) fields = DatasetElementSerializer.Meta.fields + ("dataset",)
class ElementDatasetSerializer(serializers.ModelSerializer): class ElementDatasetSetSerializer(serializers.ModelSerializer):
dataset = DatasetSerializer() dataset = DatasetSerializer(source="set.dataset")
previous = serializers.UUIDField(allow_null=True, read_only=True) previous = serializers.UUIDField(allow_null=True, read_only=True)
next = serializers.UUIDField(allow_null=True, read_only=True) next = serializers.UUIDField(allow_null=True, read_only=True)
set = serializers.SlugRelatedField(slug_field="name", read_only=True)
class Meta: class Meta:
model = DatasetElement model = DatasetElement
...@@ -721,7 +727,7 @@ class SelectionDatasetElementSerializer(serializers.Serializer): ...@@ -721,7 +727,7 @@ class SelectionDatasetElementSerializer(serializers.Serializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.fields["set_id"].queryset = DatasetSet.objects.filter( self.fields["set_id"].queryset = DatasetSet.objects.filter(
dataset__corpus_id=self.context["corpus"].id dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user)
).select_related("dataset") ).select_related("dataset")
def validate_set_id(self, set): def validate_set_id(self, set):
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment