From 71ea7f3a70e3bf9ab7ab08fd9a51316ef64999a6 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Wed, 7 Oct 2020 12:18:48 +0000
Subject: [PATCH] Optimize get_ascending and ListElementParents

---
 arkindex/documents/api/elements.py            | 35 +----------
 arkindex/documents/managers.py                | 63 ++++++++-----------
 arkindex/documents/tests/test_classes.py      |  4 +-
 .../documents/tests/test_element_manager.py   | 12 ----
 .../documents/tests/test_parents_elements.py  |  2 +-
 arkindex/project/fields.py                    |  5 ++
 6 files changed, 37 insertions(+), 84 deletions(-)

diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py
index 3f858a904f..1f8215dcd8 100644
--- a/arkindex/documents/api/elements.py
+++ b/arkindex/documents/api/elements.py
@@ -1,7 +1,7 @@
 from collections import defaultdict
 from django.conf import settings
 from django.db import transaction
-from django.db.models import Q, Prefetch, prefetch_related_objects, Count, Max
+from django.db.models import Q, Prefetch, Count, Max
 from django.utils.functional import cached_property
 from rest_framework.exceptions import ValidationError, NotFound
 from rest_framework.generics import (
@@ -350,37 +350,8 @@ class ElementParents(ElementsListMixin, ListAPIView):
         recursive_param = self.request.query_params.get('recursive')
         return recursive_param is not None and recursive_param.lower() not in ('false', '0')
 
-    def get_filters(self):
-        filters = super().get_filters()
-
-        if not self.is_recursive:
-            # List direct parents only: elements whose IDs are the last in the element's paths
-            filters['id__in'] = ElementPath.objects \
-                .filter(element_id=self.kwargs['pk']) \
-                .values('path__last')
-
-        return filters
-
-    def filter_queryset(self, queryset):
-        if not self.is_recursive:
-            return super().filter_queryset(queryset)
-
-        # Redefine the queryset here as get_ascending needs its filters as an argument
-        # and returns a list instead of an actual queryset
-        parents = Element.objects.get_ascending(self.kwargs['pk'], **self.get_filters())
-
-        # The list is unsorted; prefetch using prefetch_related_objects and sort with sorted
-        if not self.is_recursive:
-            # Apply .distinct() using a dict
-            parents = list({parent.id: parent for parent in parents}.values())
-
-        prefetch_related_objects(parents, *self.get_prefetch())
-
-        with_children_count = self.request.query_params.get('with_children_count')
-        if with_children_count and with_children_count.lower() not in ('false', '0'):
-            _fetch_children_count(parents)
-
-        return sorted(parents, key=operator.attrgetter('corpus', 'type.slug', 'name', 'id'))
+    def get_queryset(self):
+        return Element.objects.get_ascending(self.kwargs['pk'], recursive=self.is_recursive)
 
 
 class ElementChildren(ElementsListMixin, ListAPIView):
diff --git a/arkindex/documents/managers.py b/arkindex/documents/managers.py
index f2da97afcf..4888f5a8f2 100644
--- a/arkindex/documents/managers.py
+++ b/arkindex/documents/managers.py
@@ -1,54 +1,53 @@
 from django.db import models
 from itertools import groupby, chain
+from arkindex.project.fields import Unnest
 import uuid
 
 
 class ElementManager(models.Manager):
     """Model manager for elements"""
 
-    def get_ascending_paths(self, child_id, **filters):
-        """
-        Get all parent paths for a specific element ID.
-        """
-        # TODO: it should be possible to do in a single query through Django's filtered relations
-        from arkindex.documents.models import ElementPath
-        paths = ElementPath.objects.filter(element_id=child_id).values_list('path', flat=True)
-        parents = {
-            parent.id: parent
-            for parent in self.filter(id__in=chain(*paths), **filters)
-        }
-        return [
-            filter(None, [
-                parents.get(parent_id)
-                for parent_id in path
-            ])
-            for path in paths
-        ]
-
-    def get_ascending(self, child_id, **filters):
+    def get_ascending(self, child_id, recursive=True):
         """
         Get all parent elements for a specific element ID.
         > All the elements from paths
         """
-        return list(chain(*self.get_ascending_paths(child_id, **filters)))
+        from arkindex.documents.models import ElementPath
+
+        paths = ElementPath.objects.filter(element_id=child_id)
+        if recursive:
+            parent_ids = paths.annotate(parent_id=Unnest('path')).values('parent_id')
+        else:
+            parent_ids = paths.values('path__last')
+
+        return self.filter(id__in=parent_ids)
 
     def get_ascendings_paths(self, *children_ids, **filters):
         """
         Get all ascending paths for some elements IDs.
         """
-        # Loads paths and group them by element ids
         from arkindex.documents.models import ElementPath
+        # Load all parents
+        parents = {
+            parent.id: parent
+            for parent in self.filter(
+                **filters,
+                id__in=ElementPath
+                .objects
+                .filter(element_id__in=children_ids)
+                .annotate(parent_id=Unnest('path'))
+                .values('parent_id')
+            )
+        }
+
+        # Loads paths and group them by element ids
         paths = ElementPath.objects.filter(element_id__in=children_ids).order_by('element_id')
         tree = {
             elt_id: [p.path for p in elt_paths]
             for elt_id, elt_paths in groupby(paths, lambda e: e.element_id)
         }
 
-        # Load parents elements in bulk
-        parents = {
-            parent.id: parent
-            for parent in self.filter(id__in=chain(*[p.path for p in paths]), **filters)
-        }
+        # Put Element instances in paths
         return {
             elt_id: [
                 [
@@ -60,16 +59,6 @@ class ElementManager(models.Manager):
             for elt_id, paths in tree.items()
         }
 
-    def get_ascendings(self, *children_ids, **filters):
-        """
-        Get all parent elements for some element IDs.
-        """
-        # Just assemble the paths together
-        return {
-            elt_id: chain(*paths)
-            for elt_id, paths in self.get_ascendings_paths(*children_ids, **filters).items()
-        }
-
     def get_descending(self, parent_id, **filters):
         """
         Get all child elements for a specific element ID.
diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py
index 84b0d66b37..1799c075e5 100644
--- a/arkindex/documents/tests/test_classes.py
+++ b/arkindex/documents/tests/test_classes.py
@@ -318,7 +318,7 @@ class TestClasses(FixtureAPITestCase):
         A non best class validated by a human is considered as best as it is for the human
         """
         self.populate_classified_elements()
-        parent = Element.objects.get_ascending(self.common_children.id)[-1]
+        parent = Element.objects.get_ascending(self.common_children.id).last()
         parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
         response = self.client.get(
             reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
@@ -400,7 +400,7 @@ class TestClasses(FixtureAPITestCase):
 
     def test_class_filter_list_parents(self):
         self.populate_classified_elements()
-        parent = Element.objects.get_ascending(self.common_children.id)[-1]
+        parent = Element.objects.get_ascending(self.common_children.id).last()
         parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
         with self.assertNumQueries(5):
             response = self.client.get(
diff --git a/arkindex/documents/tests/test_element_manager.py b/arkindex/documents/tests/test_element_manager.py
index 1fabe46410..85ef5adfeb 100644
--- a/arkindex/documents/tests/test_element_manager.py
+++ b/arkindex/documents/tests/test_element_manager.py
@@ -28,13 +28,6 @@ class TestElementManager(FixtureTestCase):
         ids = Element.objects.get_ascending(self.vol.id)
         self.assertCountEqual(ids, [])
 
-    def test_get_ascendings(self):
-        # Use all three pages, expect Volume thrice
-        ids = Element.objects.get_ascendings(self.p1.id, self.p2.id, self.p3.id)
-        self.assertCountEqual(ids[self.p1.id], [self.vol])
-        self.assertCountEqual(ids[self.p2.id], [self.vol])
-        self.assertCountEqual(ids[self.p3.id], [self.vol])
-
     def test_get_descending(self):
         # Use volume, expect all three pages and the act
         ids = Element.objects.get_descending(self.vol.id)
@@ -52,11 +45,6 @@ class TestElementManager(FixtureTestCase):
         self.assertCountEqual(ids[self.p2.id], [])
         self.assertCountEqual(ids[self.p3.id], [])
 
-    def test_get_ascending_paths(self):
-        paths = Element.objects.get_ascending_paths(self.act.id)
-        self.assertEqual(len(paths), 1)
-        self.assertSequenceEqual(list(paths[0]), [self.vol, self.p1])
-
     def test_get_neighbors(self):
         with self.assertNumQueries(4):
             self.assertEqual(len(Element.objects.get_neighbors(self.p1, 0)), 1)
diff --git a/arkindex/documents/tests/test_parents_elements.py b/arkindex/documents/tests/test_parents_elements.py
index 1b25a0e96d..d4b7b197ef 100644
--- a/arkindex/documents/tests/test_parents_elements.py
+++ b/arkindex/documents/tests/test_parents_elements.py
@@ -112,7 +112,7 @@ class TestParentsElements(FixtureAPITestCase):
 
     def test_parents_with_children_count(self):
         surface = Element.objects.get(name='Surface A')
-        with self.assertNumQueries(6):
+        with self.assertNumQueries(7):
             response = self.client.get(
                 reverse('api:elements-parents', kwargs={'pk': str(surface.id)}),
                 data={'recursive': True, 'with_children_count': True},
diff --git a/arkindex/project/fields.py b/arkindex/project/fields.py
index a6b1a0e7ea..e4065e4ad5 100644
--- a/arkindex/project/fields.py
+++ b/arkindex/project/fields.py
@@ -30,6 +30,11 @@ class ArrayRemove(Func):
     output_field = fields.ArrayField
 
 
+class Unnest(Func):
+    function = 'unnest'
+    arity = 1
+
+
 class ArrayField(fields.ArrayField):
     """
     An enhanced PostgreSQL ArrayField, adding a `last` transform
-- 
GitLab