Skip to content
Snippets Groups Projects
Commit a3df9e3d authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'create-transcriptions' into 'master'

CreateTranscriptions endpoint

Closes #465

See merge request !1047
parents d39e6d6e c51e18ef
No related branches found
No related tags found
1 merge request!1047CreateTranscriptions endpoint
......@@ -3,8 +3,8 @@ from django.db import transaction
from django.db.models import Q, Count
from rest_framework import status
from rest_framework.generics import (
GenericAPIView, ListAPIView, ListCreateAPIView,
CreateAPIView, UpdateAPIView, RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView
GenericAPIView, ListAPIView, ListCreateAPIView, CreateAPIView,
RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView
)
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.response import Response
......@@ -13,9 +13,10 @@ from arkindex.documents.models import \
from arkindex_common.ml_tool import MLToolType
from arkindex.documents.serializers.ml import (
ClassificationsSerializer, ClassificationCreateSerializer, ClassificationSerializer,
TranscriptionSerializer, TranscriptionCreateSerializer, ElementTranscriptionsBulkSerializer,
DataSourceStatsSerializer, ClassificationsSelectionSerializer, ClassificationMode,
CountMLClassSerializer, AnnotatedElementSerializer
ClassificationsSelectionSerializer, ClassificationMode,
TranscriptionSerializer, TranscriptionCreateSerializer, TranscriptionBulkSerializer,
ElementTranscriptionsBulkSerializer, AnnotatedElementSerializer,
DataSourceStatsSerializer, CountMLClassSerializer
)
from arkindex.images.models import Zone
from arkindex.project.filters import SafeSearchFilter
......@@ -292,21 +293,16 @@ class ElementTranscriptionsBulk(CreateAPIView):
return annotations
class TranscriptionBulk(DeprecatedMixin, CreateAPIView, UpdateAPIView):
class TranscriptionBulk(CreateAPIView):
'''
Create multiple transcriptions at once, all linked to the same page
and to the same recognizer.
Create multiple transcriptions at once on existing elements
'''
# Force DRF to ignore PATCH
http_method_names = ['post', 'put', 'head', 'options', 'trace']
openapi_overrides = {
'operationId': 'CreateTranscriptions',
'tags': ['transcriptions'],
}
deprecation_message = (
'Creating or updating transcriptions with their own zones is now deprecated. '
'Please use CreateElementTranscriptions to push transcriptions in bulk '
'attached to sub-elements.'
)
permission_classes = (IsVerified, )
serializer_class = TranscriptionBulkSerializer
class CorpusMLClassList(CorpusACLMixin, ListCreateAPIView):
......
......@@ -369,6 +369,56 @@ class AnnotatedElementSerializer(serializers.Serializer):
created = serializers.BooleanField(default=False)
class TranscriptionBulkItemSerializer(serializers.Serializer):
# Element retrieval and checks is done in the BulkSerializer to avoid duplicate queries
element_id = serializers.UUIDField(
help_text='ID of an existing element to add the transcription to'
)
type = EnumField(TranscriptionType, help_text='Type of the transcription')
text = serializers.CharField()
score = serializers.FloatField(min_value=0, max_value=1)
class TranscriptionBulkSerializer(serializers.Serializer):
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all())
transcriptions = TranscriptionBulkItemSerializer(many=True)
def validate(self, data):
element_ids = set(transcription['element_id'] for transcription in data['transcriptions'])
found_ids = set(Element.objects.filter(
id__in=element_ids,
corpus__in=Corpus.objects.writable(self.context['request'].user)
).values_list('id', flat=True))
missing_ids = element_ids - found_ids
if not missing_ids:
return data
# Return an error message as a list just like DRF's ListField, for easier debugging
raise ValidationError({'transcriptions': [
{"element_id": [f'Element {transcription["element_id"]} was not found or cannot be written to.']}
if transcription['element_id'] in missing_ids
else {}
for i, transcription in enumerate(data['transcriptions'])
]})
def create(self, validated_data):
transcriptions = [
Transcription(
worker_version=validated_data['worker_version'],
element_id=transcription['element_id'],
type=transcription['type'],
text=transcription['text'],
score=transcription['score'],
)
for transcription in validated_data['transcriptions']
]
Transcription.objects.bulk_create(transcriptions)
validated_data['transcriptions'] = transcriptions
return validated_data
class ClassificationBulkSerializer(serializers.Serializer):
"""
Single classification serializer for bulk insertion
......
from django.urls import reverse
from rest_framework import status
from arkindex_common.enums import TranscriptionType
from arkindex.project.tests import FixtureAPITestCase
from arkindex.dataimport.models import WorkerVersion
class TestBulkTranscriptions(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
def test_bulk_transcriptions_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(reverse('api:transcription-bulk'))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_bulk_transcriptions_not_found(self):
self.client.force_login(self.user)
self.user.corpus_right.all().delete()
forbidden_element = self.corpus.elements.get(name='Volume 1, page 1r')
with self.assertNumQueries(4):
response = self.client.post(reverse('api:transcription-bulk'), {
"worker_version": str(self.worker_version.id),
"transcriptions": [
{
"element_id": str(forbidden_element.id),
"type": TranscriptionType.Word.value,
"text": "lol",
"score": 0.4
},
{
"element_id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
"type": TranscriptionType.Word.value,
"text": "lol",
"score": 0.4
}
],
}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {
'transcriptions': [
{'element_id': [f'Element {forbidden_element.id} was not found or cannot be written to.']},
{'element_id': ['Element aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa was not found or cannot be written to.']},
]
})
def test_bulk_transcriptions(self):
self.client.force_login(self.user)
element1 = self.corpus.elements.get(name='Volume 2')
element2 = self.corpus.elements.get(name='Volume 2, page 1r')
self.assertFalse(element1.transcriptions.exists())
self.assertFalse(element2.transcriptions.exists())
with self.assertNumQueries(5):
response = self.client.post(reverse('api:transcription-bulk'), {
"worker_version": str(self.worker_version.id),
"transcriptions": [
{
"element_id": str(element1.id),
"type": TranscriptionType.Word.value,
"text": "Sneasel",
"score": 0.54
},
{
"element_id": str(element2.id),
"type": TranscriptionType.Line.value,
"text": "Charizard",
"score": 0.85
},
{
"element_id": str(element1.id),
"type": TranscriptionType.Word.value,
"text": "Raticate",
"score": 0.12
},
],
}, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertCountEqual(
list(element1.transcriptions.values('type', 'text', 'score')),
[
{
"type": TranscriptionType.Word,
"text": "Sneasel",
"score": 0.54
},
{
"type": TranscriptionType.Word,
"text": "Raticate",
"score": 0.12
},
]
)
self.assertCountEqual(
list(element2.transcriptions.values('type', 'text', 'score')),
[{
"type": TranscriptionType.Line,
"text": "Charizard",
"score": 0.85
}]
)
......@@ -283,14 +283,6 @@ paths:
description: Update the text of a manual transcription
delete:
description: Delete a manual transcription
/api/v1/transcription/bulk/:
post:
operationId: CreateTranscriptions
put:
operationId: UpdateTranscriptions
description: >-
Replace all existing transcriptions from a given recognizer on a page
with other transcriptions.
/api/v1/metadata/{id}/:
get:
operationId: RetrieveMetaData
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment