From 077793632edeb56626163c2a67058378d7b43937 Mon Sep 17 00:00:00 2001
From: ml bonhomme <bonhomme@teklia.com>
Date: Mon, 19 Feb 2024 15:22:22 +0000
Subject: [PATCH] Retrieve dataset element neighbors with ListElementDatasets

---
 arkindex/training/api.py                     |  75 +++++++-
 arkindex/training/serializers.py             |   5 +-
 arkindex/training/tests/test_datasets_api.py | 170 +++++++++++++++++++
 3 files changed, 247 insertions(+), 3 deletions(-)

diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 9e71b20fda..c8d6ce5830 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -2,7 +2,7 @@ import copy
 from textwrap import dedent
 from uuid import UUID
 
-from django.db import transaction
+from django.db import connection, transaction
 from django.db.models import Q
 from django.shortcuts import get_object_or_404
 from django.utils.functional import cached_property
@@ -25,6 +25,7 @@ from arkindex.documents.models import Corpus, Element
 from arkindex.project.mixins import ACLMixin, CorpusACLMixin, TrainingModelMixin
 from arkindex.project.pagination import CountCursorPagination
 from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
+from arkindex.project.tools import BulkMap
 from arkindex.training.models import (
     Dataset,
     DatasetElement,
@@ -53,6 +54,64 @@ from arkindex.users.models import Role
 from arkindex.users.utils import get_max_level
 
 
+def _fetch_datasetelement_neighbors(datasetelements):
+    """
+    Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements
+    with next and previous attributes.
+    The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and
+    perform the second request *after* DRF's pagination, because there is no way to perform
+    post-processing after pagination in Django without having to use Django private methods.
+    """
+    if not datasetelements:
+        return datasetelements
+    with connection.cursor() as cursor:
+        cursor.execute(
+            """
+            WITH neighbors AS (
+                SELECT
+                    n.id,
+                    lag(element_id) OVER (
+                        partition BY (n.dataset_id, n.set)
+                        order by
+                            n.element_id
+                    ) as previous,
+                    lead(element_id) OVER (
+                        partition BY (n.dataset_id, n.set)
+                        order by
+                            n.element_id
+                    ) as next
+                FROM training_datasetelement as n
+                WHERE (dataset_id, set) IN (
+                    SELECT dataset_id, set
+                    FROM training_datasetelement
+                    WHERE id IN %(ids)s
+                )
+                ORDER BY n.element_id
+            )
+            SELECT
+                neighbors.id, neighbors.previous, neighbors.next
+            FROM
+                neighbors
+            WHERE neighbors.id in %(ids)s
+            """,
+            {"ids": tuple(datasetelement.id for datasetelement in datasetelements)}
+        )
+
+        neighbors = {
+            id: {
+                "previous": previous,
+                "next": next
+            }
+            for id, previous, next in cursor.fetchall()
+        }
+
+    for datasetelement in datasetelements:
+        datasetelement.previous = neighbors[datasetelement.id]["previous"]
+        datasetelement.next = neighbors[datasetelement.id]["next"]
+
+    return datasetelements
+
+
 @extend_schema(tags=["training"])
 @extend_schema_view(
     get=extend_schema(
@@ -823,6 +882,12 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView):
             location=OpenApiParameter.PATH,
             description="ID of the element.",
             required=True,
+        ),
+        OpenApiParameter(
+            "with_neighbors",
+            type=bool,
+            description="Load previous and next elements in the same dataset and set.",
+            required=False,
         )
     ],
 )
@@ -844,12 +909,18 @@ class ElementDatasets(CorpusACLMixin, ListAPIView):
         )
 
     def get_queryset(self):
-        return (
+        qs = (
             self.element.dataset_elements.all()
             .select_related("dataset__creator")
             .order_by("dataset__name", "set", "dataset_id")
         )
 
+        with_neighbors = self.request.query_params.get("with_neighbors", "false")
+        if with_neighbors.lower() not in ("false", "0"):
+            qs = BulkMap(_fetch_datasetelement_neighbors, qs)
+
+        return qs
+
     def get_serializer_context(self):
         context = super().get_serializer_context()
         # Avoids aggregating the number of elements per set on each
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 3d5f9e56a2..b92c44d1c5 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -713,10 +713,13 @@ class DatasetElementInfoSerializer(DatasetElementSerializer):
 
 class ElementDatasetSerializer(serializers.ModelSerializer):
     dataset = DatasetSerializer()
+    previous = serializers.UUIDField(allow_null=True, read_only=True)
+    next = serializers.UUIDField(allow_null=True, read_only=True)
 
     class Meta:
         model = DatasetElement
-        fields = ("dataset", "set")
+        fields = ("dataset", "set", "previous", "next")
+        read_only_fields = ("previous", "next")
 
 
 class SelectionDatasetElementSerializer(serializers.Serializer):
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index dde3c5abaf..653ffa152b 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -1705,6 +1705,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
             ]
         )
 
+    # ListElementDatasets
+
     def test_element_datasets_requires_read_access(self):
         self.client.force_login(self.user)
         private_elt = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="t"), name="elt")
@@ -1749,6 +1751,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
                 "set": "train",
+                "previous": None,
+                "next": None
             }]
         })
 
@@ -1780,6 +1784,75 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
                 "set": "train",
+                "previous": None,
+                "next": None
+            }, {
+                "dataset": {
+                    "id": str(self.dataset.id),
+                    "name": "First Dataset",
+                    "description": "dataset number one",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "state": "open",
+                    "corpus_id": str(self.corpus.id),
+                    "creator": "Test user",
+                    "task_id": None,
+                    "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
+                    "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
+                },
+                "set": "validation",
+                "previous": None,
+                "next": None
+            }, {
+                "dataset": {
+                    "id": str(self.dataset2.id),
+                    "name": "Second Dataset",
+                    "description": "dataset number two",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "state": "open",
+                    "corpus_id": str(self.corpus.id),
+                    "creator": "Test user",
+                    "task_id": None,
+                    "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
+                    "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
+                },
+                "set": "train",
+                "previous": None,
+                "next": None
+            }]
+        })
+
+    def test_element_datasets_with_neighbors_false(self):
+        self.client.force_login(self.user)
+        self.dataset.dataset_elements.create(element=self.page1, set="train")
+        self.dataset.dataset_elements.create(element=self.page1, set="validation")
+        self.dataset2.dataset_elements.create(element=self.page1, set="train")
+        with self.assertNumQueries(7):
+            response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False})
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            "count": 3,
+            "next": None,
+            "number": 1,
+            "previous": None,
+            "results": [{
+                "dataset": {
+                    "id": str(self.dataset.id),
+                    "name": "First Dataset",
+                    "description": "dataset number one",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "state": "open",
+                    "corpus_id": str(self.corpus.id),
+                    "creator": "Test user",
+                    "task_id": None,
+                    "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
+                    "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
+                },
+                "set": "train",
+                "previous": None,
+                "next": None
             }, {
                 "dataset": {
                     "id": str(self.dataset.id),
@@ -1795,6 +1868,8 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
                 },
                 "set": "validation",
+                "previous": None,
+                "next": None
             }, {
                 "dataset": {
                     "id": str(self.dataset2.id),
@@ -1810,6 +1885,101 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
                 },
                 "set": "train",
+                "previous": None,
+                "next": None
+            }]
+        })
+
+    def test_element_datasets_with_neighbors(self):
+        self.client.force_login(self.user)
+        self.dataset.dataset_elements.create(element=self.page1, set="train")
+        self.dataset.dataset_elements.create(element=self.page2, set="train")
+        self.dataset.dataset_elements.create(element=self.page3, set="train")
+        self.dataset.dataset_elements.create(element=self.page1, set="validation")
+        self.dataset2.dataset_elements.create(element=self.page1, set="train")
+        self.dataset2.dataset_elements.create(element=self.page3, set="train")
+
+        # Results are alphabetically ordered and must not depend on the random page UUIDs
+        sorted_dataset_elements = sorted([str(self.page1.id), str(self.page2.id), str(self.page3.id)])
+        page1_index_1 = sorted_dataset_elements.index(str(self.page1.id))
+        sorted_dataset2_elements = sorted([str(self.page1.id), str(self.page3.id)])
+        page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id))
+
+        with self.assertNumQueries(9):
+            response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True})
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertDictEqual(response.json(), {
+            "count": 3,
+            "next": None,
+            "number": 1,
+            "previous": None,
+            "results": [{
+                "dataset": {
+                    "id": str(self.dataset.id),
+                    "name": "First Dataset",
+                    "description": "dataset number one",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "state": "open",
+                    "corpus_id": str(self.corpus.id),
+                    "creator": "Test user",
+                    "task_id": None,
+                    "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
+                    "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
+                },
+                "set": "train",
+                "previous": (
+                    sorted_dataset_elements[page1_index_1 - 1]
+                    if page1_index_1 - 1 >= 0
+                    else None
+                ),
+                "next": (
+                    sorted_dataset_elements[page1_index_1 + 1]
+                    if page1_index_1 + 1 <= 2
+                    else None
+                )
+            }, {
+                "dataset": {
+                    "id": str(self.dataset.id),
+                    "name": "First Dataset",
+                    "description": "dataset number one",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "state": "open",
+                    "corpus_id": str(self.corpus.id),
+                    "creator": "Test user",
+                    "task_id": None,
+                    "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
+                    "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
+                },
+                "set": "validation",
+                "previous": None,
+                "next": None
+            }, {
+                "dataset": {
+                    "id": str(self.dataset2.id),
+                    "name": "Second Dataset",
+                    "description": "dataset number two",
+                    "sets": ["training", "test", "validation"],
+                    "set_elements": None,
+                    "state": "open",
+                    "corpus_id": str(self.corpus.id),
+                    "creator": "Test user",
+                    "task_id": None,
+                    "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
+                    "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
+                },
+                "set": "train",
+                "previous": (
+                    sorted_dataset2_elements[page1_index_2 - 1]
+                    if page1_index_1 - 1 >= 0
+                    else None
+                ),
+                "next": (
+                    sorted_dataset2_elements[page1_index_2 + 1]
+                    if page1_index_1 + 1 <= 1
+                    else None
+                )
             }]
         })
 
-- 
GitLab