From 71ea7f3a70e3bf9ab7ab08fd9a51316ef64999a6 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Wed, 7 Oct 2020 12:18:48 +0000 Subject: [PATCH] Optimize get_ascending and ListElementParents --- arkindex/documents/api/elements.py | 35 +---------- arkindex/documents/managers.py | 63 ++++++++----------- arkindex/documents/tests/test_classes.py | 4 +- .../documents/tests/test_element_manager.py | 12 ---- .../documents/tests/test_parents_elements.py | 2 +- arkindex/project/fields.py | 5 ++ 6 files changed, 37 insertions(+), 84 deletions(-) diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index 3f858a904f..1f8215dcd8 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -1,7 +1,7 @@ from collections import defaultdict from django.conf import settings from django.db import transaction -from django.db.models import Q, Prefetch, prefetch_related_objects, Count, Max +from django.db.models import Q, Prefetch, Count, Max from django.utils.functional import cached_property from rest_framework.exceptions import ValidationError, NotFound from rest_framework.generics import ( @@ -350,37 +350,8 @@ class ElementParents(ElementsListMixin, ListAPIView): recursive_param = self.request.query_params.get('recursive') return recursive_param is not None and recursive_param.lower() not in ('false', '0') - def get_filters(self): - filters = super().get_filters() - - if not self.is_recursive: - # List direct parents only: elements whose IDs are the last in the element's paths - filters['id__in'] = ElementPath.objects \ - .filter(element_id=self.kwargs['pk']) \ - .values('path__last') - - return filters - - def filter_queryset(self, queryset): - if not self.is_recursive: - return super().filter_queryset(queryset) - - # Redefine the queryset here as get_ascending needs its filters as an argument - # and returns a list instead of an actual queryset - parents = Element.objects.get_ascending(self.kwargs['pk'], **self.get_filters()) - - # The list is unsorted; prefetch using prefetch_related_objects and sort with sorted - if not self.is_recursive: - # Apply .distinct() using a dict - parents = list({parent.id: parent for parent in parents}.values()) - - prefetch_related_objects(parents, *self.get_prefetch()) - - with_children_count = self.request.query_params.get('with_children_count') - if with_children_count and with_children_count.lower() not in ('false', '0'): - _fetch_children_count(parents) - - return sorted(parents, key=operator.attrgetter('corpus', 'type.slug', 'name', 'id')) + def get_queryset(self): + return Element.objects.get_ascending(self.kwargs['pk'], recursive=self.is_recursive) class ElementChildren(ElementsListMixin, ListAPIView): diff --git a/arkindex/documents/managers.py b/arkindex/documents/managers.py index f2da97afcf..4888f5a8f2 100644 --- a/arkindex/documents/managers.py +++ b/arkindex/documents/managers.py @@ -1,54 +1,53 @@ from django.db import models from itertools import groupby, chain +from arkindex.project.fields import Unnest import uuid class ElementManager(models.Manager): """Model manager for elements""" - def get_ascending_paths(self, child_id, **filters): - """ - Get all parent paths for a specific element ID. - """ - # TODO: it should be possible to do in a single query through Django's filtered relations - from arkindex.documents.models import ElementPath - paths = ElementPath.objects.filter(element_id=child_id).values_list('path', flat=True) - parents = { - parent.id: parent - for parent in self.filter(id__in=chain(*paths), **filters) - } - return [ - filter(None, [ - parents.get(parent_id) - for parent_id in path - ]) - for path in paths - ] - - def get_ascending(self, child_id, **filters): + def get_ascending(self, child_id, recursive=True): """ Get all parent elements for a specific element ID. > All the elements from paths """ - return list(chain(*self.get_ascending_paths(child_id, **filters))) + from arkindex.documents.models import ElementPath + + paths = ElementPath.objects.filter(element_id=child_id) + if recursive: + parent_ids = paths.annotate(parent_id=Unnest('path')).values('parent_id') + else: + parent_ids = paths.values('path__last') + + return self.filter(id__in=parent_ids) def get_ascendings_paths(self, *children_ids, **filters): """ Get all ascending paths for some elements IDs. """ - # Loads paths and group them by element ids from arkindex.documents.models import ElementPath + # Load all parents + parents = { + parent.id: parent + for parent in self.filter( + **filters, + id__in=ElementPath + .objects + .filter(element_id__in=children_ids) + .annotate(parent_id=Unnest('path')) + .values('parent_id') + ) + } + + # Loads paths and group them by element ids paths = ElementPath.objects.filter(element_id__in=children_ids).order_by('element_id') tree = { elt_id: [p.path for p in elt_paths] for elt_id, elt_paths in groupby(paths, lambda e: e.element_id) } - # Load parents elements in bulk - parents = { - parent.id: parent - for parent in self.filter(id__in=chain(*[p.path for p in paths]), **filters) - } + # Put Element instances in paths return { elt_id: [ [ @@ -60,16 +59,6 @@ class ElementManager(models.Manager): for elt_id, paths in tree.items() } - def get_ascendings(self, *children_ids, **filters): - """ - Get all parent elements for some element IDs. - """ - # Just assemble the paths together - return { - elt_id: chain(*paths) - for elt_id, paths in self.get_ascendings_paths(*children_ids, **filters).items() - } - def get_descending(self, parent_id, **filters): """ Get all child elements for a specific element ID. diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py index 84b0d66b37..1799c075e5 100644 --- a/arkindex/documents/tests/test_classes.py +++ b/arkindex/documents/tests/test_classes.py @@ -318,7 +318,7 @@ class TestClasses(FixtureAPITestCase): A non best class validated by a human is considered as best as it is for the human """ self.populate_classified_elements() - parent = Element.objects.get_ascending(self.common_children.id)[-1] + parent = Element.objects.get_ascending(self.common_children.id).last() parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated) response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), @@ -400,7 +400,7 @@ class TestClasses(FixtureAPITestCase): def test_class_filter_list_parents(self): self.populate_classified_elements() - parent = Element.objects.get_ascending(self.common_children.id)[-1] + parent = Element.objects.get_ascending(self.common_children.id).last() parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated) with self.assertNumQueries(5): response = self.client.get( diff --git a/arkindex/documents/tests/test_element_manager.py b/arkindex/documents/tests/test_element_manager.py index 1fabe46410..85ef5adfeb 100644 --- a/arkindex/documents/tests/test_element_manager.py +++ b/arkindex/documents/tests/test_element_manager.py @@ -28,13 +28,6 @@ class TestElementManager(FixtureTestCase): ids = Element.objects.get_ascending(self.vol.id) self.assertCountEqual(ids, []) - def test_get_ascendings(self): - # Use all three pages, expect Volume thrice - ids = Element.objects.get_ascendings(self.p1.id, self.p2.id, self.p3.id) - self.assertCountEqual(ids[self.p1.id], [self.vol]) - self.assertCountEqual(ids[self.p2.id], [self.vol]) - self.assertCountEqual(ids[self.p3.id], [self.vol]) - def test_get_descending(self): # Use volume, expect all three pages and the act ids = Element.objects.get_descending(self.vol.id) @@ -52,11 +45,6 @@ class TestElementManager(FixtureTestCase): self.assertCountEqual(ids[self.p2.id], []) self.assertCountEqual(ids[self.p3.id], []) - def test_get_ascending_paths(self): - paths = Element.objects.get_ascending_paths(self.act.id) - self.assertEqual(len(paths), 1) - self.assertSequenceEqual(list(paths[0]), [self.vol, self.p1]) - def test_get_neighbors(self): with self.assertNumQueries(4): self.assertEqual(len(Element.objects.get_neighbors(self.p1, 0)), 1) diff --git a/arkindex/documents/tests/test_parents_elements.py b/arkindex/documents/tests/test_parents_elements.py index 1b25a0e96d..d4b7b197ef 100644 --- a/arkindex/documents/tests/test_parents_elements.py +++ b/arkindex/documents/tests/test_parents_elements.py @@ -112,7 +112,7 @@ class TestParentsElements(FixtureAPITestCase): def test_parents_with_children_count(self): surface = Element.objects.get(name='Surface A') - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(surface.id)}), data={'recursive': True, 'with_children_count': True}, diff --git a/arkindex/project/fields.py b/arkindex/project/fields.py index a6b1a0e7ea..e4065e4ad5 100644 --- a/arkindex/project/fields.py +++ b/arkindex/project/fields.py @@ -30,6 +30,11 @@ class ArrayRemove(Func): output_field = fields.ArrayField +class Unnest(Func): + function = 'unnest' + arity = 1 + + class ArrayField(fields.ArrayField): """ An enhanced PostgreSQL ArrayField, adding a `last` transform -- GitLab