diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index 5e2cf1d6ca7b9f14d89fc7aa6be44d84683da738..18cae43eb22e5acb6a6ba02b5deb1bc9a2c71058 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -268,6 +268,13 @@ class ElementsList(ElementsListMixin, CorpusACLMixin, ListAPIView): 'required': False, 'schema': {'type': 'string', 'format': 'uuid'}, }, + { + 'name': 'top_level', + 'in': 'query', + 'description': 'Only include elements without parent elements (top-level elements).', + 'required': False, + 'schema': {'type': 'boolean', 'default': False}, + } ] }) @@ -281,6 +288,9 @@ class ElementsList(ElementsListMixin, CorpusACLMixin, ListAPIView): raise ValidationError({'corpus': ['Not a valid uuid']}) filters['corpus'] = self.get_corpus(corpus_id) + if self.request.query_params.get('top_level') not in (None, 'false', '0'): + filters['paths__isnull'] = True + return filters diff --git a/arkindex/documents/tests/test_elements_api.py b/arkindex/documents/tests/test_elements_api.py index 376f009d5c3700d9c8ac1ad114bb24282b841f3e..5379ec30ba51068c0fe18888d550a1ead5db338a 100644 --- a/arkindex/documents/tests/test_elements_api.py +++ b/arkindex/documents/tests/test_elements_api.py @@ -1,5 +1,6 @@ from uuid import UUID from django.urls import reverse +from django.db.models.sql.constants import LOUTER from rest_framework import status from arkindex_common.ml_tool import MLToolType from arkindex_common.enums import MetaType, TranscriptionType, EntityType @@ -997,3 +998,25 @@ class TestElementsAPI(FixtureAPITestCase): 'Act 1': 2, } ) + + def test_list_top_level_left_join(self): + """ + Ensure the top_level option on ListElements triggers a left join. + """ + query = self.corpus.elements.filter(paths__isnull=True, name='something').query + # Ensure the documents_elementpath table is joined via a LEFT JOIN + self.assertEqual(query.alias_map['documents_elementpath'].join_type, LOUTER) + # Ensure the IS NULL goes last + self.assertTrue(str(query).endswith('AND "documents_elementpath"."id" IS NULL)')) + + def test_list_top_level(self): + with self.assertNumQueries(5): + response = self.client.get( + reverse('api:elements'), + data={'top_level': True}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertListEqual( + [element['name'] for element in response.json()['results']], + ['Volume 1', 'Volume 2'], + )