Skip to content
Snippets Groups Projects
Commit c51e18ef authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

CreateTranscriptions endpoint

parent d39e6d6e
No related branches found
No related tags found
No related merge requests found
......@@ -3,8 +3,8 @@ from django.db import transaction
from django.db.models import Q, Count
from rest_framework import status
from rest_framework.generics import (
GenericAPIView, ListAPIView, ListCreateAPIView,
CreateAPIView, UpdateAPIView, RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView
GenericAPIView, ListAPIView, ListCreateAPIView, CreateAPIView,
RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView
)
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.response import Response
......@@ -13,9 +13,10 @@ from arkindex.documents.models import \
from arkindex_common.ml_tool import MLToolType
from arkindex.documents.serializers.ml import (
ClassificationsSerializer, ClassificationCreateSerializer, ClassificationSerializer,
TranscriptionSerializer, TranscriptionCreateSerializer, ElementTranscriptionsBulkSerializer,
DataSourceStatsSerializer, ClassificationsSelectionSerializer, ClassificationMode,
CountMLClassSerializer, AnnotatedElementSerializer
ClassificationsSelectionSerializer, ClassificationMode,
TranscriptionSerializer, TranscriptionCreateSerializer, TranscriptionBulkSerializer,
ElementTranscriptionsBulkSerializer, AnnotatedElementSerializer,
DataSourceStatsSerializer, CountMLClassSerializer
)
from arkindex.images.models import Zone
from arkindex.project.filters import SafeSearchFilter
......@@ -292,21 +293,16 @@ class ElementTranscriptionsBulk(CreateAPIView):
return annotations
class TranscriptionBulk(DeprecatedMixin, CreateAPIView, UpdateAPIView):
class TranscriptionBulk(CreateAPIView):
'''
Create multiple transcriptions at once, all linked to the same page
and to the same recognizer.
Create multiple transcriptions at once on existing elements
'''
# Force DRF to ignore PATCH
http_method_names = ['post', 'put', 'head', 'options', 'trace']
openapi_overrides = {
'operationId': 'CreateTranscriptions',
'tags': ['transcriptions'],
}
deprecation_message = (
'Creating or updating transcriptions with their own zones is now deprecated. '
'Please use CreateElementTranscriptions to push transcriptions in bulk '
'attached to sub-elements.'
)
permission_classes = (IsVerified, )
serializer_class = TranscriptionBulkSerializer
class CorpusMLClassList(CorpusACLMixin, ListCreateAPIView):
......
......@@ -369,6 +369,56 @@ class AnnotatedElementSerializer(serializers.Serializer):
created = serializers.BooleanField(default=False)
class TranscriptionBulkItemSerializer(serializers.Serializer):
# Element retrieval and checks is done in the BulkSerializer to avoid duplicate queries
element_id = serializers.UUIDField(
help_text='ID of an existing element to add the transcription to'
)
type = EnumField(TranscriptionType, help_text='Type of the transcription')
text = serializers.CharField()
score = serializers.FloatField(min_value=0, max_value=1)
class TranscriptionBulkSerializer(serializers.Serializer):
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all())
transcriptions = TranscriptionBulkItemSerializer(many=True)
def validate(self, data):
element_ids = set(transcription['element_id'] for transcription in data['transcriptions'])
found_ids = set(Element.objects.filter(
id__in=element_ids,
corpus__in=Corpus.objects.writable(self.context['request'].user)
).values_list('id', flat=True))
missing_ids = element_ids - found_ids
if not missing_ids:
return data
# Return an error message as a list just like DRF's ListField, for easier debugging
raise ValidationError({'transcriptions': [
{"element_id": [f'Element {transcription["element_id"]} was not found or cannot be written to.']}
if transcription['element_id'] in missing_ids
else {}
for i, transcription in enumerate(data['transcriptions'])
]})
def create(self, validated_data):
transcriptions = [
Transcription(
worker_version=validated_data['worker_version'],
element_id=transcription['element_id'],
type=transcription['type'],
text=transcription['text'],
score=transcription['score'],
)
for transcription in validated_data['transcriptions']
]
Transcription.objects.bulk_create(transcriptions)
validated_data['transcriptions'] = transcriptions
return validated_data
class ClassificationBulkSerializer(serializers.Serializer):
"""
Single classification serializer for bulk insertion
......
from django.urls import reverse
from rest_framework import status
from arkindex_common.enums import TranscriptionType
from arkindex.project.tests import FixtureAPITestCase
from arkindex.dataimport.models import WorkerVersion
class TestBulkTranscriptions(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
def test_bulk_transcriptions_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(reverse('api:transcription-bulk'))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_bulk_transcriptions_not_found(self):
self.client.force_login(self.user)
self.user.corpus_right.all().delete()
forbidden_element = self.corpus.elements.get(name='Volume 1, page 1r')
with self.assertNumQueries(4):
response = self.client.post(reverse('api:transcription-bulk'), {
"worker_version": str(self.worker_version.id),
"transcriptions": [
{
"element_id": str(forbidden_element.id),
"type": TranscriptionType.Word.value,
"text": "lol",
"score": 0.4
},
{
"element_id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
"type": TranscriptionType.Word.value,
"text": "lol",
"score": 0.4
}
],
}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {
'transcriptions': [
{'element_id': [f'Element {forbidden_element.id} was not found or cannot be written to.']},
{'element_id': ['Element aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa was not found or cannot be written to.']},
]
})
def test_bulk_transcriptions(self):
self.client.force_login(self.user)
element1 = self.corpus.elements.get(name='Volume 2')
element2 = self.corpus.elements.get(name='Volume 2, page 1r')
self.assertFalse(element1.transcriptions.exists())
self.assertFalse(element2.transcriptions.exists())
with self.assertNumQueries(5):
response = self.client.post(reverse('api:transcription-bulk'), {
"worker_version": str(self.worker_version.id),
"transcriptions": [
{
"element_id": str(element1.id),
"type": TranscriptionType.Word.value,
"text": "Sneasel",
"score": 0.54
},
{
"element_id": str(element2.id),
"type": TranscriptionType.Line.value,
"text": "Charizard",
"score": 0.85
},
{
"element_id": str(element1.id),
"type": TranscriptionType.Word.value,
"text": "Raticate",
"score": 0.12
},
],
}, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertCountEqual(
list(element1.transcriptions.values('type', 'text', 'score')),
[
{
"type": TranscriptionType.Word,
"text": "Sneasel",
"score": 0.54
},
{
"type": TranscriptionType.Word,
"text": "Raticate",
"score": 0.12
},
]
)
self.assertCountEqual(
list(element2.transcriptions.values('type', 'text', 'score')),
[{
"type": TranscriptionType.Line,
"text": "Charizard",
"score": 0.85
}]
)
......@@ -283,14 +283,6 @@ paths:
description: Update the text of a manual transcription
delete:
description: Delete a manual transcription
/api/v1/transcription/bulk/:
post:
operationId: CreateTranscriptions
put:
operationId: UpdateTranscriptions
description: >-
Replace all existing transcriptions from a given recognizer on a page
with other transcriptions.
/api/v1/metadata/{id}/:
get:
operationId: RetrieveMetaData
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment