Skip to content
Snippets Groups Projects
Commit b7129a0a authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'optimize-get-ascending' into 'master'

Optimize get_ascending and ListElementParents

Closes #481

See merge request !1006
parents 9ecc2e1b 71ea7f3a
No related branches found
No related tags found
1 merge request!1006Optimize get_ascending and ListElementParents
from collections import defaultdict
from django.conf import settings
from django.db import transaction
from django.db.models import Q, Prefetch, prefetch_related_objects, Count, Max
from django.db.models import Q, Prefetch, Count, Max
from django.utils.functional import cached_property
from rest_framework.exceptions import ValidationError, NotFound
from rest_framework.generics import (
......@@ -350,37 +350,8 @@ class ElementParents(ElementsListMixin, ListAPIView):
recursive_param = self.request.query_params.get('recursive')
return recursive_param is not None and recursive_param.lower() not in ('false', '0')
def get_filters(self):
filters = super().get_filters()
if not self.is_recursive:
# List direct parents only: elements whose IDs are the last in the element's paths
filters['id__in'] = ElementPath.objects \
.filter(element_id=self.kwargs['pk']) \
.values('path__last')
return filters
def filter_queryset(self, queryset):
if not self.is_recursive:
return super().filter_queryset(queryset)
# Redefine the queryset here as get_ascending needs its filters as an argument
# and returns a list instead of an actual queryset
parents = Element.objects.get_ascending(self.kwargs['pk'], **self.get_filters())
# The list is unsorted; prefetch using prefetch_related_objects and sort with sorted
if not self.is_recursive:
# Apply .distinct() using a dict
parents = list({parent.id: parent for parent in parents}.values())
prefetch_related_objects(parents, *self.get_prefetch())
with_children_count = self.request.query_params.get('with_children_count')
if with_children_count and with_children_count.lower() not in ('false', '0'):
_fetch_children_count(parents)
return sorted(parents, key=operator.attrgetter('corpus', 'type.slug', 'name', 'id'))
def get_queryset(self):
return Element.objects.get_ascending(self.kwargs['pk'], recursive=self.is_recursive)
class ElementChildren(ElementsListMixin, ListAPIView):
......
from django.db import models
from itertools import groupby, chain
from arkindex.project.fields import Unnest
import uuid
class ElementManager(models.Manager):
"""Model manager for elements"""
def get_ascending_paths(self, child_id, **filters):
"""
Get all parent paths for a specific element ID.
"""
# TODO: it should be possible to do in a single query through Django's filtered relations
from arkindex.documents.models import ElementPath
paths = ElementPath.objects.filter(element_id=child_id).values_list('path', flat=True)
parents = {
parent.id: parent
for parent in self.filter(id__in=chain(*paths), **filters)
}
return [
filter(None, [
parents.get(parent_id)
for parent_id in path
])
for path in paths
]
def get_ascending(self, child_id, **filters):
def get_ascending(self, child_id, recursive=True):
"""
Get all parent elements for a specific element ID.
> All the elements from paths
"""
return list(chain(*self.get_ascending_paths(child_id, **filters)))
from arkindex.documents.models import ElementPath
paths = ElementPath.objects.filter(element_id=child_id)
if recursive:
parent_ids = paths.annotate(parent_id=Unnest('path')).values('parent_id')
else:
parent_ids = paths.values('path__last')
return self.filter(id__in=parent_ids)
def get_ascendings_paths(self, *children_ids, **filters):
"""
Get all ascending paths for some elements IDs.
"""
# Loads paths and group them by element ids
from arkindex.documents.models import ElementPath
# Load all parents
parents = {
parent.id: parent
for parent in self.filter(
**filters,
id__in=ElementPath
.objects
.filter(element_id__in=children_ids)
.annotate(parent_id=Unnest('path'))
.values('parent_id')
)
}
# Loads paths and group them by element ids
paths = ElementPath.objects.filter(element_id__in=children_ids).order_by('element_id')
tree = {
elt_id: [p.path for p in elt_paths]
for elt_id, elt_paths in groupby(paths, lambda e: e.element_id)
}
# Load parents elements in bulk
parents = {
parent.id: parent
for parent in self.filter(id__in=chain(*[p.path for p in paths]), **filters)
}
# Put Element instances in paths
return {
elt_id: [
[
......@@ -60,16 +59,6 @@ class ElementManager(models.Manager):
for elt_id, paths in tree.items()
}
def get_ascendings(self, *children_ids, **filters):
"""
Get all parent elements for some element IDs.
"""
# Just assemble the paths together
return {
elt_id: chain(*paths)
for elt_id, paths in self.get_ascendings_paths(*children_ids, **filters).items()
}
def get_descending(self, parent_id, **filters):
"""
Get all child elements for a specific element ID.
......
......@@ -318,7 +318,7 @@ class TestClasses(FixtureAPITestCase):
A non best class validated by a human is considered as best as it is for the human
"""
self.populate_classified_elements()
parent = Element.objects.get_ascending(self.common_children.id)[-1]
parent = Element.objects.get_ascending(self.common_children.id).last()
parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
......@@ -400,7 +400,7 @@ class TestClasses(FixtureAPITestCase):
def test_class_filter_list_parents(self):
self.populate_classified_elements()
parent = Element.objects.get_ascending(self.common_children.id)[-1]
parent = Element.objects.get_ascending(self.common_children.id).last()
parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
with self.assertNumQueries(5):
response = self.client.get(
......
......@@ -28,13 +28,6 @@ class TestElementManager(FixtureTestCase):
ids = Element.objects.get_ascending(self.vol.id)
self.assertCountEqual(ids, [])
def test_get_ascendings(self):
# Use all three pages, expect Volume thrice
ids = Element.objects.get_ascendings(self.p1.id, self.p2.id, self.p3.id)
self.assertCountEqual(ids[self.p1.id], [self.vol])
self.assertCountEqual(ids[self.p2.id], [self.vol])
self.assertCountEqual(ids[self.p3.id], [self.vol])
def test_get_descending(self):
# Use volume, expect all three pages and the act
ids = Element.objects.get_descending(self.vol.id)
......@@ -52,11 +45,6 @@ class TestElementManager(FixtureTestCase):
self.assertCountEqual(ids[self.p2.id], [])
self.assertCountEqual(ids[self.p3.id], [])
def test_get_ascending_paths(self):
paths = Element.objects.get_ascending_paths(self.act.id)
self.assertEqual(len(paths), 1)
self.assertSequenceEqual(list(paths[0]), [self.vol, self.p1])
def test_get_neighbors(self):
with self.assertNumQueries(4):
self.assertEqual(len(Element.objects.get_neighbors(self.p1, 0)), 1)
......
......@@ -112,7 +112,7 @@ class TestParentsElements(FixtureAPITestCase):
def test_parents_with_children_count(self):
surface = Element.objects.get(name='Surface A')
with self.assertNumQueries(6):
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(surface.id)}),
data={'recursive': True, 'with_children_count': True},
......
......@@ -30,6 +30,11 @@ class ArrayRemove(Func):
output_field = fields.ArrayField
class Unnest(Func):
function = 'unnest'
arity = 1
class ArrayField(fields.ArrayField):
"""
An enhanced PostgreSQL ArrayField, adding a `last` transform
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment