Skip to content
Snippets Groups Projects
Commit 7d752b9c authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'append-transcriptions-bulk' into 'master'

Append transcriptions in bulk

See merge request !275
parents a11592ad c97e6df3
No related branches found
No related tags found
1 merge request!275Append transcriptions in bulk
from django.conf import settings
from rest_framework import status
from rest_framework.generics import CreateAPIView
from rest_framework.generics import CreateAPIView, UpdateAPIView
from rest_framework.exceptions import ValidationError
from rest_framework.response import Response
from arkindex.documents.models import Classification, DataSource, Transcription, TranscriptionType
......@@ -68,7 +68,7 @@ class TranscriptionCreate(CreateAPIView):
return Response({'id': obj.id}, status=status.HTTP_201_CREATED, headers=headers)
class TranscriptionBulk(CreateAPIView):
class TranscriptionBulk(CreateAPIView, UpdateAPIView):
'''
Create transcriptions in bulk, all linked to the same image
and parent element
......@@ -76,8 +76,16 @@ class TranscriptionBulk(CreateAPIView):
serializer_class = TranscriptionsSerializer
permission_classes = (IsVerified, )
def get_object(self): # Ignore the get_object part in UpdateAPIView.update
return
def perform_create(self, serializer):
self.run(serializer, delete=False)
def perform_update(self, serializer):
self.run(serializer, delete=True)
def run(self, serializer, delete=False):
parent = serializer.validated_data['parent']
source = DataSource.from_ml_tool(serializer.validated_data['recognizer'])
trpolygons = build_transcriptions(
......@@ -96,8 +104,10 @@ class TranscriptionBulk(CreateAPIView):
)
if not trpolygons:
return
# Clear all previous transcriptions from source
parent.transcriptions.filter(source=source).delete()
if delete:
# Clear all previous transcriptions from source
parent.transcriptions.filter(source=source).delete()
transcriptions, _ = save_transcriptions(*trpolygons)
......
......@@ -192,3 +192,70 @@ class TestTranscriptionCreate(FixtureAPITestCase):
]
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@patch('arkindex.project.serializer_fields.MLTool.get')
@patch('arkindex.images.importer.Indexer')
def test_update_bulk_transcription(self, indexer, ml_get_mock):
ml_get_mock.return_value.type = self.src.type
ml_get_mock.return_value.slug = self.src.slug
ml_get_mock.return_value.version = self.src.revision
self.src.internal = True
self.src.save()
data = {
"parent": str(self.page.id),
"recognizer": self.src.slug,
"transcriptions": [
{
"type": "word",
"polygon": [(0, 0), (100, 0), (100, 100), (0, 100), (0, 0)],
"text": "NEKUDOTAYIM",
"score": 0.83,
},
{
"type": "line",
"polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)],
"text": "This is a test",
"score": 0.75,
},
]
}
self.client.force_login(self.user)
response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
new_ts = self.page.transcriptions.get(text="NEKUDOTAYIM", type=TranscriptionType.Word)
self.assertEqual(new_ts.zone.polygon, Polygon.from_coords(0, 0, 100, 100))
self.assertEqual(new_ts.score, 0.83)
self.assertEqual(new_ts.source, self.src)
page_ts = self.page.transcriptions.get(type=TranscriptionType.Page)
self.assertEqual(page_ts.text, 'This is a test')
self.assertEqual(page_ts.score, None)
self.assertEqual(page_ts.source, self.src)
self.assertEqual(page_ts.zone, self.page.zone)
# Indexer called
self.assertEqual(indexer.return_value.run_index.call_count, 2)
# Update again
indexer.reset_mock()
data['transcriptions'] = [
{
"type": "word",
"polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)],
"text": "Hi",
"score": 0.99,
},
]
response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(indexer.return_value.run_index.call_count, 2)
# Previous transcriptions should be replaced by a Word and a Page transcription
self.assertEqual(self.page.transcriptions.count(), 2)
ts = self.page.transcriptions.get(type=TranscriptionType.Word)
self.assertEqual(ts.text, "Hi")
page_ts = self.page.transcriptions.get(type=TranscriptionType.Page)
self.assertEqual(page_ts.text, '')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment