Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
......@@ -109,7 +109,7 @@ from arkindex.training.api import (
DatasetElementDestroy,
DatasetElements,
DatasetUpdate,
ElementDatasets,
ElementDatasetSets,
MetricValueBulkCreate,
MetricValueCreate,
ModelCompatibleWorkerManage,
......@@ -184,7 +184,7 @@ api = [
# Datasets
path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"),
path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"),
path("element/<uuid:pk>/datasets/", ElementDatasets.as_view(), name="element-datasets"),
path("element/<uuid:pk>/sets/", ElementDatasetSets.as_view(), name="element-sets"),
path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"),
path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"),
path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"),
......
......@@ -4,7 +4,6 @@ from urllib.parse import quote, unquote
import bleach
from django.contrib.gis.geos import Point
from django.db.models import Count
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
......@@ -12,6 +11,7 @@ from arkindex.documents.models import MetaType
from arkindex.ponos.utils import get_process_from_task_auth
from arkindex.process.models import ProcessMode, WorkerRun
from arkindex.project.gis import ensure_linear_ring
from arkindex.project.tools import is_prefetched
class EnumField(serializers.ChoiceField):
......@@ -269,7 +269,9 @@ class ArchivedField(serializers.BooleanField):
class DatasetSetsCountField(serializers.DictField):
"""
Serialize the number of element per set on a dataset.
This value can be disabled by setting `sets_count` to False in the context.
This field is None, unless the sets have been prefetched
with a `element_count` annotation holding the number of elements per set.
"""
def __init__(self, **kwargs):
......@@ -281,16 +283,17 @@ class DatasetSetsCountField(serializers.DictField):
)
def get_attribute(self, instance):
if not self.context.get("sets_count", True):
# Skip this field if sets are not prefetched, or if they are missing a count
if (
not is_prefetched(instance.sets)
or not all(hasattr(set, "element_count") for set in instance.sets.all())
):
return None
elts_count = {k: 0 for k in instance.sets}
elts_count.update(
instance.dataset_elements
.values("set")
.annotate(count=Count("id"))
.values_list("set", "count")
)
return elts_count
return {
set.name: set.element_count
for set in instance.sets.all()
}
class NullField(serializers.CharField):
......
......@@ -188,3 +188,28 @@ def fake_now():
Fake creation date for fixtures and test objects
"""
return datetime(2020, 2, 2, 1, 23, 45, 678000, tzinfo=timezone.utc)
def is_prefetched(related_manager) -> bool:
"""
Determines whether the related items for a reverse foreign key have been prefetched;
that is, if calling `instance.things.all()` will not cause an SQL query.
Usage: `is_prefetched(instance.things)`
"""
return (
related_manager.field.remote_field.get_cache_name()
in getattr(related_manager.instance, "_prefetched_objects_cache", {})
)
def add_as_prefetch(related_manager, items) -> None:
"""
Manually set a list of related items on an instance, as if they were actually prefetched from the database.
Usage: `add_as_prefetch(instance.things, [thing1, thing2])`
"""
assert (
isinstance(items, list) and all(isinstance(item, related_manager.model) for item in items)
), f"Prefetched items should be a list of {related_manager.model} instances."
cache = getattr(related_manager.instance, "_prefetched_objects_cache", {})
cache[related_manager.field.remote_field.get_cache_name()] = items
related_manager.instance._prefetched_objects_cache = cache
......@@ -185,6 +185,15 @@ FROM "training_datasetelement"
WHERE "training_datasetelement"."id" IN
(SELECT U0."id"
FROM "training_datasetelement" U0
INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id")
INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id")
WHERE U2."corpus_id" = '{corpus_id}'::uuid);
DELETE
FROM "training_datasetset"
WHERE "training_datasetset"."id" IN
(SELECT U0."id"
FROM "training_datasetset" U0
INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
WHERE U1."corpus_id" = '{corpus_id}'::uuid);
......
......@@ -189,8 +189,17 @@ FROM "training_datasetelement"
WHERE "training_datasetelement"."id" IN
(SELECT U0."id"
FROM "training_datasetelement" U0
INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
WHERE U1."corpus_id" = '{corpus_id}'::uuid);
INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id")
INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id")
WHERE U2."corpus_id" = '{corpus_id}'::uuid);
DELETE
FROM "training_datasetset"
WHERE "training_datasetset"."id" IN
(SELECT U0."id"
FROM "training_datasetset" U0
INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id")
WHERE U1."corpus_id" = '{corpus_id}'::uuid);
DELETE
FROM "training_dataset"
......
......@@ -2,7 +2,7 @@ from django.contrib import admin
from enumfields.admin import EnumFieldListFilter
from arkindex.project.admin import ArchivedListFilter
from arkindex.training.models import Dataset, MetricKey, MetricValue, Model, ModelVersion
from arkindex.training.models import Dataset, DatasetSet, MetricKey, MetricValue, Model, ModelVersion
class ModelAdmin(admin.ModelAdmin):
......@@ -31,21 +31,17 @@ class MetricKeyAdmin(admin.ModelAdmin):
inlines = [MetricValueInline, ]
class DatasetSetInLine(admin.StackedInline):
model = DatasetSet
class DatasetAdmin(admin.ModelAdmin):
list_display = ("name", "corpus", "state")
list_filter = (("state", EnumFieldListFilter), "corpus")
search_fields = ("name", "description")
fields = ("id", "name", "created", "updated", "description", "corpus", "creator", "task", "sets")
fields = ("id", "name", "created", "updated", "description", "corpus", "creator", "task")
readonly_fields = ("id", "created", "updated", "task")
def get_form(self, *args, **kwargs):
form = super().get_form(*args, **kwargs)
# Add a help text to mention that the set names should be comma-separated.
# This is only done here and not through the usual `help_text=…` in the model
# because this is only relevant to the Django admin and should not appear in
# DRF serializers or the API docs.
form.base_fields["sets"].help_text = "Comma-separated list of set names"
return form
inlines = [DatasetSetInLine, ]
admin.site.register(Model, ModelAdmin)
......
......@@ -3,7 +3,7 @@ from textwrap import dedent
from uuid import UUID
from django.db import connection, transaction
from django.db.models import Q
from django.db.models import Count, Prefetch, Q, prefetch_related_objects
from django.shortcuts import get_object_or_404
from django.utils.functional import cached_property
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
......@@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap
from arkindex.training.models import (
Dataset,
DatasetElement,
DatasetSet,
DatasetState,
MetricValue,
Model,
......@@ -40,7 +41,7 @@ from arkindex.training.serializers import (
DatasetElementInfoSerializer,
DatasetElementSerializer,
DatasetSerializer,
ElementDatasetSerializer,
ElementDatasetSetSerializer,
MetricValueBulkSerializer,
MetricValueCreateSerializer,
ModelCompatibleWorkerSerializer,
......@@ -53,12 +54,18 @@ from arkindex.training.serializers import (
from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
# A prefetch object that includes the number of elements per set.
DATASET_SET_COUNTS_PREFETCH = Prefetch(
"sets",
DatasetSet.objects.annotate(element_count=Count("set_elements")).order_by("name")
)
def _fetch_datasetelement_neighbors(datasetelements):
"""
Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
with next and previous attributes.
The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and
The ElementDatasetSets endpoint uses arkindex.project.tools.BulkMap to apply this method and
perform the second request *after* DRF's pagination, because there is no way to perform
post-processing after pagination in Django without having to use Django private methods.
"""
......@@ -71,18 +78,18 @@ def _fetch_datasetelement_neighbors(datasetelements):
SELECT
n.id,
lag(element_id) OVER (
partition BY (n.dataset_id, n.set)
partition BY (n.set_id)
order by
n.element_id
) as previous,
lead(element_id) OVER (
partition BY (n.dataset_id, n.set)
partition BY (n.set_id)
order by
n.element_id
) as next
FROM training_datasetelement as n
WHERE (dataset_id, set) IN (
SELECT dataset_id, set
WHERE set_id IN (
SELECT set_id
FROM training_datasetelement
WHERE id IN %(ids)s
)
......@@ -609,6 +616,11 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView):
def get_queryset(self):
return Dataset.objects \
.select_related("creator") \
.prefetch_related(Prefetch(
"sets",
# Prefetch sets, but ensure they are ordered by name
DatasetSet.objects.order_by("name")
)) \
.filter(corpus=self.corpus) \
.order_by("name")
......@@ -625,10 +637,6 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView):
if not self.kwargs:
return context
context["corpus"] = self.corpus
# Avoids aggregating the number of elements per set on each
# entry, which would cause 1 extra query per dataset
if self.request.method in permissions.SAFE_METHODS:
context["sets_count"] = False
return context
......@@ -686,8 +694,14 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
serializer_class = DatasetSerializer
def get_queryset(self):
queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
return queryset.select_related("corpus", "creator")
queryset = (
Dataset.objects
.filter(corpus__in=Corpus.objects.readable(self.request.user))
.select_related("corpus", "creator")
)
if self.request.method != "DELETE":
queryset = queryset.prefetch_related(DATASET_SET_COUNTS_PREFETCH)
return queryset
def check_object_permissions(self, request, obj):
super().check_object_permissions(request, obj)
......@@ -706,8 +720,20 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
if obj.state == DatasetState.Complete:
raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.")
def update(self, request, *args, **kwargs):
# Do exactly the same thing as what DRF does, but without the automatic prefetch cache removal:
# https://github.com/encode/django-rest-framework/blob/2da473c8c8e024e80c13a624782f1da6272812da/rest_framework/mixins.py#L70
# This allows `set_elements` to still be returned after the update.
partial = kwargs.pop("partial", False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return Response(serializer.data)
def perform_destroy(self, dataset):
dataset.dataset_elements.all().delete()
DatasetElement.objects.filter(set__dataset_id=dataset.id).delete()
dataset.sets.all().delete()
super().perform_destroy(dataset)
......@@ -768,13 +794,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
def get_queryset(self):
qs = (
self.dataset.dataset_elements
.prefetch_related("element")
DatasetElement.objects.filter(set__dataset_id=self.dataset.id)
.prefetch_related("element", "set")
.select_related("element__type", "element__corpus", "element__image__server")
.order_by("element_id", "id")
)
if "set" in self.request.query_params:
qs = qs.filter(set=self.request.query_params["set"])
qs = qs.filter(set__name=self.request.query_params["set"])
return qs
def get_serializer_context(self):
......@@ -800,7 +826,7 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView):
)
class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
"""
Remove an element from a dataset.
Remove an element from a dataset set.
Elements can only be removed from **open** datasets.
......@@ -811,17 +837,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView):
lookup_url_kwarg = "element"
def destroy(self, request, *args, **kwargs):
if not self.request.query_params.get("set"):
if not (set_name := self.request.query_params.get("set")):
raise ValidationError({"set": ["This field is required."]})
dataset_element = get_object_or_404(
DatasetElement.objects.select_related("dataset__corpus"),
dataset_id=self.kwargs["dataset"],
element_id=self.kwargs["element"],
set=self.request.query_params.get("set")
DatasetElement.objects.select_related("set__dataset__corpus").filter(set__dataset_id=self.kwargs["dataset"], set__name=set_name),
element_id=self.kwargs["element"]
)
if dataset_element.dataset.state != DatasetState.Open:
if dataset_element.set.dataset.state != DatasetState.Open:
raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]})
if not self.has_write_access(dataset_element.dataset.corpus):
if not self.has_write_access(dataset_element.set.dataset.corpus):
raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.")
dataset_element.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
......@@ -896,14 +920,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
)
],
)
class ElementDatasets(CorpusACLMixin, ListAPIView):
class ElementDatasetSets(CorpusACLMixin, ListAPIView):
"""
List all datasets containing a specific element.
List all dataset sets containing a specific element.
Requires a **guest** access to the element's corpus.
"""
permission_classes = (IsVerifiedOrReadOnly, )
serializer_class = ElementDatasetSerializer
serializer_class = ElementDatasetSetSerializer
@cached_property
def element(self):
......@@ -915,9 +939,14 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
def get_queryset(self):
qs = (
self.element.dataset_elements.all()
.select_related("dataset__creator")
.order_by("dataset__name", "set", "dataset_id")
self.element.dataset_elements
.select_related("set__dataset__creator")
.prefetch_related(Prefetch(
"set__dataset__sets",
# Prefetch sets, but ensure they are ordered by name
DatasetSet.objects.order_by("name")
))
.order_by("set__dataset__name", "set__name")
)
with_neighbors = self.request.query_params.get("with_neighbors", "false")
......@@ -926,13 +955,6 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
return qs
def get_serializer_context(self):
context = super().get_serializer_context()
# Avoids aggregating the number of elements per set on each
# entry, which would cause 1 extra query per dataset
context["sets_count"] = False
return context
@extend_schema_view(
post=extend_schema(
......@@ -995,11 +1017,24 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
clone.creator = request.user
clone.save()
# Clone dataset sets
cloned_sets = DatasetSet.objects.bulk_create([
DatasetSet(dataset_id=clone.id, name=set.name)
for set in dataset.sets.all()
])
set_map = {set.name: set for set in cloned_sets}
# Associate all elements to the clone
DatasetElement.objects.bulk_create([
DatasetElement(element_id=elt_id, dataset_id=clone.id, set=set_name)
for elt_id, set_name in dataset.dataset_elements.values_list("element_id", "set")
DatasetElement(element_id=elt_id, set=set_map[set_name])
for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id)
.values_list("element_id", "set__name")
.iterator()
])
# Add the set counts to the API response
prefetch_related_objects([clone], DATASET_SET_COUNTS_PREFETCH)
return Response(
DatasetSerializer(clone).data,
status=status.HTTP_201_CREATED,
......
......@@ -8,9 +8,13 @@ import django.db.models.deletion
import enumfields.fields
from django.db import migrations, models
import arkindex.process.models
import arkindex.project.aws
import arkindex.project.fields
import arkindex.training.models
def default_sets():
return ["training", "test", "validation"]
class Migration(migrations.Migration):
......@@ -32,7 +36,7 @@ class Migration(migrations.Migration):
("name", models.CharField(max_length=100, validators=[django.core.validators.MinLengthValidator(1)])),
("description", models.TextField(validators=[django.core.validators.MinLengthValidator(1)])),
("state", enumfields.fields.EnumField(default="open", enum=arkindex.training.models.DatasetState, max_length=10)),
("sets", django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=50, validators=[django.core.validators.MinLengthValidator(1)]), default=arkindex.training.models.default_sets, size=None, validators=[django.core.validators.MinLengthValidator(1), arkindex.training.models.validate_unique_set_names])),
("sets", django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=50, validators=[django.core.validators.MinLengthValidator(1)]), default=default_sets, size=None, validators=[django.core.validators.MinLengthValidator(1), arkindex.process.models.validate_unique_set_names])),
("corpus", models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="datasets", to="documents.corpus")),
],
),
......
# Generated by Django 4.1.7 on 2024-03-05 16:28
import uuid
import django.core.validators
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("training", "0006_model_archived"),
("process", "0029_processdataset_sets"),
]
operations = [
migrations.CreateModel(
name="DatasetSet",
fields=[
("id", models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)),
("name", models.CharField(max_length=50, validators=[django.core.validators.MinLengthValidator(1)])),
("dataset", models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="tmp_sets", to="training.dataset")),
],
),
migrations.AddConstraint(
model_name="datasetset",
constraint=models.UniqueConstraint(fields=("dataset", "name"), name="unique_dataset_sets"),
),
# Make the old set name and dataset ID fields nullable
# so that they can be filled in when rolling the migration back
migrations.AlterField(
model_name="datasetelement",
name="set",
field=models.CharField(max_length=50, null=True),
),
migrations.AlterField(
model_name="datasetelement",
name="dataset",
field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="elements", to="training.dataset", null=True),
),
migrations.AddField(
model_name="datasetelement",
# Temporary name, because the `set` column already existed as the set name.
# This is referred to as `set_id_id` in the SQL migration,
# and renamed to `set` afterwards.
name="set_id",
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name="dataset_elements", to="training.datasetset"),
),
migrations.RunSQL(
[
"""
INSERT INTO training_datasetset (id, dataset_id, name)
SELECT uuid_generate_v4(), ds.id, ds.set
FROM (
SELECT DISTINCT id, unnest(sets) AS set
FROM training_dataset
) ds
""",
"""
INSERT INTO training_datasetset (id, dataset_id, name)
SELECT uuid_generate_v4(), sets.dataset_id, sets.set
FROM (
SELECT DISTINCT dataset_id, set
FROM training_datasetelement
) sets
ON CONFLICT DO NOTHING
""",
"""
UPDATE training_datasetelement de
SET set_id_id = ds.id
FROM training_datasetset ds
WHERE de.dataset_id = ds.dataset_id AND de.set = ds.name
""",
],
reverse_sql=[
"""
UPDATE training_dataset
SET sets = ARRAY(
SELECT name
FROM training_datasetset
WHERE dataset_id = training_dataset.id
)
""",
"""
UPDATE training_datasetelement de
SET dataset_id = ds.dataset_id, set = ds.name
FROM training_datasetset ds
WHERE ds.id = de.set_id_id
""",
],
),
migrations.RemoveConstraint(
model_name="datasetelement",
name="unique_dataset_elements",
),
migrations.RemoveField(
model_name="datasetelement",
name="dataset"
),
migrations.RemoveField(
model_name="datasetelement",
name="set"
),
migrations.RenameField(
model_name="datasetelement",
old_name="set_id",
new_name="set"
),
migrations.AlterField(
model_name="datasetelement",
name="set",
field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="set_elements", to="training.datasetset"),
),
migrations.AddConstraint(
model_name="datasetelement",
constraint=models.UniqueConstraint(fields=("element_id", "set_id"), name="unique_set_element"),
),
migrations.RemoveField(
model_name="dataset",
name="sets"
),
migrations.RemoveField(
model_name="dataset",
name="elements",
),
migrations.AlterField(
model_name="datasetset",
name="dataset",
field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="sets", to="training.dataset"),
),
]
......@@ -4,8 +4,6 @@ from hashlib import sha256
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
from django.contrib.postgres.fields import ArrayField
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q
......@@ -242,15 +240,6 @@ class DatasetState(Enum):
Error = "error"
def validate_unique_set_names(sets):
if len(set(sets)) != len(sets):
raise ValidationError("Set names must be unique.")
def default_sets():
return ["training", "test", "validation"]
class Dataset(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True, editable=False)
created = models.DateTimeField(auto_now_add=True)
......@@ -278,21 +267,6 @@ class Dataset(models.Model):
description = models.TextField(validators=[MinLengthValidator(1)])
state = EnumField(DatasetState, default=DatasetState.Open, max_length=50)
sets = ArrayField(
models.CharField(max_length=50, validators=[MinLengthValidator(1)]),
validators=[
MinLengthValidator(1),
validate_unique_set_names,
],
default=default_sets,
)
elements = models.ManyToManyField(
"documents.Element",
through="training.DatasetElement",
related_name="datasets",
)
class Meta:
constraints = [
models.UniqueConstraint(
......@@ -305,24 +279,41 @@ class Dataset(models.Model):
return self.name
class DatasetElement(models.Model):
class DatasetSet(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True, editable=False)
name = models.CharField(max_length=50, validators=[MinLengthValidator(1)])
dataset = models.ForeignKey(
Dataset,
related_name="dataset_elements",
related_name="sets",
on_delete=models.DO_NOTHING,
)
class Meta:
constraints = [
models.UniqueConstraint(
fields=["dataset", "name"],
name="unique_dataset_sets",
),
]
class DatasetElement(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True, editable=False)
element = models.ForeignKey(
"documents.Element",
related_name="dataset_elements",
on_delete=models.DO_NOTHING,
)
set = models.CharField(max_length=50, validators=[MinLengthValidator(1)])
set = models.ForeignKey(
DatasetSet,
related_name="set_elements",
on_delete=models.DO_NOTHING,
)
class Meta:
constraints = [
models.UniqueConstraint(
fields=["dataset", "element", "set"],
name="unique_dataset_elements",
fields=["element_id", "set_id"],
name="unique_set_element",
),
]
......@@ -6,18 +6,20 @@ from textwrap import dedent
from django.db import transaction
from django.db.models import Count, Q
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from rest_framework import permissions, serializers
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.validators import UniqueTogetherValidator
from arkindex.documents.models import Element
from arkindex.documents.models import Corpus, Element
from arkindex.documents.serializers.elements import ElementListSerializer
from arkindex.ponos.models import Task
from arkindex.process.models import ProcessDataset, Worker
from arkindex.process.models import Worker
from arkindex.project.serializer_fields import ArchivedField, DatasetSetsCountField, EnumField
from arkindex.project.tools import add_as_prefetch
from arkindex.training.models import (
Dataset,
DatasetElement,
DatasetSet,
DatasetState,
MetricKey,
MetricMode,
......@@ -479,6 +481,12 @@ class MetricValueBulkSerializer(serializers.Serializer):
return validated_data
class DatasetSetSerializer(serializers.ModelSerializer):
class Meta:
model = DatasetSet
fields = ("id", "name",)
class DatasetSerializer(serializers.ModelSerializer):
state = EnumField(
DatasetState,
......@@ -505,6 +513,13 @@ class DatasetSerializer(serializers.ModelSerializer):
help_text="Display name of the user who created the dataset.",
)
set_names = serializers.ListField(
child=serializers.CharField(max_length=50),
write_only=True,
default=serializers.CreateOnlyDefault(["training", "validation", "test"])
)
sets = DatasetSetSerializer(many=True, read_only=True)
# When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
corpus = serializers.HiddenField(default=_corpus_from_context)
......@@ -517,16 +532,6 @@ class DatasetSerializer(serializers.ModelSerializer):
help_text="Distribution of elements in sets. This value is set to null when listing multiple datasets.",
)
def sets_diff(self, new_sets):
"""
Returns a list of removed sets, and a list of added sets for updates
"""
if not isinstance(self.instance, Dataset):
return [], new_sets
current_sets = set(self.instance.sets)
new_sets = set(new_sets)
return list(current_sets - new_sets), list(new_sets - current_sets)
def validate_state(self, state):
"""
Dataset's state update is limited to these transitions:
......@@ -550,27 +555,14 @@ class DatasetSerializer(serializers.ModelSerializer):
raise ValidationError(f"Transition from {self.instance.state} to {state} is not allowed.")
return state
def validate_sets(self, sets):
"""
Ensure sets are updated in a comprehensible way.
It is either possible to add multiple sets,
remove multiple sets or update a single set.
"""
if sets is not None and len(set(sets)) != len(sets):
def validate_set_names(self, set_names):
if self.context["request"].method not in permissions.SAFE_METHODS and self.context["request"].method != "POST":
raise ValidationError("This API endpoint does not allow updating a dataset's sets.")
if set_names is not None and len(set(set_names)) != len(set_names):
raise ValidationError("Set names must be unique.")
removed, added = self.sets_diff(sets)
if removed and ProcessDataset.objects.filter(sets__overlap=removed, dataset_id=self.instance.id).exists():
# Sets that are used in a ProcessDataset cannot be renamed or deleted
raise ValidationError("These sets cannot be updated because one or more are selected in a dataset process.")
if not removed or not added:
# Some sets have either been added or removed, but not both; do nothing
return sets
elif len(removed) == 1 and len(added) == 1:
# A single set has been renamed. Move its elements later, while performing the update
return sets
else:
raise ValidationError("Updating those sets is ambiguous because several have changed.")
if set_names is not None and len(set_names) == 0:
raise ValidationError("Either do not specify set names to use the default values, or specify a non-empty list of names.")
return set_names
def validate(self, data):
data = super().validate(data)
......@@ -598,13 +590,22 @@ class DatasetSerializer(serializers.ModelSerializer):
return data
@transaction.atomic()
def update(self, instance, validated_data):
removed, added = self.sets_diff(validated_data.get("sets", self.instance.sets))
if len(removed) == 1 and len(added) == 1:
set_from, set_to = *removed, *added
instance.dataset_elements.filter(set=set_from).update(set=set_to)
return super().update(instance, validated_data)
@transaction.atomic
def create(self, validated_data):
set_names = validated_data.pop("set_names")
dataset = Dataset.objects.create(**validated_data)
sets = DatasetSet.objects.bulk_create(
DatasetSet(
name=set_name,
dataset_id=dataset.id
) for set_name in sorted(set_names)
)
# We will output set element counts in the API, but we know there are zero,
# so no need to make another query to prefetch the sets and count them
for set in sets:
set.element_count = 0
add_as_prefetch(dataset.sets, sets)
return dataset
class Meta:
model = Dataset
......@@ -613,6 +614,7 @@ class DatasetSerializer(serializers.ModelSerializer):
"name",
"description",
"sets",
"set_names",
"set_elements",
"state",
# Only the corpus ID is actually serialized
......@@ -647,13 +649,7 @@ class DatasetSerializer(serializers.ModelSerializer):
"sets": {
"error_messages": {
"empty": "Either do not specify set names to use the default values, or specify a non-empty list of names."
},
"help_text": dedent(
"""
Updating the sets array must either add or remove sets (in this case nothing specific is done),
or rename a single set within the array (all elements linked to the previous set will be moved).
"""
).strip(),
}
}
}
......@@ -673,6 +669,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
default=_dataset_from_context,
write_only=True,
)
set = serializers.SlugRelatedField(queryset=DatasetSet.objects.none(), slug_field="name")
class Meta:
model = DatasetElement
......@@ -681,7 +678,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
validators = [
UniqueTogetherValidator(
queryset=DatasetElement.objects.all(),
fields=["dataset", "element_id", "set"],
fields=["element_id", "set"],
message="This element is already part of this set.",
)
]
......@@ -690,13 +687,12 @@ class DatasetElementSerializer(serializers.ModelSerializer):
super().__init__(*args, **kwargs)
if dataset := self.context.get("dataset"):
self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
self.fields["set"].queryset = dataset.sets.all()
def validate_set(self, value):
# The set must match the `sets` array defined at the dataset level
dataset = self.context["dataset"]
if dataset and value not in dataset.sets:
raise ValidationError(f"This dataset has no set named {value}.")
return value
def validate(self, data):
data = super().validate(data)
data.pop("dataset")
return data
class DatasetElementInfoSerializer(DatasetElementSerializer):
......@@ -714,10 +710,11 @@ class DatasetElementInfoSerializer(DatasetElementSerializer):
fields = DatasetElementSerializer.Meta.fields + ("dataset",)
class ElementDatasetSerializer(serializers.ModelSerializer):
dataset = DatasetSerializer()
class ElementDatasetSetSerializer(serializers.ModelSerializer):
dataset = DatasetSerializer(source="set.dataset")
previous = serializers.UUIDField(allow_null=True, read_only=True)
next = serializers.UUIDField(allow_null=True, read_only=True)
set = serializers.SlugRelatedField(slug_field="name", read_only=True)
class Meta:
model = DatasetElement
......@@ -726,35 +723,32 @@ class ElementDatasetSerializer(serializers.ModelSerializer):
class SelectionDatasetElementSerializer(serializers.Serializer):
dataset_id = serializers.PrimaryKeyRelatedField(
queryset=Dataset.objects.all(),
source="dataset",
set_id = serializers.PrimaryKeyRelatedField(
queryset=DatasetSet.objects.none(),
source="set",
write_only=True,
help_text="UUID of a dataset to add elements from your corpus' selection.",
help_text="UUID of a dataset set the elements will be added to.",
style={"base_template": "input.html"},
)
set = serializers.CharField(
max_length=50,
write_only=True,
help_text="Name of the set elements will be added to.",
)
def validate_dataset_id(self, dataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# for openAPI schema generation
if "request" not in self.context:
return
self.fields["set_id"].queryset = DatasetSet.objects.filter(
dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user)
).select_related("dataset")
def validate_set_id(self, set):
if (
(corpus := self.context.get("corpus"))
and dataset.corpus_id != corpus.id
and set.dataset.corpus_id != corpus.id
):
raise ValidationError(f"Dataset {dataset.id} is not part of corpus {corpus.name}.")
if dataset.state == DatasetState.Complete:
raise ValidationError(f"Dataset {dataset.id} is marked as completed.")
return dataset
def validate(self, data):
data = super().validate(data)
dataset = data["dataset"]
if data["set"] not in dataset.sets:
raise ValidationError({"set": [f'This dataset only allows one of {", ".join(dataset.sets)}.']})
return data
raise ValidationError(f"Dataset {set.dataset.id} is not part of corpus {corpus.name}.")
if set.dataset.state == DatasetState.Complete:
raise ValidationError(f"Dataset {set.dataset.id} is marked as completed.")
return set
def create(self, validated_data):
user = self.context["request"].user
......