From 077793632edeb56626163c2a67058378d7b43937 Mon Sep 17 00:00:00 2001 From: ml bonhomme <bonhomme@teklia.com> Date: Mon, 19 Feb 2024 15:22:22 +0000 Subject: [PATCH] Retrieve dataset element neighbors with ListElementDatasets --- arkindex/training/api.py | 75 +++++++- arkindex/training/serializers.py | 5 +- arkindex/training/tests/test_datasets_api.py | 170 +++++++++++++++++++ 3 files changed, 247 insertions(+), 3 deletions(-) diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 9e71b20fda..c8d6ce5830 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -2,7 +2,7 @@ import copy from textwrap import dedent from uuid import UUID -from django.db import transaction +from django.db import connection, transaction from django.db.models import Q from django.shortcuts import get_object_or_404 from django.utils.functional import cached_property @@ -25,6 +25,7 @@ from arkindex.documents.models import Corpus, Element from arkindex.project.mixins import ACLMixin, CorpusACLMixin, TrainingModelMixin from arkindex.project.pagination import CountCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly +from arkindex.project.tools import BulkMap from arkindex.training.models import ( Dataset, DatasetElement, @@ -53,6 +54,64 @@ from arkindex.users.models import Role from arkindex.users.utils import get_max_level +def _fetch_datasetelement_neighbors(datasetelements): + """ + Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements + with next and previous attributes. + The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and + perform the second request *after* DRF's pagination, because there is no way to perform + post-processing after pagination in Django without having to use Django private methods. + """ + if not datasetelements: + return datasetelements + with connection.cursor() as cursor: + cursor.execute( + """ + WITH neighbors AS ( + SELECT + n.id, + lag(element_id) OVER ( + partition BY (n.dataset_id, n.set) + order by + n.element_id + ) as previous, + lead(element_id) OVER ( + partition BY (n.dataset_id, n.set) + order by + n.element_id + ) as next + FROM training_datasetelement as n + WHERE (dataset_id, set) IN ( + SELECT dataset_id, set + FROM training_datasetelement + WHERE id IN %(ids)s + ) + ORDER BY n.element_id + ) + SELECT + neighbors.id, neighbors.previous, neighbors.next + FROM + neighbors + WHERE neighbors.id in %(ids)s + """, + {"ids": tuple(datasetelement.id for datasetelement in datasetelements)} + ) + + neighbors = { + id: { + "previous": previous, + "next": next + } + for id, previous, next in cursor.fetchall() + } + + for datasetelement in datasetelements: + datasetelement.previous = neighbors[datasetelement.id]["previous"] + datasetelement.next = neighbors[datasetelement.id]["next"] + + return datasetelements + + @extend_schema(tags=["training"]) @extend_schema_view( get=extend_schema( @@ -823,6 +882,12 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView): location=OpenApiParameter.PATH, description="ID of the element.", required=True, + ), + OpenApiParameter( + "with_neighbors", + type=bool, + description="Load previous and next elements in the same dataset and set.", + required=False, ) ], ) @@ -844,12 +909,18 @@ class ElementDatasets(CorpusACLMixin, ListAPIView): ) def get_queryset(self): - return ( + qs = ( self.element.dataset_elements.all() .select_related("dataset__creator") .order_by("dataset__name", "set", "dataset_id") ) + with_neighbors = self.request.query_params.get("with_neighbors", "false") + if with_neighbors.lower() not in ("false", "0"): + qs = BulkMap(_fetch_datasetelement_neighbors, qs) + + return qs + def get_serializer_context(self): context = super().get_serializer_context() # Avoids aggregating the number of elements per set on each diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 3d5f9e56a2..b92c44d1c5 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -713,10 +713,13 @@ class DatasetElementInfoSerializer(DatasetElementSerializer): class ElementDatasetSerializer(serializers.ModelSerializer): dataset = DatasetSerializer() + previous = serializers.UUIDField(allow_null=True, read_only=True) + next = serializers.UUIDField(allow_null=True, read_only=True) class Meta: model = DatasetElement - fields = ("dataset", "set") + fields = ("dataset", "set", "previous", "next") + read_only_fields = ("previous", "next") class SelectionDatasetElementSerializer(serializers.Serializer): diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index dde3c5abaf..653ffa152b 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -1705,6 +1705,8 @@ class TestDatasetsAPI(FixtureAPITestCase): ] ) + # ListElementDatasets + def test_element_datasets_requires_read_access(self): self.client.force_login(self.user) private_elt = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="t"), name="elt") @@ -1749,6 +1751,8 @@ class TestDatasetsAPI(FixtureAPITestCase): "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, "set": "train", + "previous": None, + "next": None }] }) @@ -1780,6 +1784,75 @@ class TestDatasetsAPI(FixtureAPITestCase): "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, "set": "train", + "previous": None, + "next": None + }, { + "dataset": { + "id": str(self.dataset.id), + "name": "First Dataset", + "description": "dataset number one", + "sets": ["training", "test", "validation"], + "set_elements": None, + "state": "open", + "corpus_id": str(self.corpus.id), + "creator": "Test user", + "task_id": None, + "created": self.dataset.created.isoformat().replace("+00:00", "Z"), + "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), + }, + "set": "validation", + "previous": None, + "next": None + }, { + "dataset": { + "id": str(self.dataset2.id), + "name": "Second Dataset", + "description": "dataset number two", + "sets": ["training", "test", "validation"], + "set_elements": None, + "state": "open", + "corpus_id": str(self.corpus.id), + "creator": "Test user", + "task_id": None, + "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), + "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), + }, + "set": "train", + "previous": None, + "next": None + }] + }) + + def test_element_datasets_with_neighbors_false(self): + self.client.force_login(self.user) + self.dataset.dataset_elements.create(element=self.page1, set="train") + self.dataset.dataset_elements.create(element=self.page1, set="validation") + self.dataset2.dataset_elements.create(element=self.page1, set="train") + with self.assertNumQueries(7): + response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 3, + "next": None, + "number": 1, + "previous": None, + "results": [{ + "dataset": { + "id": str(self.dataset.id), + "name": "First Dataset", + "description": "dataset number one", + "sets": ["training", "test", "validation"], + "set_elements": None, + "state": "open", + "corpus_id": str(self.corpus.id), + "creator": "Test user", + "task_id": None, + "created": self.dataset.created.isoformat().replace("+00:00", "Z"), + "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), + }, + "set": "train", + "previous": None, + "next": None }, { "dataset": { "id": str(self.dataset.id), @@ -1795,6 +1868,8 @@ class TestDatasetsAPI(FixtureAPITestCase): "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, "set": "validation", + "previous": None, + "next": None }, { "dataset": { "id": str(self.dataset2.id), @@ -1810,6 +1885,101 @@ class TestDatasetsAPI(FixtureAPITestCase): "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), }, "set": "train", + "previous": None, + "next": None + }] + }) + + def test_element_datasets_with_neighbors(self): + self.client.force_login(self.user) + self.dataset.dataset_elements.create(element=self.page1, set="train") + self.dataset.dataset_elements.create(element=self.page2, set="train") + self.dataset.dataset_elements.create(element=self.page3, set="train") + self.dataset.dataset_elements.create(element=self.page1, set="validation") + self.dataset2.dataset_elements.create(element=self.page1, set="train") + self.dataset2.dataset_elements.create(element=self.page3, set="train") + + # Results are alphabetically ordered and must not depend on the random page UUIDs + sorted_dataset_elements = sorted([str(self.page1.id), str(self.page2.id), str(self.page3.id)]) + page1_index_1 = sorted_dataset_elements.index(str(self.page1.id)) + sorted_dataset2_elements = sorted([str(self.page1.id), str(self.page3.id)]) + page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id)) + + with self.assertNumQueries(9): + response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 3, + "next": None, + "number": 1, + "previous": None, + "results": [{ + "dataset": { + "id": str(self.dataset.id), + "name": "First Dataset", + "description": "dataset number one", + "sets": ["training", "test", "validation"], + "set_elements": None, + "state": "open", + "corpus_id": str(self.corpus.id), + "creator": "Test user", + "task_id": None, + "created": self.dataset.created.isoformat().replace("+00:00", "Z"), + "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), + }, + "set": "train", + "previous": ( + sorted_dataset_elements[page1_index_1 - 1] + if page1_index_1 - 1 >= 0 + else None + ), + "next": ( + sorted_dataset_elements[page1_index_1 + 1] + if page1_index_1 + 1 <= 2 + else None + ) + }, { + "dataset": { + "id": str(self.dataset.id), + "name": "First Dataset", + "description": "dataset number one", + "sets": ["training", "test", "validation"], + "set_elements": None, + "state": "open", + "corpus_id": str(self.corpus.id), + "creator": "Test user", + "task_id": None, + "created": self.dataset.created.isoformat().replace("+00:00", "Z"), + "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), + }, + "set": "validation", + "previous": None, + "next": None + }, { + "dataset": { + "id": str(self.dataset2.id), + "name": "Second Dataset", + "description": "dataset number two", + "sets": ["training", "test", "validation"], + "set_elements": None, + "state": "open", + "corpus_id": str(self.corpus.id), + "creator": "Test user", + "task_id": None, + "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), + "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), + }, + "set": "train", + "previous": ( + sorted_dataset2_elements[page1_index_2 - 1] + if page1_index_1 - 1 >= 0 + else None + ), + "next": ( + sorted_dataset2_elements[page1_index_2 + 1] + if page1_index_1 + 1 <= 1 + else None + ) }] }) -- GitLab