Skip to content
Snippets Groups Projects
Commit c97e6df3 authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

Append transcriptions in bulk

parent a11592ad
No related branches found
No related tags found
No related merge requests found
from django.conf import settings
from rest_framework import status
from rest_framework.generics import CreateAPIView
from rest_framework.generics import CreateAPIView, UpdateAPIView
from rest_framework.exceptions import ValidationError
from rest_framework.response import Response
from arkindex.documents.models import Classification, DataSource, Transcription, TranscriptionType
......@@ -68,7 +68,7 @@ class TranscriptionCreate(CreateAPIView):
return Response({'id': obj.id}, status=status.HTTP_201_CREATED, headers=headers)
class TranscriptionBulk(CreateAPIView):
class TranscriptionBulk(CreateAPIView, UpdateAPIView):
'''
Create transcriptions in bulk, all linked to the same image
and parent element
......@@ -76,8 +76,16 @@ class TranscriptionBulk(CreateAPIView):
serializer_class = TranscriptionsSerializer
permission_classes = (IsVerified, )
def get_object(self): # Ignore the get_object part in UpdateAPIView.update
return
def perform_create(self, serializer):
self.run(serializer, delete=False)
def perform_update(self, serializer):
self.run(serializer, delete=True)
def run(self, serializer, delete=False):
parent = serializer.validated_data['parent']
source = DataSource.from_ml_tool(serializer.validated_data['recognizer'])
trpolygons = build_transcriptions(
......@@ -96,8 +104,10 @@ class TranscriptionBulk(CreateAPIView):
)
if not trpolygons:
return
# Clear all previous transcriptions from source
parent.transcriptions.filter(source=source).delete()
if delete:
# Clear all previous transcriptions from source
parent.transcriptions.filter(source=source).delete()
transcriptions, _ = save_transcriptions(*trpolygons)
......
......@@ -192,3 +192,70 @@ class TestTranscriptionCreate(FixtureAPITestCase):
]
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@patch('arkindex.project.serializer_fields.MLTool.get')
@patch('arkindex.images.importer.Indexer')
def test_update_bulk_transcription(self, indexer, ml_get_mock):
ml_get_mock.return_value.type = self.src.type
ml_get_mock.return_value.slug = self.src.slug
ml_get_mock.return_value.version = self.src.revision
self.src.internal = True
self.src.save()
data = {
"parent": str(self.page.id),
"recognizer": self.src.slug,
"transcriptions": [
{
"type": "word",
"polygon": [(0, 0), (100, 0), (100, 100), (0, 100), (0, 0)],
"text": "NEKUDOTAYIM",
"score": 0.83,
},
{
"type": "line",
"polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)],
"text": "This is a test",
"score": 0.75,
},
]
}
self.client.force_login(self.user)
response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
new_ts = self.page.transcriptions.get(text="NEKUDOTAYIM", type=TranscriptionType.Word)
self.assertEqual(new_ts.zone.polygon, Polygon.from_coords(0, 0, 100, 100))
self.assertEqual(new_ts.score, 0.83)
self.assertEqual(new_ts.source, self.src)
page_ts = self.page.transcriptions.get(type=TranscriptionType.Page)
self.assertEqual(page_ts.text, 'This is a test')
self.assertEqual(page_ts.score, None)
self.assertEqual(page_ts.source, self.src)
self.assertEqual(page_ts.zone, self.page.zone)
# Indexer called
self.assertEqual(indexer.return_value.run_index.call_count, 2)
# Update again
indexer.reset_mock()
data['transcriptions'] = [
{
"type": "word",
"polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)],
"text": "Hi",
"score": 0.99,
},
]
response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(indexer.return_value.run_index.call_count, 2)
# Previous transcriptions should be replaced by a Word and a Page transcription
self.assertEqual(self.page.transcriptions.count(), 2)
ts = self.page.transcriptions.get(type=TranscriptionType.Word)
self.assertEqual(ts.text, "Hi")
page_ts = self.page.transcriptions.get(type=TranscriptionType.Page)
self.assertEqual(page_ts.text, '')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment