diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index ab87ada1c6076d97a1b9312b2c99739e2d687f62..093cb312a9795cae3500c35a32252b5435168787 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -271,7 +271,10 @@ class ElementTranscriptionsBulk(CreateAPIView): .only('id', 'zone_id') } # Load the paths immediately to avoid iterating over them for each element - paths = list(self.element.paths.all()) + paths = list(self.element.paths.values_list('path', flat=True)) + if not paths: + # Support top level elements, by adding an empty initial path to trigger loops below + paths = [[]] next_path_ordering = self.element.get_next_order(elt_type) for annotation in annotations: # Look for a direct children with the right type and zone @@ -289,7 +292,7 @@ class ElementTranscriptionsBulk(CreateAPIView): children[annotation['zone_id']] = annotation['element'] missing_elements.append(annotation['element']) for parent_path in paths: - new_path = parent_path.path + [self.element.id] + new_path = parent_path + [self.element.id] # Add the children to all of its parent paths missing_paths.append(ElementPath( element=annotation['element'], diff --git a/arkindex/documents/tests/test_bulk_element_transcriptions.py b/arkindex/documents/tests/test_bulk_element_transcriptions.py index 33ff67d5319a80fc2ba91e258b85d54ad6fefec5..a6a8bcd43ff269d13102d51550893702c83699fc 100644 --- a/arkindex/documents/tests/test_bulk_element_transcriptions.py +++ b/arkindex/documents/tests/test_bulk_element_transcriptions.py @@ -319,3 +319,68 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {'element': ['Cannot create transcriptions on an element without a zone.']}) + + @patch('arkindex.project.triggers.tasks.reindex_start.delay') + def test_top_level_element(self, delay_mock): + """ + Create transcriptions on a top level element + """ + # Create a top level page + top_level = self.corpus.elements.create( + type=self.page.type, + name='Top level page', + zone=self.page.zone, + ) + + transcriptions = [ + ([[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'Hello world !', 0.1337), + ([[24, 42], [64, 42], [64, 142], [24, 142], [24, 42]], 'I <3 JavaScript', 0.42), + ] + data = { + 'element_type': 'text_line', + 'transcription_type': 'line', + 'worker_version': str(self.worker_version.id), + 'transcriptions': [{ + 'polygon': poly, + 'text': text, + 'score': score + } for poly, text, score in transcriptions] + } + self.client.force_login(self.user) + response = self.client.post( + reverse('api:element-transcriptions-bulk', kwargs={'pk': top_level.id}), + format='json', + data=data + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + created_elts = Element.objects.get_descending(top_level.id) + self.assertEqual(created_elts.count(), 2) + self.assertTrue(all(map(lambda elt: elt.zone.image == self.page.zone.image, created_elts))) + self.assertListEqual( + [ + (elt.paths.first().ordering, elt.name, elt.zone.polygon.coords) + for elt in created_elts + ], + [ + (0, '1', ((13, 37), (13, 137), (133, 137), (133, 37), (13, 37))), + (1, '2', ((24, 42), (24, 142), (64, 142), (64, 42), (24, 42))) + ] + ) + self.assertCountEqual( + created_elts.values_list('transcriptions__type', 'transcriptions__text', 'transcriptions__source', 'transcriptions__worker_version'), + [ + (TranscriptionType.Line, ('Hello world !'), None, self.worker_version.id), + (TranscriptionType.Line, ('I <3 JavaScript'), None, self.worker_version.id) + ] + ) + self.assertEqual(delay_mock.call_count, 1) + self.assertEqual(delay_mock.call_args, call( + element_id=str(top_level.id), + corpus_id=None, + entity_id=None, + transcriptions=True, + elements=True, + entities=False, + drop=False, + ))