Skip to content
Snippets Groups Projects
Commit 3ee6fbe5 authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'recursive-list-transcriptions' into 'master'

Recursive parameter in ListElementTranscriptions endpoint

Closes #375

See merge request !881
parents 9ae3df10 df7a849e
No related branches found
No related tags found
1 merge request!881Recursive parameter in ListElementTranscriptions endpoint
from django.conf import settings
from django.db.models import Q, Prefetch, prefetch_related_objects, Count
from rest_framework.exceptions import ValidationError, NotFound
from rest_framework.generics import \
ListAPIView, CreateAPIView, DestroyAPIView, ListCreateAPIView, RetrieveUpdateDestroyAPIView
from rest_framework.generics import (
ListAPIView, CreateAPIView, DestroyAPIView, ListCreateAPIView, RetrieveUpdateDestroyAPIView,
get_object_or_404
)
from rest_framework import status, serializers
from rest_framework.response import Response
from arkindex_common.enums import TranscriptionType
......@@ -17,7 +19,7 @@ from arkindex.documents.serializers.elements import (
)
from arkindex.project.openapi import AutoSchema
from arkindex.documents.serializers.light import CorpusAllowedMetaDataSerializer
from arkindex.documents.serializers.ml import TranscriptionSerializer
from arkindex.documents.serializers.ml import ElementTranscriptionSerializer
from arkindex.project.mixins import CorpusACLMixin, SelectionMixin
from arkindex.project.pagination import PageNumberPagination
from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly, IsAuthenticated
......@@ -657,9 +659,11 @@ class TranscriptionsPagination(PageNumberPagination):
class ElementTranscriptions(ListAPIView):
"""
List all transcriptions for an element, optionally filtered by type
List all transcriptions for an element, optionally filtered by type.
Recursive parameter allow listing transcriptions on sub-elements,
otherwise element fields in the response will be set to null.
"""
serializer_class = TranscriptionSerializer
serializer_class = ElementTranscriptionSerializer
pagination_class = TranscriptionsPagination
openapi_overrides = {
'security': [],
......@@ -674,15 +678,59 @@ class ElementTranscriptions(ListAPIView):
'type': 'string',
'enum': [ts_type.value for ts_type in TranscriptionType],
}
}, {
'name': 'recursive',
'in': 'query',
'required': False,
'description': 'Recursively list transcriptions on sub-elements',
'schema': {
'type': 'boolean',
}
},
]
}
@property
def is_recursive(self):
if not self.request:
return
recursive = self.request.query_params.get('recursive')
return recursive and recursive not in ('false', 0)
def get_serializer_context(self):
context = super().get_serializer_context()
# Do serialize the element attached to each transcription in recursive mode only
context['ignore_element'] = not self.is_recursive
return context
def check_object_permissions(self, request, element):
super().check_object_permissions(request, element)
if element.type.folder:
self.permission_denied(request, message='Element is a folder')
def get_queryset(self):
queryset = Transcription.objects.filter(
element_id=self.kwargs['pk'],
element__corpus__in=Corpus.objects.readable(self.request.user),
).prefetch_related('zone__image__server', 'source').order_by('id')
element = get_object_or_404(Element.objects.filter(
id=self.kwargs['pk'],
corpus__in=Corpus.objects.readable(self.request.user)
))
self.check_object_permissions(self.request, element)
queryset = Transcription.objects \
.prefetch_related('zone__image__server', 'source') \
.extra(
# ORDER BY casting IDs as char to avoid PostgreSQL optimizer inefficient scan
select={'char_id': 'CAST(id AS CHAR(36))'},
order_by=['char_id']
)
if self.is_recursive:
queryset = queryset.filter(
# Retrieve both element and sub-elements transcriptions
element__in=[element.id, *Element.objects.get_descending(element.id).values_list('id')]
).prefetch_related('element__type', 'element__zone')
else:
queryset = queryset.filter(element_id=element.id)
req_type = self.request.query_params.get('type')
if req_type:
try:
......
......@@ -4,6 +4,7 @@ from django.db.models import Max
from arkindex.documents.models import Element, ElementType, Corpus, MetaData, AllowedMetaData
from arkindex_common.enums import MetaType, TranscriptionType
from arkindex.documents.dates import DateType
from arkindex.images.serializers import ZoneLightSerializer
from arkindex.dataimport.serializers.git import RevisionSerializer
from arkindex.project.serializer_fields import EnumField
from arkindex.project.triggers import reindex_start
......@@ -42,6 +43,17 @@ class ElementLightSerializer(serializers.ModelSerializer):
)
class ElementZoneSerializer(ElementLightSerializer):
"""
Lightly serialises an element with its type and zone
"""
zone = ZoneLightSerializer()
class Meta(ElementLightSerializer.Meta):
model = Element
fields = ElementLightSerializer.Meta.fields + ('zone', )
class ElementTypeSerializer(serializers.ModelSerializer):
allowed_transcription = EnumField(TranscriptionType)
......
......@@ -5,12 +5,13 @@ from rest_framework.validators import UniqueTogetherValidator
from rest_framework.exceptions import ValidationError
from arkindex_common.ml_tool import MLToolType
from arkindex_common.enums import TranscriptionType
from arkindex.project.serializer_fields import EnumField, DataSourceSlugField, PolygonField
from arkindex.dataimport.models import WorkerVersion
from arkindex.documents.models import (
Corpus, Element, ElementType, Transcription, DataSource, MLClass, Classification, ClassificationState
)
from arkindex.project.serializer_fields import EnumField, DataSourceSlugField, PolygonField
from arkindex.images.serializers import ZoneSerializer
from arkindex.documents.serializers.light import ElementZoneSerializer
class ClassificationMode(Enum):
......@@ -247,6 +248,22 @@ class TranscriptionSerializer(serializers.ModelSerializer):
)
class ElementTranscriptionSerializer(TranscriptionSerializer):
"""
Serialises a transcription with its element basic informations (e.g. image zone)
"""
element = ElementZoneSerializer(allow_null=True)
class Meta(TranscriptionSerializer.Meta):
fields = TranscriptionSerializer.Meta.fields + ('element', )
def to_representation(self, obj):
if self.context.get('ignore_element'):
# Skip transcription element zone serialization
obj.element = None
return super().to_representation(obj)
class TranscriptionCreateSerializer(serializers.ModelSerializer):
"""
Allows the insertion of a manual transcription attached to an element
......
......@@ -3,7 +3,8 @@ from rest_framework import status
from arkindex.project.tests import FixtureAPITestCase
from arkindex.project.polygon import Polygon
from arkindex_common.enums import TranscriptionType
from arkindex.documents.models import DataSource
from arkindex.documents.models import Corpus, DataSource
from arkindex.users.models import User
class TestTranscriptions(FixtureAPITestCase):
......@@ -15,24 +16,119 @@ class TestTranscriptions(FixtureAPITestCase):
def setUpTestData(cls):
super().setUpTestData()
cls.page = cls.corpus.elements.get(name='Volume 1, page 1r')
cls.volume = cls.corpus.elements.get(name='Volume 1')
cls.line = cls.corpus.elements.get(name='Text line')
cls.private_corpus = Corpus.objects.create(name='Private')
cls.private_page = cls.private_corpus.elements.create(type=cls.page.type)
cls.src = DataSource.objects.get(slug='test')
# Create an user with a read right only on the private corpus
cls.private_read_user = User.objects.create_user('a@bc.de', 'a')
cls.private_read_user.verified_email = True
cls.private_read_user.save()
def test_get_transcriptions(self):
def test_list_transcriptions_read_right(self):
# A read right on the element corpus is required to access transcriptions
self.client.force_login(self.private_read_user)
url = reverse('api:element-transcriptions', kwargs={'pk': str(self.private_page.id)})
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_list_transcriptions_non_folder(self):
# Transcriptions should be listed for non folder elements only
self.client.force_login(self.user)
url = reverse('api:element-transcriptions', kwargs={'pk': str(self.volume.id)})
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'Element is a folder'})
def test_list_transcriptions_wrong_type(self):
# Wrong transcription type
url = reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)})
url += '?type=potato'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_list_element_transcriptions(self):
self.page.transcriptions.all().delete()
transcriptions = []
for i in range(10):
zone, _ = self.page.zone.image.zones.get_or_create(polygon=Polygon.from_coords(0, 0, i + 1, i + 1))
# Create transcriptions on the page with their own zones
transcriptions.append(self.page.transcriptions.create(
source_id=DataSource.objects.get(slug='test').id,
type=TranscriptionType.Word,
zone=zone,
source_id=self.src.id, type=TranscriptionType.Word, zone=zone,
))
self.client.force_login(self.user)
response = self.client.get(reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}))
with self.assertNumQueries(10):
response = self.client.get(reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.json()['results']), 10)
self.assertEqual(len(transcriptions), 10)
# Wrong transcription type
url = reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)})
url += '?type=potato'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
results = response.json()['results']
self.assertEqual(len(results), 10)
self.assertCountEqual(
[(tr['id'], tr['element']) for tr in results],
# Element should not be serialized in case recursive parameter is not set
[(str(tr.id), None) for tr in transcriptions]
)
def test_list_transcriptions_recursive(self):
for i in range(1, 5):
# Add 4 transcriptions on the page line
self.line.transcriptions.create(source_id=self.src.id, type=TranscriptionType.Line, text=f'Text {i}')
for i in range(1, 5):
# Add 4 transcribed line children
zone, _ = self.page.zone.image.zones.get_or_create(polygon=Polygon.from_coords(0, 0, i + 1, i + 1))
line = self.page.corpus.elements.create(zone=zone, type=self.line.type, name=f'Added line {i}')
line.transcriptions.create(source_id=self.src.id, type=TranscriptionType.Line, text=f'Added text {i}')
line.add_parent(self.page)
self.client.force_login(self.user)
with self.assertNumQueries(14):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true'}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.json()['results']
self.assertEqual(len(results), 12)
page_polygon = [[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]]
line_polygon = [[400, 400], [400, 500], [500, 500], [500, 400], [400, 400]]
self.assertCountEqual(
[(data['element']['type'], data['element']['zone']['polygon'], data['text']) for data in results],
[
('page', page_polygon, 'PARIS'),
('page', page_polygon, 'ROY'),
('page', page_polygon, 'Lorem ipsum dolor sit amet'),
('page', page_polygon, 'DATUM'),
('text_line', line_polygon, 'Text 1'),
('text_line', line_polygon, 'Text 2'),
('text_line', line_polygon, 'Text 3'),
('text_line', line_polygon, 'Text 4'),
('text_line', [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]], 'Added text 1'),
('text_line', [[0, 0], [0, 3], [3, 3], [3, 0], [0, 0]], 'Added text 2'),
('text_line', [[0, 0], [0, 4], [4, 4], [4, 0], [0, 0]], 'Added text 3'),
('text_line', [[0, 0], [0, 5], [5, 5], [5, 0], [0, 0]], 'Added text 4')
]
)
def test_list_transcriptions_recursive_filtered(self):
for i in range(1, 5):
# Add 4 transcriptions on the page line
self.line.transcriptions.create(source_id=self.src.id, type=TranscriptionType.Line, text=f'Text {i}')
self.client.force_login(self.user)
with self.assertNumQueries(12):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true', 'type': 'line'}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.json()['results']
self.assertEqual(len(results), 4)
for tr in results:
self.assertEqual(tr.get('type'), 'line')
......@@ -230,17 +230,25 @@ class ImageUploadSerializer(ImageSerializer):
return obj.s3_put_url
class ZoneSerializer(serializers.ModelSerializer):
class ZoneLightSerializer(serializers.ModelSerializer):
"""
Serialize a zone by its polygon and image ID only
"""
polygon = PolygonField()
class Meta:
model = Zone
fields = ('id', 'polygon')
class ZoneSerializer(ZoneLightSerializer):
"""
Serialize a complete zone with its computed center, url and image informations
"""
center = PointField(source='polygon.center')
# Override the field to fully serialize the image
image = ImageSerializer()
class Meta:
class Meta(ZoneLightSerializer.Meta):
model = Zone
fields = (
'id',
'polygon',
'center',
'url',
'image',
)
fields = ZoneLightSerializer.Meta.fields + ('image', 'url', 'polygon', 'center')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment