diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 71408a746b3f05fc7d2f479146930947c6e54afc..783ecd1f350f857bb433d001d9249a98e8788115 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -8,7 +8,6 @@ from django.core.mail import send_mail from django.db import transaction from django.db.models import ( Avg, - CharField, Count, DurationField, Exists, @@ -21,7 +20,7 @@ from django.db.models import ( Q, Value, ) -from django.db.models.functions import Cast, Coalesce, Concat, Now, NullIf +from django.db.models.functions import Coalesce, Now from django.db.models.query import Prefetch from django.shortcuts import get_object_or_404 from django.template.loader import render_to_string @@ -108,6 +107,7 @@ from arkindex.process.serializers.workers import ( WorkerVersionCreateSerializer, WorkerVersionSerializer, ) +from arkindex.process.utils import annotate_image_url from arkindex.project.aws import get_ingest_resource from arkindex.project.fields import ArrayRemove from arkindex.project.mixins import ( @@ -121,7 +121,7 @@ from arkindex.project.mixins import ( from arkindex.project.openapi import UUID_OR_STR from arkindex.project.pagination import CustomCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly -from arkindex.project.tools import PercentileCont, RTrimChr +from arkindex.project.tools import PercentileCont from arkindex.project.triggers import process_delete from arkindex.training.models import Dataset from arkindex.training.serializers import DatasetSerializer @@ -1610,20 +1610,7 @@ class ListProcessElements(CorpusACLMixin, ListAPIView): if not self.with_image: return queryset.values('id', 'type_id', 'name', 'confidence') - return queryset.annotate( - # Build the image URL by concatenating the server's URL to the image's path - # Server URLs might end in a slash, but not all the time, - # so we strip any trailing slashes and add our own. - # This uses NullIf as this concatenation will return '/' if there is no zone on the element. - image_url=NullIf( - Concat( - RTrimChr('image__server__url', Value('/')), - Value('/'), - Cast('image__path', CharField()), - ), - Value('/'), - ) - ).values( + return queryset.annotate(image_url=annotate_image_url()).values( 'id', 'type_id', 'name', diff --git a/arkindex/process/utils.py b/arkindex/process/utils.py index 71f1b5060ddc2cc895e9b46bb815810b1a4ded8f..d79ec264c26c606df6e33f0552e5ecda49a945ab 100644 --- a/arkindex/process/utils.py +++ b/arkindex/process/utils.py @@ -2,6 +2,10 @@ import json from hashlib import md5 from django.conf import settings +from django.db.models import CharField, Value +from django.db.models.functions import Cast, Concat, NullIf + +from arkindex.project.tools import RTrimChr __default_farm_id = None @@ -31,3 +35,20 @@ def get_default_farm_id(): def hash_object(object): object_json = json.dumps(object, sort_keys=True).encode('utf-8') return md5(object_json).hexdigest() + + +def annotate_image_url(field="image"): + """ + Build the image URL by concatenating the server's URL to the image's path + Server URLs might end in a slash, but not all the time, + so we strip any trailing slashes and add our own. + This uses NullIf as this concatenation will return '/' if there is no zone on the element. + """ + return NullIf( + Concat( + RTrimChr(f'{field}__server__url', Value('/')), + Value('/'), + Cast(f'{field}__path', CharField()), + ), + Value('/'), + ) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 679f03a3d58b817c40ccd23da09455a7ec2d86c6..00ac21114ccc44b47346574823702238490d7a1b 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -120,6 +120,7 @@ from arkindex.process.api import ( from arkindex.project.openapi import OpenApiSchemaView from arkindex.training.api import ( CorpusDataset, + DatasetElements, DatasetUpdate, MetricValueBulkCreate, MetricValueCreate, @@ -206,6 +207,7 @@ api = [ # Datasets path('corpus/<uuid:pk>/datasets/', CorpusDataset.as_view(), name='corpus-datasets'), path('datasets/<uuid:pk>/', DatasetUpdate.as_view(), name='dataset-update'), + path('datasets/<uuid:pk>/elements/', DatasetElements.as_view(), name='dataset-elements'), # Moderation path('classifications/', ClassificationCreate.as_view(), name='classification-create'), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index afd97c04fa275d38727ee8f4de6c70962a72ece9..d780898ea1fca2070c26fb341cf0e0096a1e609e 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -10,6 +10,7 @@ from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.generics import ( CreateAPIView, GenericAPIView, + ListAPIView, ListCreateAPIView, RetrieveAPIView, RetrieveUpdateDestroyAPIView, @@ -17,11 +18,22 @@ from rest_framework.generics import ( from rest_framework.response import Response from arkindex.documents.models import Corpus +from arkindex.process.utils import annotate_image_url from arkindex.project.mixins import ACLMixin, CorpusACLMixin, TrainingModelMixin +from arkindex.project.pagination import CustomCursorPagination from arkindex.project.permissions import IsVerified -from arkindex.training.models import Dataset, DatasetState, MetricValue, Model, ModelVersion, ModelVersionState +from arkindex.training.models import ( + Dataset, + DatasetElement, + DatasetState, + MetricValue, + Model, + ModelVersion, + ModelVersionState, +) from arkindex.training.serializers import ( CreateModelErrorResponseSerializer, + DatasetElementSerializer, DatasetLightSerializer, DatasetSerializer, MetricValueBulkSerializer, @@ -507,3 +519,61 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): raise PermissionDenied(detail='Only datasets in an open state can be deleted.') if not self.has_access(obj.corpus, role.value): raise PermissionDenied(detail=f'You do not have {str(role).lower()} access to corpus {obj.corpus.name}.') + + +@extend_schema(tags=['datasets']) +class DatasetElements(CorpusACLMixin, ListAPIView): + """ + List all elements in a dataset.\n\n + Requires a **guest** access to the dataset corpus. + """ + permission_classes = (IsVerified, ) + queryset = DatasetElement.objects.none() + serializer_class = DatasetElementSerializer + + @property + def paginator(self): + """ + Use a cursor pagination, ordering results by set then element id + """ + if not hasattr(self, '_paginator'): + self._paginator = CustomCursorPagination(ordering=('set', 'element_id')) + return self._paginator + + def paginate_queryset(self, queryset): + """ + Manually annotate the related elements with attributes for the nested serializer + """ + page = super().paginate_queryset(queryset) + for dataset_elt in page: + dataset_elt.element.image_url = dataset_elt.image_url + dataset_elt.element.image__width = dataset_elt.element.image.width + dataset_elt.element.image__height = dataset_elt.element.image.height + return page + + def get_queryset(self): + dataset = get_object_or_404( + Dataset.objects.select_related('corpus'), + id=self.kwargs['pk'], + ) + if not self.has_read_access(dataset.corpus): + raise PermissionDenied(detail='You do not have access to the corpus of this dataset') + return ( + dataset.dataset_elements + .select_related('element__image') + .annotate(image_url=annotate_image_url(field='element__image')) + .only( + 'id', + 'set', + 'dataset_id', + 'element__name', + 'element__type_id', + 'element__confidence', + 'element__image_id', + 'element__image__width', + 'element__image__height', + 'element__polygon', + 'element__rotation_angle', + 'element__mirrored', + ) + ) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index f2fb0d436221f5bd8bd2a8ca8231d6bb87b537b8..1de611dca9af0b614322c9f5fab4355ecfcaabe1 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -10,10 +10,12 @@ from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.validators import UniqueTogetherValidator from arkindex.ponos.models import Task +from arkindex.process.serializers.imports import ProcessElementSerializer from arkindex.project.mixins import TrainingModelMixin from arkindex.project.serializer_fields import EnumField from arkindex.training.models import ( Dataset, + DatasetElement, DatasetState, MetricKey, MetricMode, @@ -435,3 +437,12 @@ class DatasetSerializer(DatasetLightSerializer): if request and not isinstance(request.auth, Task) and data.get('state'): del data['state'] return super().validate(data) + + +class DatasetElementSerializer(serializers.ModelSerializer): + element = ProcessElementSerializer(read_only=True, allow_null=False) + + class Meta: + model = DatasetElement + fields = ('set', 'element') + read_only_fields = fields diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index d88d3a9a4e1cbe255b8d4164c335205932133268..b347c46e995f1b80ecc8af1eb86af5e6ea503aba 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import patch from django.urls import reverse @@ -32,6 +33,9 @@ class TestDatasetsAPI(FixtureAPITestCase): cls.private_dataset = Dataset.objects.create(name="Private Dataset", description="Dead Sea Scrolls", corpus=cls.private_corpus, creator=cls.dataset_creator) cls.dataset.task = cls.task cls.dataset.save() + cls.page1 = cls.corpus.elements.get(name='Volume 1, page 1r') + cls.page2 = cls.corpus.elements.get(name='Volume 1, page 1v') + cls.page3 = cls.corpus.elements.get(name='Volume 1, page 2r') # ListCorpusDatasets @@ -871,3 +875,88 @@ class TestDatasetsAPI(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "Only datasets in an open state can be deleted."}) + + def test_list_elements_requires_login(self): + with self.assertNumQueries(0): + response = self.client.get(reverse('api:dataset-elements', kwargs={'pk': str(self.dataset.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_list_elements_forbidden_methods(self): + self.client.force_login(self.user) + forbidden_methods = ('post', 'patch', 'put', 'delete') + for method in forbidden_methods: + with self.subTest(method=method): + response = getattr(self.client, method)(reverse('api:dataset-elements', kwargs={'pk': str(self.dataset.id)})) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_list_elements_invalid_dataset_id(self): + self.client.force_login(self.user) + with self.assertNumQueries(3): + response = self.client.get(reverse('api:dataset-elements', kwargs={'pk': str(uuid.uuid4())})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_list_elements_readable_corpus(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:dataset-elements', kwargs={'pk': str(self.private_dataset.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_list_elements(self): + self.dataset.dataset_elements.create(element_id=self.page1.id, set="test") + self.dataset.dataset_elements.create(element_id=self.page2.id, set="training") + self.dataset.dataset_elements.create(element_id=self.page3.id, set="training") + self.page1.confidence = 0.42 + self.page1.mirrored = True + self.page1.rotation_angle = 42 + self.page1.save() + first_training_elt = sorted((self.page2, self.page3), key=lambda x: x.id)[0] + + self.client.force_login(self.user) + with self.assertNumQueries(4): + response = self.client.get( + reverse('api:dataset-elements', kwargs={'pk': self.dataset.pk}), + {'page_size': 2}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertTrue('?cursor=' in data['next']) + self.assertIsNone(data['count']) + self.assertListEqual(data['results'], [{ + 'set': 'test', + 'element': { + 'id': str(self.page1.id), + 'confidence': 0.42, + 'type_id': str(self.page1.type_id), + 'image_id': str(self.page1.image.id), + 'image_height': 1000, + 'image_width': 1000, + 'image_url': 'http://server/img1', + 'mirrored': True, + 'name': 'Volume 1, page 1r', + 'polygon': [[0, 0], + [0, 1000], + [1000, 1000], + [1000, 0], + [0, 0]], + 'rotation_angle': 42, + }, + }, { + 'set': 'training', + 'element': { + 'id': str(first_training_elt.id), + 'confidence': None, + 'type_id': str(first_training_elt.type_id), + 'image_id': str(first_training_elt.image.id), + 'image_height': first_training_elt.image.height, + 'image_width': first_training_elt.image.width, + 'image_url': f'http://server/{first_training_elt.image.path}', + 'mirrored': False, + 'name': first_training_elt.name, + 'polygon': [[0, 0], + [0, 1000], + [1000, 1000], + [1000, 0], + [0, 0]], + 'rotation_angle': 0, + }, + }])