diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index 35879bfb5d4925f9fee3e2f903e7facbbea218ea..dfb651edbd1fad6d1c145a8fdce7a8258437066b 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -13,7 +13,7 @@ from rest_framework.views import APIView from rest_framework.response import Response from rest_framework import status from rest_framework.exceptions import ValidationError -from arkindex.project.mixins import CorpusACLMixin, SelectionMixin, DeprecatedMixin +from arkindex.project.mixins import CorpusACLMixin, SelectionMixin, DeprecatedMixin, CustomPaginationViewMixin from arkindex.project.permissions import IsVerified from arkindex.project.openapi import AutoSchema from arkindex.documents.models import Corpus, ElementType, Element, ClassificationState @@ -861,7 +861,7 @@ class ImportTranskribus(CreateAPIView): self.dataimport.start(thumbnails=True) -class ListProcessElements(ListAPIView): +class ListProcessElements(CustomPaginationViewMixin, ListAPIView): """ List all elements for a specific process """ @@ -966,4 +966,4 @@ class ListProcessElements(ListAPIView): if dataimport.mode not in (DataImportMode.Elements, DataImportMode.Workers): return Element.objects.none() - return self.retrieve_elements(dataimport).order_by('id') + return self.retrieve_elements(dataimport) diff --git a/arkindex/dataimport/tests/test_process_elements.py b/arkindex/dataimport/tests/test_process_elements.py index 6caef5db62c89b603f84eb607d212739946075b9..fcebe116408debb879abd20fdebe2553387a78ca 100644 --- a/arkindex/dataimport/tests/test_process_elements.py +++ b/arkindex/dataimport/tests/test_process_elements.py @@ -4,6 +4,7 @@ from rest_framework import status from arkindex.project.tests import FixtureAPITestCase from arkindex.documents.models import Element, Classification, DataSource, MLClass, ClassificationState, Corpus from arkindex.dataimport.models import DataImport, DataImportMode +import uuid class TestProcessElements(FixtureAPITestCase): @@ -165,44 +166,42 @@ class TestProcessElements(FixtureAPITestCase): source=source ) - def test_requires_login(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Images, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id + def setUp(self): + super().setUp() + self.dataimport = DataImport.objects.create( + creator_id=self.user.id, + mode=DataImportMode.Elements, + corpus_id=self.private_corpus.id ) - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + + def test_requires_login(self): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_wrong_id(self): self.client.force_login(self.superuser) - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'})) + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': str(uuid.uuid4())})) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_no_access(self): - corpus = Corpus.objects.create(name='private') - dataimport = DataImport.objects.create( - mode=DataImportMode.Images, - corpus_id=corpus.id, - creator_id=self.superuser.id - ) + self.dataimport.corpus = Corpus.objects.create(name='private') + self.dataimport.save() self.client.force_login(self.user) - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_filter_elements_wrong_corpus(self): - corpus = Corpus.objects.create(name='private') - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=corpus.id, - creator_id=self.superuser.id, - ) - dataimport.elements.add(self.page_1.id, self.folder_2.id) + self.dataimport.corpus = Corpus.objects.create(name='private') + self.dataimport.save() + self.dataimport.elements.add(self.page_1.id, self.folder_2.id) + self.client.force_login(self.superuser) - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) data = response.json() - self.assertEqual(data, {'elements': [f'Some elements on this process are not part of corpus {corpus.id}']}) + self.assertEqual(data, { + 'elements': [f'Some elements on this process are not part of corpus {self.dataimport.corpus.id}'] + }) def test_filter_elements_multiple_corpus(self): corpus = Corpus.objects.create(name='private') @@ -211,46 +210,41 @@ class TestProcessElements(FixtureAPITestCase): type=self.folder_type, name="Folder", ) - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=corpus.id, - creator_id=self.superuser.id, - ) - dataimport.elements.add(element.id, self.page_1.id) + self.dataimport.elements.add(element.id, self.page_1.id) + self.client.force_login(self.superuser) - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) data = response.json() - self.assertEqual(data, {'elements': [f'Some elements on this process are not part of corpus {corpus.id}']}) + self.assertEqual(data, { + 'elements': [f'Some elements on this process are not part of corpus {self.dataimport.corpus.id}'] + }) def test_filter_element_wrong_corpus(self): - corpus = Corpus.objects.create(name='private') - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=corpus.id, - creator_id=self.superuser.id, - element=self.page_1 - ) + self.dataimport.corpus = Corpus.objects.create(name='private') + self.dataimport.element = self.page_1 + self.dataimport.save() + self.client.force_login(self.superuser) - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) data = response.json() - self.assertEqual(data, {'element_id': [f'Element {self.page_1.id} is not part of corpus {corpus.id}']}) + self.assertEqual(data, { + 'element_id': [f'Element {self.page_1.id} is not part of corpus {self.dataimport.corpus.id}'] + }) def test_filter_name(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - name_contains="rhum" - ) + self.dataimport.name_contains = "rhum" + self.dataimport.save() elements = [self.folder_1, self.page_1, self.page_5] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -261,21 +255,19 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_name_from_element(self): - dataimport = DataImport.objects.create( - creator_id=self.superuser.id, - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - element=self.folder_1, - name_contains="rhum", - load_children=True - ) + self.dataimport.element = self.folder_1 + self.dataimport.name_contains = "rhum" + self.dataimport.load_children = True + self.dataimport.save() elements = [self.folder_1, self.page_1] + self.client.force_login(self.superuser) - with self.assertNumQueries(10): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(9): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -286,19 +278,17 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_type(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - element_type=self.folder_type - ) + self.dataimport.element_type = self.folder_type + self.dataimport.save() elements = [self.folder_1, self.folder_2] + self.client.force_login(self.superuser) - with self.assertNumQueries(8): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(7): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -309,21 +299,19 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_type_from_element(self): - dataimport = DataImport.objects.create( - creator_id=self.superuser.id, - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - element=self.folder_1, - element_type=self.line_type, - load_children=True - ) + self.dataimport.element = self.folder_1 + self.dataimport.element_type = self.line_type + self.dataimport.load_children = True + self.dataimport.save() elements = [self.line_1, self.line_2, self.line_3] + self.client.force_login(self.superuser) - with self.assertNumQueries(11): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(10): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -334,19 +322,17 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_best_class_by_id(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - best_class=self.food_source.id - ) + self.dataimport.best_class = self.food_source.id + self.dataimport.save() elements = [self.page_5, self.page_3, self.folder_2, self.page_2] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -357,19 +343,17 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_any_best_class(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - best_class="true" - ) + self.dataimport.best_class = "true" + self.dataimport.save() elements = [self.page_1, self.page_5, self.page_3, self.folder_2, self.page_2] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -380,19 +364,17 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_no_best_class(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - best_class="false" - ) + self.dataimport.best_class = "false" + self.dataimport.save() elements = [self.folder_1, self.line_1, self.line_2, self.line_3, self.line_4, self.line_5, self.page_4] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -403,18 +385,16 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_element(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - element=self.page_1 - ) + self.dataimport.element = self.page_1 + self.dataimport.save() + self.client.force_login(self.superuser) - with self.assertNumQueries(8): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(7): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], 1) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(self.page_1.id), @@ -424,18 +404,15 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_filter_elements(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - ) - dataimport.elements.add(self.page_1.id, self.folder_2.id) + self.dataimport.elements.add(self.page_1.id, self.folder_2.id) + self.client.force_login(self.superuser) - with self.assertNumQueries(10): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(9): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], 2) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(self.page_1.id), @@ -449,43 +426,19 @@ class TestProcessElements(FixtureAPITestCase): } ]) - def test_any_filter(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - ) - elements = Element.objects.filter(corpus=self.private_corpus).order_by('name', 'type__slug') - self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data["count"], elements.count()) - self.assertCountEqual(data["results"], [ - { - 'id': str(element.id), - 'type': element.type.slug, - 'name': element.name - } - for element in elements - ]) - def test_load_children_and_filter_name(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - name_contains="rhum", - load_children=True - ) + self.dataimport.name_contains = "rhum" + self.dataimport.load_children = True + self.dataimport.save() elements = [self.folder_1, self.page_1, self.page_5] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -496,19 +449,18 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_corpus_and_filter_type(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - element_type=self.folder_type, - load_children=True - ) + self.dataimport.element_type = self.folder_type + self.dataimport.load_children = True + self.dataimport.save() elements = [self.folder_1, self.folder_2] + self.client.force_login(self.superuser) - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(7): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -519,21 +471,18 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_load_children_and_filter_best_class_by_id(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - best_class=self.food_source.id, - load_children=True - - ) + self.dataimport.best_class = self.food_source.id + self.dataimport.load_children = True + self.dataimport.save() elements = [self.folder_2, self.page_2, self.page_3, self.page_5] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -544,22 +493,19 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_load_children_and_filter_best_class(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - best_class="true", - load_children=True - - ) + self.dataimport.best_class = "true", + self.dataimport.load_children = True + self.dataimport.save() elements = [self.page_1, self.page_5, self.page_3, self.folder_2, self.page_2] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): + with self.assertNumQueries(6): response = self.client.get( - reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -570,20 +516,18 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_load_children_and_filter_no_best_class(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - best_class="false", - load_children=True - ) + self.dataimport.best_class = "false" + self.dataimport.load_children = True + self.dataimport.save() elements = [self.folder_1, self.line_1, self.line_2, self.line_3, self.line_4, self.line_5, self.page_4] + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -594,20 +538,18 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_load_children_and_filter_element(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - element=self.folder_1, - load_children=True - ) + self.dataimport.element = self.folder_1 + self.dataimport.load_children = True + self.dataimport.save() elements = [self.folder_1, self.page_1, self.page_3, self.line_1, self.line_2, self.line_3, self.page_2] + self.client.force_login(self.superuser) - with self.assertNumQueries(10): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(9): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -618,20 +560,18 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_load_children_and_filter_elements(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - load_children=True - ) - dataimport.elements.add(self.page_1.id, self.folder_2.id) + self.dataimport.load_children = True + self.dataimport.save() + self.dataimport.elements.add(self.page_1.id, self.folder_2.id) elements = [self.page_1, self.page_5, self.page_3, self.line_1, self.line_3, self.line_4, self.line_5, self.folder_2, self.page_4] + self.client.force_login(self.superuser) - with self.assertNumQueries(12): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(11): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], len(elements)) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -642,19 +582,17 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_load_children_and_any_filter(self): - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - load_children=True - ) + self.dataimport.load_children = True + self.dataimport.save() elements = Element.objects.filter(corpus=self.private_corpus).order_by('name', 'type__slug') + self.client.force_login(self.superuser) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], elements.count()) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -665,35 +603,29 @@ class TestProcessElements(FixtureAPITestCase): ]) def test_all_modes(self): + self.dataimport.load_children = True + self.dataimport.save() self.client.force_login(self.superuser) for mode in (DataImportMode.Images, DataImportMode.PDF, DataImportMode.Repository, DataImportMode.IIIF, DataImportMode.Transkribus): - dataimport = DataImport.objects.create( - mode=mode, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - load_children=True - ) + self.dataimport.mode = mode + self.dataimport.save() with self.assertNumQueries(3): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], 0) - self.assertCountEqual(data["results"], []) + self.assertCountEqual(data, {'count': None, 'next': None, 'previous': None, 'results': []}) elements = Element.objects.filter(corpus=self.private_corpus).order_by('name', 'type__slug') for mode in (DataImportMode.Elements, DataImportMode.Workers): - dataimport = DataImport.objects.create( - mode=mode, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - load_children=True - ) - with self.assertNumQueries(7): - response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) + self.dataimport.mode = mode + self.dataimport.save() + with self.assertNumQueries(6): + response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data["count"], elements.count()) + self.assertEqual(data["count"], None) + self.assertEqual(data["next"], None) self.assertCountEqual(data["results"], [ { 'id': str(element.id), @@ -703,33 +635,92 @@ class TestProcessElements(FixtureAPITestCase): for element in elements ]) - def test_list_elements_ordering(self): + def test_list_elements_cursor_pagination(self): """ No element is duplicated or dropped """ elts = Element.objects.bulk_create([ Element( corpus=self.private_corpus, - name="Similar name", + name='Similar name', type=self.page_type ) for i in range(40) ]) for elt in elts: elt.add_parent(self.folder_1) - dataimport = DataImport.objects.create( - mode=DataImportMode.Elements, - corpus_id=self.private_corpus.id, - creator_id=self.superuser.id, - element=self.folder_1, - load_children=True - ) + self.dataimport.element = self.folder_1 + self.dataimport.name_contains = 'Similar' + self.dataimport.load_children = True + self.dataimport.save() + self.client.force_login(self.superuser) - with self.assertNumQueries(10): - page_1 = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) - with self.assertNumQueries(10): - page_2 = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id}), {'page': 2}) + with self.assertNumQueries(9): + page_1 = self.client.get(reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id})) + self.assertEqual(len(page_1.json()['results']), 20) + next_page = page_1.json().get('next') + self.assertIsNotNone(next_page) + with self.assertNumQueries(9): + page_2 = self.client.get(next_page) + self.assertIsNone(page_2.json()['next']) qs_1 = Element.objects.filter(id__in=[elt['id'] for elt in page_1.json()['results']]) qs_2 = Element.objects.filter(id__in=[elt['id'] for elt in page_2.json()['results']]) self.assertEqual(qs_1.intersection(qs_2).count(), 0) self.assertEqual(qs_1.union(qs_2).distinct().count(), 40) + + def test_elements_invalid_cursor(self): + self.client.force_login(self.superuser) + response = self.client.get( + reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}), + {'cursor': 'ABC', 'with_count': True} + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.json(), {'detail': 'Invalid cursor'}) + + def test_elements_count(self): + """ + Elements count can be retrieved when no cursor is set + """ + self.client.force_login(self.superuser) + with self.assertNumQueries(7): + response = self.client.get( + reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}), + {'page_size': 6, 'with_count': True} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertIsNotNone(data['next']) + self.assertEqual(data['count'], 12) + self.assertEqual(len(data['results']), 6) + next_url = data.get('next') + self.assertIn('cursor', next_url) + + second_page = self.client.get(next_url) + data = second_page.json() + self.assertIsNone(data['count']) + self.assertIsNone(data['next']) + self.assertEqual(len(data['results']), 6) + + def test_cursor_pagination_page_size(self): + """ + Page size may be changed for cursor pagination + """ + Element.objects.bulk_create([ + Element( + corpus=self.private_corpus, + name='Similar name', + type=self.page_type + ) for i in range(51) + ]) + self.dataimport.name_contains = 'Similar' + self.dataimport.save() + + self.client.force_login(self.superuser) + with self.assertNumQueries(6): + response = self.client.get( + reverse('api:process-elements-list', kwargs={'pk': self.dataimport.id}), + {'page_size': 50} + ) + data = response.json() + self.assertEqual(len(data['results']), 50) + self.assertIsNotNone(data['next']) diff --git a/arkindex/dataimport/tests/test_workflows_api.py b/arkindex/dataimport/tests/test_workflows_api.py index 21c864446ca752d03dcb6f63681a17e3b51e1580..b43afb03ba8ec56e83183d20faa39e8e9492e772 100644 --- a/arkindex/dataimport/tests/test_workflows_api.py +++ b/arkindex/dataimport/tests/test_workflows_api.py @@ -338,7 +338,7 @@ class TestWorkflows(FixtureAPITestCase): response = self.client.get(reverse('api:process-elements-list', kwargs={'pk': dataimport.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(data['count'], 1) + self.assertEqual(len(data['results']), 1) self.assertEqual(data['results'][0]['id'], str(page.id)) @override_settings(ARKINDEX_FEATURES={'selection': False}) diff --git a/arkindex/project/mixins.py b/arkindex/project/mixins.py index bcfdb7046712b8b91463d0970e6837717df4f637..8ce92cc564dd2246421bcf9eade971c00f386499 100644 --- a/arkindex/project/mixins.py +++ b/arkindex/project/mixins.py @@ -8,6 +8,7 @@ from arkindex.documents.models import Corpus, Right from arkindex.documents.serializers.search import SearchQuerySerializer from arkindex.project.elastic import ESQuerySet from arkindex.project.openapi import AutoSchema, SearchAutoSchema +from arkindex.project.pagination import CustomCursorPagination class CorpusACLMixin(object): @@ -158,3 +159,22 @@ class CachedViewMixin(object): self.cache_timeout, key_prefix=self.cache_prefix, )(self.dispatch) + + +class CustomPaginationViewMixin(object): + """ + A custom cursor pagination mixin + Elements count can be retrieved with the `with_count` parameter if there is no cursor + """ + pagination_class = CustomCursorPagination + + @property + def paginator(self): + if not hasattr(self, '_paginator'): + params = self.request.query_params + with_count = ( + not params.get('cursor') + and params.get('with_count') not in (None, '', 'false', '0') + ) + self._paginator = self.pagination_class(with_count=with_count) + return self._paginator diff --git a/arkindex/project/pagination.py b/arkindex/project/pagination.py index c4160e51f29e25d68b318aea629f9b864fe260c2..d88073ccea63ae8fd8efbe19f8ea16115a6b1c3e 100644 --- a/arkindex/project/pagination.py +++ b/arkindex/project/pagination.py @@ -26,3 +26,53 @@ class PageNumberPagination(pagination.PageNumberPagination): 'example': 123, } return schema + + +class CustomCursorPagination(pagination.CursorPagination): + """ + A custom cursor pagination class + Count attribute and ordering may be updated when instanciating the class + """ + count = None + page_size = 20 + page_size_query_param = 'page_size' + max_page_size = 1000 + + def __init__(self, *args, **kwargs): + self.with_count = kwargs.pop('with_count', False) + self.ordering = kwargs.pop('ordering', 'id') + super().__init__(*args, **kwargs) + + def paginate_queryset(self, queryset, *args, **kwargs): + if self.with_count: + self.count = queryset.count() + return super().paginate_queryset(queryset, *args, **kwargs) + + def get_paginated_response(self, data): + return Response(OrderedDict([ + ('next', self.get_next_link()), + ('previous', self.get_previous_link()), + ('count', self.count), + ('results', data) + ])) + + def get_schema_operation_parameters(self, view): + parameters = super().get_schema_operation_parameters(view) + parameters.append({ + 'name': 'with_count', + 'required': False, + 'in': 'query', + 'description': 'Count the total number of elements. Incompatible with `cursor` parameter.', + 'schema': { + 'type': 'boolean', + } + }) + return parameters + + def get_paginated_response_schema(self, schema): + schema = super().get_paginated_response_schema(schema) + schema['properties']['count'] = { + 'type': 'integer', + 'example': None, + } + return schema