diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index b3c654aedbc8b06af5d72da662433b91ab5af958..9a1ec1396d18e2e71da72d7308485c36ea83792a 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -1907,7 +1907,10 @@ class ElementTypeUpdate(RetrieveUpdateDestroyAPIView): ) class ElementBulkCreate(CreateAPIView): """ - Create multiple child elements at once on a single parent + Create multiple child elements at once on a single parent. + + Exactly one of `worker_version` or `worker_run_id` fields must be set. + If `worker_run_id` is set, the worker version will be deduced from it. """ serializer_class = ElementBulkSerializer permission_classes = (IsVerified, ) diff --git a/arkindex/documents/serializers/elements.py b/arkindex/documents/serializers/elements.py index 07728cb91e8470fa8204e53204be749d35e82136..4260ebef0aabfb4a2973ab141ec362aa95aa1561 100644 --- a/arkindex/documents/serializers/elements.py +++ b/arkindex/documents/serializers/elements.py @@ -944,7 +944,11 @@ class ElementBulkSerializer(serializers.Serializer): element_errors[i] = {'polygon': ["An element's polygon must not exceed its image's bounds."]} worker_run = data.get('worker_run', None) - if worker_run: + if not worker_run and not data.get('worker_version'): + errors['non_field_errors'] = [ + 'Exactly one of `worker_version` or `worker_run_id` must be set.' + ] + elif worker_run: data['worker_version'] = WorkerVersion(id=worker_run.version_id) if element_errors: diff --git a/arkindex/documents/tests/test_bulk_elements.py b/arkindex/documents/tests/test_bulk_elements.py index 21ca35d5c8b4ff4a5a50b70a3d2d205932ac92ed..3040e7ef07ab1032b772280f440ffc42ba7eadcd 100644 --- a/arkindex/documents/tests/test_bulk_elements.py +++ b/arkindex/documents/tests/test_bulk_elements.py @@ -1,5 +1,6 @@ from uuid import uuid4 +from django.contrib.gis.geos import LineString from django.urls import reverse from rest_framework import status @@ -143,65 +144,6 @@ class TestBulkElements(FixtureAPITestCase): 'non_field_errors': ["Element types with slugs nope do not exist in the parent element's corpus"] }) - def test_bulk_create(self): - self.client.force_login(self.user) - with self.assertNumQueries(13): - response = self.client.post( - reverse('api:elements-bulk-create', kwargs={'pk': str(self.element.id)}), - data=self.payload, - format='json', - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - element_path = self.element.paths.get() - a, b, c = Element \ - .objects \ - .get_descending(self.element.id) \ - .filter(type__slug__in=('act', 'surface')) \ - .order_by('name') - a_path, b_path, c_path = a.paths.get(), b.paths.get(), c.paths.get() - - self.assertListEqual( - response.json(), - [ - {'id': str(a.id)}, - {'id': str(b.id)}, - {'id': str(c.id)} - ] - ) - - self.assertEqual(a.name, 'A') - self.assertEqual(b.name, 'B') - self.assertEqual(c.name, 'C') - self.assertEqual(a.type.slug, 'act') - self.assertEqual(b.type.slug, 'surface') - self.assertEqual(c.type.slug, 'surface') - self.assertEqual(a.worker_version, self.worker_version) - self.assertEqual(b.worker_version, self.worker_version) - self.assertEqual(c.worker_version, self.worker_version) - self.assertEqual(a.worker_run, None) - self.assertEqual(b.worker_run, None) - self.assertEqual(c.worker_run, None) - self.assertEqual(a.image_id, self.element.image_id) - self.assertEqual(b.image_id, self.element.image_id) - self.assertEqual(c.image_id, self.element.image_id) - self.assertTupleEqual(a.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0))) - self.assertTupleEqual(b.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0))) - self.assertTupleEqual(c.polygon.coords, ((0, 0), (0, 9), (9, 9), (9, 0), (0, 0))) - self.assertEqual(a.rotation_angle, self.element.rotation_angle) - self.assertEqual(b.rotation_angle, self.element.rotation_angle) - self.assertEqual(c.rotation_angle, self.element.rotation_angle) - self.assertEqual(a.mirrored, self.element.mirrored) - self.assertEqual(b.mirrored, self.element.mirrored) - self.assertEqual(c.mirrored, self.element.mirrored) - - self.assertListEqual(a_path.path, element_path.path + [self.element.id]) - self.assertListEqual(b_path.path, element_path.path + [self.element.id]) - self.assertListEqual(c_path.path, element_path.path + [self.element.id]) - self.assertEqual(a_path.ordering, 0) - self.assertEqual(b_path.ordering, 0) - self.assertEqual(c_path.ordering, 1) - def test_bulk_create_multiple_parent_paths(self): parent = self.corpus.elements.create( name='Parent 2', @@ -349,7 +291,6 @@ class TestBulkElements(FixtureAPITestCase): """ Cannot create elements outside their image """ - self.maxDiff = None self.client.force_login(self.user) payload = { 'worker_version': str(self.worker_version.id), @@ -427,7 +368,32 @@ class TestBulkElements(FixtureAPITestCase): 'worker_run_id': [f'Invalid pk "{random_uuid}" - object does not exist.'] }) - def test_bulk_create_invalid_parameters_worker_run(self): + def test_bulk_create_worker_run_or_version(self): + """Either a worker run or a worker version is required""" + + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:elements-bulk-create', kwargs={'pk': str(self.element.id)}), + data={ + 'elements': [ + { + 'name': 'Blah', + 'type': 'surface', + 'polygon': [[0, 10], [10, 20], [30, 40], [50, 60], [0, 10]] + } + ] + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'non_field_errors': ['Exactly one of `worker_version` or `worker_run_id` must be set.'] + }) + + def test_bulk_create_worker_run_and_version(self): + """Worker run and worker version cannot be set at the same time""" + self.client.force_login(self.user) with self.assertNumQueries(7): response = self.client.post( @@ -450,6 +416,48 @@ class TestBulkElements(FixtureAPITestCase): 'non_field_errors': ['Only one of `worker_version` and `worker_run_id` may be set.'] }) + def test_bulk_create_with_worker_version(self): + self.client.force_login(self.user) + with self.assertNumQueries(13): + response = self.client.post( + reverse('api:elements-bulk-create', kwargs={'pk': str(self.element.id)}), + data=self.payload, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + elements = ( + Element.objects + .get_descending(self.element.id) + .filter(type__slug__in=('act', 'surface')) + .order_by('name') + ) + element_ids = list(elements.values_list("id", flat=True)) + # Only IDs are returned in the payload + self.assertListEqual(response.json(), [{'id': str(elt_id)} for elt_id in element_ids]) + # Test common attributes + parent_path = self.element.paths.get().path + self.assertListEqual( + list(elements.values_list( + 'worker_version_id', + 'worker_run_id', + 'image_id', + 'rotation_angle', + 'mirrored', + 'paths__path', + )), + 3 * [(self.worker_version.id, None, self.element.image_id, 0, False, [*parent_path, self.element.id])] + ) + # Test specific attributes + self.assertListEqual( + list(elements.values_list('name', 'type__slug', 'paths__ordering', 'polygon')), + [ + ('A', 'act', 0, LineString(self.payload['elements'][0]['polygon'])), + ('B', 'surface', 0, LineString(self.payload['elements'][1]['polygon'])), + ('C', 'surface', 1, LineString(self.payload['elements'][2]['polygon'])), + ], + ) + def test_bulk_create_with_worker_run(self): self.client.force_login(self.user) with self.assertNumQueries(13): @@ -460,51 +468,34 @@ class TestBulkElements(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - element_path = self.element.paths.get() - a, b, c = Element \ - .objects \ - .get_descending(self.element.id) \ - .filter(type__slug__in=('act', 'surface')) \ + elements = ( + Element.objects + .get_descending(self.element.id) + .filter(type__slug__in=('act', 'surface')) .order_by('name') - a_path, b_path, c_path = a.paths.get(), b.paths.get(), c.paths.get() - + ) + element_ids = list(elements.values_list("id", flat=True)) + # Only IDs are returned in the payload + self.assertListEqual(response.json(), [{'id': str(elt_id)} for elt_id in element_ids]) + # Test common attributes + parent_path = self.element.paths.get().path self.assertListEqual( - response.json(), + list(elements.values_list( + 'worker_version_id', + 'worker_run_id', + 'image_id', + 'rotation_angle', + 'mirrored', + 'paths__path', + )), + 3 * [(self.worker_version.id, self.worker_run.id, self.element.image_id, 0, False, [*parent_path, self.element.id])] + ) + # Test specific attributes + self.assertListEqual( + list(elements.values_list('name', 'type__slug', 'paths__ordering', 'polygon')), [ - {'id': str(a.id)}, - {'id': str(b.id)}, - {'id': str(c.id)} - ] + ('A', 'act', 0, LineString(self.payload['elements'][0]['polygon'])), + ('B', 'surface', 0, LineString(self.payload['elements'][1]['polygon'])), + ('C', 'surface', 1, LineString(self.payload['elements'][2]['polygon'])), + ], ) - - self.assertEqual(a.name, 'A') - self.assertEqual(b.name, 'B') - self.assertEqual(c.name, 'C') - self.assertEqual(a.type.slug, 'act') - self.assertEqual(b.type.slug, 'surface') - self.assertEqual(c.type.slug, 'surface') - self.assertEqual(a.worker_version, self.worker_run.version) - self.assertEqual(b.worker_version, self.worker_run.version) - self.assertEqual(c.worker_version, self.worker_run.version) - self.assertEqual(a.worker_run, self.worker_run) - self.assertEqual(b.worker_run, self.worker_run) - self.assertEqual(c.worker_run, self.worker_run) - self.assertEqual(a.image_id, self.element.image_id) - self.assertEqual(b.image_id, self.element.image_id) - self.assertEqual(c.image_id, self.element.image_id) - self.assertTupleEqual(a.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0))) - self.assertTupleEqual(b.polygon.coords, ((0, 0), (0, 1), (1, 1), (1, 0), (0, 0))) - self.assertTupleEqual(c.polygon.coords, ((0, 0), (0, 9), (9, 9), (9, 0), (0, 0))) - self.assertEqual(a.rotation_angle, self.element.rotation_angle) - self.assertEqual(b.rotation_angle, self.element.rotation_angle) - self.assertEqual(c.rotation_angle, self.element.rotation_angle) - self.assertEqual(a.mirrored, self.element.mirrored) - self.assertEqual(b.mirrored, self.element.mirrored) - self.assertEqual(c.mirrored, self.element.mirrored) - - self.assertListEqual(a_path.path, element_path.path + [self.element.id]) - self.assertListEqual(b_path.path, element_path.path + [self.element.id]) - self.assertListEqual(c_path.path, element_path.path + [self.element.id]) - self.assertEqual(a_path.ordering, 0) - self.assertEqual(b_path.ordering, 0) - self.assertEqual(c_path.ordering, 1)