Skip to content
Snippets Groups Projects
Commit 07779363 authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Retrieve dataset element neighbors with ListElementDatasets

parent 9f198d4b
No related branches found
No related tags found
1 merge request!2234Retrieve dataset element neighbors with ListElementDatasets
......@@ -2,7 +2,7 @@ import copy
from textwrap import dedent
from uuid import UUID
from django.db import transaction
from django.db import connection, transaction
from django.db.models import Q
from django.shortcuts import get_object_or_404
from django.utils.functional import cached_property
......@@ -25,6 +25,7 @@ from arkindex.documents.models import Corpus, Element
from arkindex.project.mixins import ACLMixin, CorpusACLMixin, TrainingModelMixin
from arkindex.project.pagination import CountCursorPagination
from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import BulkMap
from arkindex.training.models import (
Dataset,
DatasetElement,
......@@ -53,6 +54,64 @@ from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
def _fetch_datasetelement_neighbors(datasetelements):
"""
Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
with next and previous attributes.
The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and
perform the second request *after* DRF's pagination, because there is no way to perform
post-processing after pagination in Django without having to use Django private methods.
"""
if not datasetelements:
return datasetelements
with connection.cursor() as cursor:
cursor.execute(
"""
WITH neighbors AS (
SELECT
n.id,
lag(element_id) OVER (
partition BY (n.dataset_id, n.set)
order by
n.element_id
) as previous,
lead(element_id) OVER (
partition BY (n.dataset_id, n.set)
order by
n.element_id
) as next
FROM training_datasetelement as n
WHERE (dataset_id, set) IN (
SELECT dataset_id, set
FROM training_datasetelement
WHERE id IN %(ids)s
)
ORDER BY n.element_id
)
SELECT
neighbors.id, neighbors.previous, neighbors.next
FROM
neighbors
WHERE neighbors.id in %(ids)s
""",
{"ids": tuple(datasetelement.id for datasetelement in datasetelements)}
)
neighbors = {
id: {
"previous": previous,
"next": next
}
for id, previous, next in cursor.fetchall()
}
for datasetelement in datasetelements:
datasetelement.previous = neighbors[datasetelement.id]["previous"]
datasetelement.next = neighbors[datasetelement.id]["next"]
return datasetelements
@extend_schema(tags=["training"])
@extend_schema_view(
get=extend_schema(
......@@ -823,6 +882,12 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
location=OpenApiParameter.PATH,
description="ID of the element.",
required=True,
),
OpenApiParameter(
"with_neighbors",
type=bool,
description="Load previous and next elements in the same dataset and set.",
required=False,
)
],
)
......@@ -844,12 +909,18 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
)
def get_queryset(self):
return (
qs = (
self.element.dataset_elements.all()
.select_related("dataset__creator")
.order_by("dataset__name", "set", "dataset_id")
)
with_neighbors = self.request.query_params.get("with_neighbors", "false")
if with_neighbors.lower() not in ("false", "0"):
qs = BulkMap(_fetch_datasetelement_neighbors, qs)
return qs
def get_serializer_context(self):
context = super().get_serializer_context()
# Avoids aggregating the number of elements per set on each
......
......@@ -713,10 +713,13 @@ class DatasetElementInfoSerializer(DatasetElementSerializer):
class ElementDatasetSerializer(serializers.ModelSerializer):
dataset = DatasetSerializer()
previous = serializers.UUIDField(allow_null=True, read_only=True)
next = serializers.UUIDField(allow_null=True, read_only=True)
class Meta:
model = DatasetElement
fields = ("dataset", "set")
fields = ("dataset", "set", "previous", "next")
read_only_fields = ("previous", "next")
class SelectionDatasetElementSerializer(serializers.Serializer):
......
......@@ -1705,6 +1705,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
]
)
# ListElementDatasets
def test_element_datasets_requires_read_access(self):
self.client.force_login(self.user)
private_elt = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="t"), name="elt")
......@@ -1749,6 +1751,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
},
"set": "train",
"previous": None,
"next": None
}]
})
......@@ -1780,6 +1784,75 @@ class TestDatasetsAPI(FixtureAPITestCase):
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
},
"set": "train",
"previous": None,
"next": None
}, {
"dataset": {
"id": str(self.dataset.id),
"name": "First Dataset",
"description": "dataset number one",
"sets": ["training", "test", "validation"],
"set_elements": None,
"state": "open",
"corpus_id": str(self.corpus.id),
"creator": "Test user",
"task_id": None,
"created": self.dataset.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
},
"set": "validation",
"previous": None,
"next": None
}, {
"dataset": {
"id": str(self.dataset2.id),
"name": "Second Dataset",
"description": "dataset number two",
"sets": ["training", "test", "validation"],
"set_elements": None,
"state": "open",
"corpus_id": str(self.corpus.id),
"creator": "Test user",
"task_id": None,
"created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
},
"set": "train",
"previous": None,
"next": None
}]
})
def test_element_datasets_with_neighbors_false(self):
self.client.force_login(self.user)
self.dataset.dataset_elements.create(element=self.page1, set="train")
self.dataset.dataset_elements.create(element=self.page1, set="validation")
self.dataset2.dataset_elements.create(element=self.page1, set="train")
with self.assertNumQueries(7):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"count": 3,
"next": None,
"number": 1,
"previous": None,
"results": [{
"dataset": {
"id": str(self.dataset.id),
"name": "First Dataset",
"description": "dataset number one",
"sets": ["training", "test", "validation"],
"set_elements": None,
"state": "open",
"corpus_id": str(self.corpus.id),
"creator": "Test user",
"task_id": None,
"created": self.dataset.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
},
"set": "train",
"previous": None,
"next": None
}, {
"dataset": {
"id": str(self.dataset.id),
......@@ -1795,6 +1868,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
},
"set": "validation",
"previous": None,
"next": None
}, {
"dataset": {
"id": str(self.dataset2.id),
......@@ -1810,6 +1885,101 @@ class TestDatasetsAPI(FixtureAPITestCase):
"updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
},
"set": "train",
"previous": None,
"next": None
}]
})
def test_element_datasets_with_neighbors(self):
self.client.force_login(self.user)
self.dataset.dataset_elements.create(element=self.page1, set="train")
self.dataset.dataset_elements.create(element=self.page2, set="train")
self.dataset.dataset_elements.create(element=self.page3, set="train")
self.dataset.dataset_elements.create(element=self.page1, set="validation")
self.dataset2.dataset_elements.create(element=self.page1, set="train")
self.dataset2.dataset_elements.create(element=self.page3, set="train")
# Results are alphabetically ordered and must not depend on the random page UUIDs
sorted_dataset_elements = sorted([str(self.page1.id), str(self.page2.id), str(self.page3.id)])
page1_index_1 = sorted_dataset_elements.index(str(self.page1.id))
sorted_dataset2_elements = sorted([str(self.page1.id), str(self.page3.id)])
page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id))
with self.assertNumQueries(9):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"count": 3,
"next": None,
"number": 1,
"previous": None,
"results": [{
"dataset": {
"id": str(self.dataset.id),
"name": "First Dataset",
"description": "dataset number one",
"sets": ["training", "test", "validation"],
"set_elements": None,
"state": "open",
"corpus_id": str(self.corpus.id),
"creator": "Test user",
"task_id": None,
"created": self.dataset.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
},
"set": "train",
"previous": (
sorted_dataset_elements[page1_index_1 - 1]
if page1_index_1 - 1 >= 0
else None
),
"next": (
sorted_dataset_elements[page1_index_1 + 1]
if page1_index_1 + 1 <= 2
else None
)
}, {
"dataset": {
"id": str(self.dataset.id),
"name": "First Dataset",
"description": "dataset number one",
"sets": ["training", "test", "validation"],
"set_elements": None,
"state": "open",
"corpus_id": str(self.corpus.id),
"creator": "Test user",
"task_id": None,
"created": self.dataset.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
},
"set": "validation",
"previous": None,
"next": None
}, {
"dataset": {
"id": str(self.dataset2.id),
"name": "Second Dataset",
"description": "dataset number two",
"sets": ["training", "test", "validation"],
"set_elements": None,
"state": "open",
"corpus_id": str(self.corpus.id),
"creator": "Test user",
"task_id": None,
"created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
},
"set": "train",
"previous": (
sorted_dataset2_elements[page1_index_2 - 1]
if page1_index_1 - 1 >= 0
else None
),
"next": (
sorted_dataset2_elements[page1_index_2 + 1]
if page1_index_1 + 1 <= 1
else None
)
}]
})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment