From 3608c38463688762c874ff30bee62d945d466396 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Fri, 24 Sep 2021 07:18:50 +0000
Subject: [PATCH] Add sorting parameters on element lists

---
 arkindex/documents/api/elements.py            | 58 ++++++++++++++--
 .../documents/tests/test_children_elements.py | 69 +++++++++++++++++++
 .../documents/tests/test_corpus_elements.py   | 43 ++++++++++++
 .../documents/tests/test_parents_elements.py  | 43 ++++++++++++
 4 files changed, 206 insertions(+), 7 deletions(-)

diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py
index 818ee8a412..8bdaaf5c83 100644
--- a/arkindex/documents/api/elements.py
+++ b/arkindex/documents/api/elements.py
@@ -134,6 +134,8 @@ def _fetch_has_children(elements):
 
 METADATA_OPERATORS = {'eq', 'lt', 'gt', 'lte', 'gte'}
 
+ELEMENT_ORDERING_FIELDS = {'name', 'created'}
+
 
 class ElementsListAutoSchema(AutoSchema):
     """
@@ -213,6 +215,20 @@ class ElementsListAutoSchema(AutoSchema):
                     default='eq',
                     required=False,
                 ),
+                OpenApiParameter(
+                    'order',
+                    description='Sort elements by a specific field',
+                    enum=ELEMENT_ORDERING_FIELDS,
+                    default='name',
+                    required=False,
+                ),
+                OpenApiParameter(
+                    'order_direction',
+                    description='Direction in which to sort elements',
+                    enum={'asc', 'desc'},
+                    default='asc',
+                    required=False,
+                ),
                 OpenApiParameter(
                     'with_best_classes',
                     description='Returns best classifications for each element. '
@@ -499,8 +515,21 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
         return prefetch
 
     def get_order_by(self):
-        # Ordering by ID is required by postgres to order elements with the same name
-        return ('name', 'id')
+        errors = {}
+
+        sort_field = self.clean_params.get('order', 'name').lower()
+        if sort_field not in ELEMENT_ORDERING_FIELDS:
+            errors['order'] = ['Unknown sorting field']
+
+        direction = self.clean_params.get('order_direction', 'asc').lower()
+        if direction not in ('asc', 'desc'):
+            errors['order_direction'] = ['Unknown sorting direction']
+
+        if errors:
+            raise ValidationError(errors)
+
+        # The sorting field is not unique; fallback on ordering by ID to ensure a consistent ordering
+        return (f'{"-" if direction == "desc" else ""}{sort_field}', 'id')
 
     def filter_queryset(self, queryset):
         queryset = queryset \
@@ -583,14 +612,14 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
     get=extend_schema(operation_id='ListElements'),
     delete=extend_schema(operation_id='DestroyElements', description=(
         'Destroy elements in bulk.\n\n'
-        'Requires an **admin** access to the process.'
+        'Requires an **admin** access to the corpus.'
     )),
 )
 class CorpusElements(ElementsListBase):
     """
-    List elements in a corpus and filter by type, name, ML class.
+    List elements in a corpus.
 
-    Requires an **read** access to the corpus.
+    Requires a **read** access to the corpus.
     """
 
     @property
@@ -668,7 +697,15 @@ class ElementParents(ElementsListBase):
     ]
 )
 @extend_schema_view(
-    get=extend_schema(operation_id='ListElementChildren'),
+    get=extend_schema(operation_id='ListElementChildren', parameters=[
+        OpenApiParameter(
+            'order',
+            description='Sort elements by a specific field',
+            enum=ELEMENT_ORDERING_FIELDS | {'position'},
+            default='name',
+            required=False,
+        )
+    ]),
     delete=extend_schema(operation_id='DestroyElementChildren', description=(
         'Delete child elements in bulk.\n\n'
         "Requires an **admin** access to the element's corpus."
@@ -711,7 +748,14 @@ class ElementChildren(ElementsListBase):
         return filters
 
     def get_order_by(self):
-        return ('paths__ordering', 'id')
+        """
+        Override the ElementsListBase ordering to default by path ordering.
+        The only endpoint where a path ordering is made visible in the API
+        is ListElementNeighbors, in a field named `position`.
+        """
+        if self.clean_params.get('order', 'position').lower() == 'position':
+            return ('paths__ordering', 'id')
+        return super().get_order_by()
 
     def get_queryset(self):
         self.element = get_object_or_404(
diff --git a/arkindex/documents/tests/test_children_elements.py b/arkindex/documents/tests/test_children_elements.py
index 0c14cb40c0..32d04bbaaa 100644
--- a/arkindex/documents/tests/test_children_elements.py
+++ b/arkindex/documents/tests/test_children_elements.py
@@ -1,4 +1,5 @@
 import uuid
+from datetime import datetime, timedelta, timezone
 
 from django.urls import reverse
 from rest_framework import status
@@ -477,3 +478,71 @@ class TestChildrenElements(FixtureAPITestCase):
                     [element['name'] for element in response.json()['results']],
                     expected_elements
                 )
+
+    def test_children_invalid_sort(self):
+        cases = [
+            ({'order': 'blah', 'order_direction': 'asc'}, {'order': ['Unknown sorting field']}),
+            ({'order': 'name', 'order_direction': 'left'}, {'order_direction': ['Unknown sorting direction']}),
+            (
+                {'order': 'blah', 'order_direction': 'left'},
+                {'order': ['Unknown sorting field'], 'order_direction': ['Unknown sorting direction']}
+            ),
+        ]
+        for params, expected in cases:
+            with self.subTest(**params):
+                response = self.client.get(
+                    reverse('api:elements-children', kwargs={'pk': self.vol.id}),
+                    params,
+                )
+                self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+                self.assertDictEqual(response.json(), expected)
+
+    def test_children_sort_created(self):
+        elements = list(Element.objects.get_descending(self.vol.id).order_by('?').only('id'))
+        created = datetime(2020, 2, 2, tzinfo=timezone.utc)
+        for element in elements:
+            element.created = created
+            created += timedelta(seconds=1)
+        Element.objects.bulk_update(elements, ['created'])
+
+        cases = [
+            ('asc', elements[:20]),
+            ('desc', reversed(elements[-20:])),
+        ]
+        for direction, expected in cases:
+            with self.subTest(direction=direction), self.assertNumQueries(5):
+                response = self.client.get(
+                    reverse('api:elements-children', kwargs={'pk': self.vol.id}),
+                    {'order': 'created', 'order_direction': direction, 'recursive': True}
+                )
+                self.assertEqual(response.status_code, status.HTTP_200_OK)
+                self.assertListEqual(
+                    [element['id'] for element in response.json()['results']],
+                    [str(element.id) for element in expected]
+                )
+
+    def test_children_sort_name(self):
+        names = list(Element.objects.get_descending(self.vol.id).order_by('name', 'id').values_list('name', flat=True))
+
+        # Ensure the default position ordering returns a different ordering than by name;
+        # otherwise, this test would be useless
+        self.assertNotEqual(
+            list(Element.objects.get_descending(self.vol.id).values_list('name', flat=True)),
+            names,
+        )
+
+        cases = [
+            ('asc', names[:20]),
+            ('desc', list(reversed(names[-20:]))),
+        ]
+        for direction, expected in cases:
+            with self.subTest(direction=direction), self.assertNumQueries(5):
+                response = self.client.get(
+                    reverse('api:elements-children', kwargs={'pk': self.vol.id}),
+                    {'order': 'name', 'order_direction': direction, 'recursive': True}
+                )
+                self.assertEqual(response.status_code, status.HTTP_200_OK)
+                self.assertListEqual(
+                    [element['name'] for element in response.json()['results']],
+                    expected
+                )
diff --git a/arkindex/documents/tests/test_corpus_elements.py b/arkindex/documents/tests/test_corpus_elements.py
index d23d99166d..0513a1a364 100644
--- a/arkindex/documents/tests/test_corpus_elements.py
+++ b/arkindex/documents/tests/test_corpus_elements.py
@@ -1,4 +1,5 @@
 import uuid
+from datetime import datetime, timedelta, timezone
 
 import sqlparse
 from django.db.models.sql.constants import LOUTER
@@ -571,3 +572,45 @@ class TestListElements(FixtureAPITestCase):
         data = response.json()
         self.assertEqual(len(data['results']), 500)
         self.assertEqual(data['count'], 999)
+
+    def test_list_elements_invalid_sort(self):
+        cases = [
+            ({'order': 'blah', 'order_direction': 'asc'}, {'order': ['Unknown sorting field']}),
+            ({'order': 'name', 'order_direction': 'left'}, {'order_direction': ['Unknown sorting direction']}),
+            (
+                {'order': 'blah', 'order_direction': 'left'},
+                {'order': ['Unknown sorting field'], 'order_direction': ['Unknown sorting direction']}
+            ),
+        ]
+        for params, expected in cases:
+            with self.subTest(**params):
+                response = self.client.get(
+                    reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
+                    params,
+                )
+                self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+                self.assertDictEqual(response.json(), expected)
+
+    def test_list_elements_sort_created(self):
+        elements = list(self.corpus.elements.order_by('?').only('id'))
+        created = datetime(2020, 2, 2, tzinfo=timezone.utc)
+        for element in elements:
+            element.created = created
+            created += timedelta(seconds=1)
+        Element.objects.bulk_update(elements, ['created'])
+
+        cases = [
+            ('asc', elements[:20]),
+            ('desc', reversed(elements[-20:])),
+        ]
+        for direction, expected in cases:
+            with self.subTest(direction=direction), self.assertNumQueries(5):
+                response = self.client.get(
+                    reverse('api:corpus-elements', kwargs={'corpus': self.corpus.id}),
+                    {'order': 'created', 'order_direction': direction}
+                )
+                self.assertEqual(response.status_code, status.HTTP_200_OK)
+                self.assertListEqual(
+                    [element['id'] for element in response.json()['results']],
+                    [str(element.id) for element in expected]
+                )
diff --git a/arkindex/documents/tests/test_parents_elements.py b/arkindex/documents/tests/test_parents_elements.py
index 4d9e72752c..19c3cc4aff 100644
--- a/arkindex/documents/tests/test_parents_elements.py
+++ b/arkindex/documents/tests/test_parents_elements.py
@@ -1,4 +1,5 @@
 import uuid
+from datetime import datetime, timedelta, timezone
 
 from django.urls import reverse
 from rest_framework import status
@@ -294,3 +295,45 @@ class TestParentsElements(FixtureAPITestCase):
                     [element['name'] for element in response.json()['results']],
                     expected_elements
                 )
+
+    def test_parents_invalid_sort(self):
+        cases = [
+            ({'order': 'blah', 'order_direction': 'asc'}, {'order': ['Unknown sorting field']}),
+            ({'order': 'name', 'order_direction': 'left'}, {'order_direction': ['Unknown sorting direction']}),
+            (
+                {'order': 'blah', 'order_direction': 'left'},
+                {'order': ['Unknown sorting field'], 'order_direction': ['Unknown sorting direction']}
+            ),
+        ]
+        for params, expected in cases:
+            with self.subTest(**params):
+                response = self.client.get(
+                    reverse('api:elements-parents', kwargs={'pk': self.page.id}),
+                    params,
+                )
+                self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+                self.assertDictEqual(response.json(), expected)
+
+    def test_parents_sort_created(self):
+        elements = list(Element.objects.get_ascending(self.page.id).order_by('?').only('id'))
+        created = datetime(2020, 2, 2, tzinfo=timezone.utc)
+        for element in elements:
+            element.created = created
+            created += timedelta(seconds=1)
+        Element.objects.bulk_update(elements, ['created'])
+
+        cases = [
+            ('asc', elements[:20]),
+            ('desc', reversed(elements[-20:])),
+        ]
+        for direction, expected in cases:
+            with self.subTest(direction=direction), self.assertNumQueries(4):
+                response = self.client.get(
+                    reverse('api:elements-parents', kwargs={'pk': self.page.id}),
+                    {'order': 'created', 'order_direction': direction, 'recursive': True}
+                )
+                self.assertEqual(response.status_code, status.HTTP_200_OK)
+                self.assertListEqual(
+                    [element['id'] for element in response.json()['results']],
+                    [str(element.id) for element in expected]
+                )
-- 
GitLab