diff --git a/arkindex/documents/admin.py b/arkindex/documents/admin.py index 3a2353cab71e90d673ad20366f362e8f49b3d6bf..f4ea14c6855b695383c7ca21697541b61f423e85 100644 --- a/arkindex/documents/admin.py +++ b/arkindex/documents/admin.py @@ -4,8 +4,8 @@ from django.urls import path, reverse from django.utils.html import format_html from django_admin_hstore_widget.forms import HStoreFormField from arkindex.documents.models import \ - Corpus, Page, Element, ElementType, Act, Transcription, MetaData, InterpretedDate, Classification, DataSource, \ - Entity, EntityRole, EntityLink + Corpus, Page, Element, ElementType, Act, Transcription, MetaData, InterpretedDate, MLClass, Classification, \ + DataSource, Entity, EntityRole, EntityLink from arkindex.documents.views import DumpActs from arkindex.dataimport.models import Event from enumfields.admin import EnumFieldListFilter @@ -99,6 +99,12 @@ class TranscriptionAdmin(admin.ModelAdmin): raw_id_fields = ('element', 'zone', ) +class MLClassAdmin(admin.ModelAdmin): + list_display = ('id', 'name', 'corpus') + list_filter = ('corpus',) + fields = ('name', 'corpus') + + class EntityMetaForm(forms.ModelForm): metas = HStoreFormField() @@ -130,6 +136,7 @@ admin.site.register(Page, PageAdmin) admin.site.register(Element, ElementAdmin) admin.site.register(Act, ActAdmin) admin.site.register(Transcription, TranscriptionAdmin) +admin.site.register(MLClass, MLClassAdmin) admin.site.register(MetaData, MetaDataAdmin) admin.site.register(Entity, EntityAdmin) admin.site.register(EntityRole, EntityRoleAdmin) diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index 1422462a05ce454c8583bf837eb36c142674268a..064d51f4a0488eb0f927c3563346a61834f72337 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -1,23 +1,34 @@ +from django.db.models import Prefetch, prefetch_related_objects +from django_filters import rest_framework as filters +from rest_framework import status +from rest_framework.exceptions import ValidationError from rest_framework.generics import ( - ListAPIView, ListCreateAPIView, RetrieveAPIView, RetrieveUpdateAPIView, - RetrieveUpdateDestroyAPIView, CreateAPIView, + GenericAPIView, CreateAPIView, ListAPIView, ListCreateAPIView, RetrieveAPIView, RetrieveUpdateAPIView, + RetrieveUpdateDestroyAPIView) +from rest_framework.response import Response + +from arkindex.documents.models import ( + Act, Classification, ClassificationState, Corpus, DataSource, Element, ElementType, Page, Right, Transcription, ) -from arkindex.project.pagination import PageNumberPagination -from rest_framework.exceptions import ValidationError -from django.db.models import prefetch_related_objects -from arkindex.documents.serializers.light import PageLightSerializer from arkindex.documents.serializers.elements import ( - ElementSerializer, ElementSlimSerializer, CorpusSerializer, PageSerializer, - ActSerializer, SurfaceSerializer, ElementCreateSerializer -) -from arkindex.documents.serializers.ml import TranscriptionSerializer + ActSerializer, CorpusSerializer, ElementCreateSerializer, ElementSerializer, ElementSlimSerializer, + PageSerializer, SurfaceSerializer) +from arkindex.documents.serializers.light import PageLightSerializer from arkindex_common.enums import TranscriptionType -from arkindex.documents.models import ( - Element, ElementType, Page, Act, Transcription, - Corpus, Right +from arkindex.documents.serializers.ml import ( + ClassificationCreateSerializer, ClassificationSerializer, TranscriptionSerializer) +from arkindex.project.mixins import CorpusACLMixin, NestedCorpusMixin +from arkindex.project.pagination import PageNumberPagination +from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly +from arkindex_common.ml_tool import MLToolType + +from .filters import PageFilter + + +classifications_prefetch = Prefetch( + 'classifications', + queryset=Classification.objects.select_related('ml_class', 'source').order_by('-confidence') ) -from arkindex.project.mixins import CorpusACLMixin -from arkindex.project.permissions import IsVerifiedOrReadOnly, IsVerified class ElementsList(CorpusACLMixin, ListAPIView): @@ -87,7 +98,7 @@ class ElementPages(ListAPIView): self.kwargs['pk'], corpus__in=Corpus.objects.readable(self.request.user), zone__isnull=False, - ).prefetch_related('zone__image__server') + ).prefetch_related('zone__image__server', classifications_prefetch) class ElementSurfaces(ListAPIView): @@ -131,14 +142,22 @@ class CorpusRetrieve(RetrieveUpdateDestroyAPIView): self.permission_denied(request, message='You do not have write access to this corpus.') -class CorpusPages(CorpusACLMixin, ListAPIView): +class CorpusPages(NestedCorpusMixin, ListAPIView): """ List all pages in a corpus """ serializer_class = PageLightSerializer + filter_backends = (filters.DjangoFilterBackend,) + filterset_class = PageFilter def get_queryset(self): - return Page.objects.filter(corpus=self.get_corpus(self.kwargs['pk'])).select_related('zone__image__server') + return Page.objects.filter(corpus=self.corpus) + + def filter_queryset(self, queryset): + return Page.objects.filter(pk__in=super().filter_queryset(queryset)) \ + .select_related('zone__image__server') \ + .prefetch_related(classifications_prefetch) \ + .order_by('paths__ordering') class PageDetails(RetrieveAPIView): @@ -232,3 +251,53 @@ class ElementsCreate(CreateAPIView): """ permission_classes = (IsVerified, ) serializer_class = ElementCreateSerializer + + +class ClassificationCreate(CreateAPIView): + """ + Create a classification for a specific page + """ + serializer_class = ClassificationCreateSerializer + queryset = Classification.objects.all() + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + data_source, __ = DataSource.objects.get_or_create( + type=MLToolType.Classifier, slug='manual', defaults={'revision': '', 'internal': False}) + serializer.save( + source=data_source, + moderator=self.request.user, + state=ClassificationState.Validated, + confidence=1 + ) + headers = self.get_success_headers(serializer.data) + response_serializer = ClassificationSerializer(instance=serializer.instance) + return Response(response_serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + +class ClassificationModerationActionsMixin(GenericAPIView): + serializer_class = ClassificationSerializer + + def get_queryset(self): + return Classification.objects.filter(page__corpus__in=Corpus.objects.writable(self.request.user)) + + def put(self, request, *args, **kwargs): + instance = self.get_object() + instance.moderator = self.request.user + instance.state = self.moderation_action + instance.save(update_fields=['moderator', 'state']) + serializer = self.get_serializer(instance) + return Response(serializer.data, status=status.HTTP_202_ACCEPTED) + + +class ClassificationValidate(ClassificationModerationActionsMixin): + """Reject a classification""" + moderation_action = ClassificationState.Validated + action = 'Validate' # For OpenAPI operationId -> ValidateClassification + + +class ClassificationReject(ClassificationModerationActionsMixin): + """Reject a classification""" + moderation_action = ClassificationState.Rejected + action = 'Reject' # For OpenAPI operationId -> RejectClassification diff --git a/arkindex/documents/api/entities.py b/arkindex/documents/api/entities.py index 557d0457569174a503f95c66607c2abc70d8543b..16e6ec4005234371f0ce6555390354144c73dea0 100644 --- a/arkindex/documents/api/entities.py +++ b/arkindex/documents/api/entities.py @@ -1,5 +1,5 @@ -from arkindex.documents.models import \ - EntityRole, Entity, Corpus, Element, Transcription, TranscriptionEntity +from django_filters import rest_framework as filters +from arkindex.documents.models import EntityRole, Entity, Corpus, Element, MLClass, Transcription, TranscriptionEntity from arkindex.project.mixins import CorpusACLMixin from rest_framework.generics import \ ListAPIView, ListCreateAPIView, RetrieveAPIView, CreateAPIView @@ -10,8 +10,12 @@ from arkindex.documents.serializers.entities import ( from arkindex.documents.serializers.elements import ElementSerializer from rest_framework import serializers from django.core.exceptions import ValidationError +from arkindex.documents.serializers.ml import MLClassSerializer +from arkindex.project.mixins import NestedCorpusMixin from arkindex.project.permissions import IsVerified +from .filters import MLClassFilter + class CorpusRoles(CorpusACLMixin, ListCreateAPIView): """ @@ -40,6 +44,27 @@ class CorpusRoles(CorpusACLMixin, ListCreateAPIView): super().perform_create(serializer) +class CorpusMLClassList(NestedCorpusMixin, ListAPIView): + """ + List all classes in a corpus + """ + serializer_class = MLClassSerializer + filter_backends = (filters.DjangoFilterBackend,) + filterset_class = MLClassFilter + action = 'ListCorpus' # For OpenAPI operationId -> ListCorpusMLClass + + def get_queryset(self): + return MLClass.objects.filter(corpus=self.corpus) + + +class MLClassList(ListAPIView): + """ + List all available classes + """ + serializer_class = MLClassSerializer + queryset = MLClass.objects.all() + + class EntityDetails(RetrieveAPIView): """ Get all information about entity diff --git a/arkindex/documents/api/filters.py b/arkindex/documents/api/filters.py new file mode 100644 index 0000000000000000000000000000000000000000..548cde7b42c86ff51cfc5131f57a4097989cb3ad --- /dev/null +++ b/arkindex/documents/api/filters.py @@ -0,0 +1,38 @@ +from django.db.models import Q, Count +from django_filters import rest_framework as filters + +from arkindex.documents.models import MLClass, Element, ElementType, Page, ClassificationState + + +class MLClassVolumeChoiceFilter(filters.ModelChoiceFilter): + def get_queryset(self, request): + return Element.objects.filter(type=ElementType.Volume, corpus=request.corpus) + + +class MLClassFilter(filters.FilterSet): + volume = MLClassVolumeChoiceFilter(method='filter_volume', label='Volume') + + def filter_volume(self, queryset, name, value): + return queryset.filter(classifications__page__in=Page.objects.get_descending(value.pk)).distinct() + + class Meta: + model = MLClass + fields = ['volume'] + + +class PageFilter(filters.FilterSet): + ml_class = filters.ModelChoiceFilter( + queryset=MLClass.objects.all(), field_name='classifications__ml_class', label='Class') + ml_class_unvalidated = filters.BooleanFilter( + field_name='classifications__ml_class', method='filter_ml_class_unvalidated') + + def filter_ml_class_unvalidated(self, queryset, name, value): + base_qs = queryset.annotate( + num_validated=Count('classifications', filter=Q(classifications__state=ClassificationState.Validated))) + if value: + return base_qs.filter(num_validated=0) + return base_qs.exclude(num_validated=0) + + class Meta: + model = Page + fields = ['ml_class', 'ml_class_unvalidated'] diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index 92509bedccac883827e904736441beb3ef7fdbdd..fc057e1e9d4920020318108c2ed44ed77dba81e5 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -166,7 +166,7 @@ class ClassificationBulk(CreateAPIView): Classification( page=parent, source=source, - class_name=cl['class_name'], + ml_class=cl['ml_class'], confidence=cl['confidence'], ) for cl in serializer.validated_data['classifications'] diff --git a/arkindex/documents/migrations/0007_auto_20190513_1508.py b/arkindex/documents/migrations/0007_auto_20190513_1508.py new file mode 100644 index 0000000000000000000000000000000000000000..c33574f75e1246eb3f670f17172a78e99e134631 --- /dev/null +++ b/arkindex/documents/migrations/0007_auto_20190513_1508.py @@ -0,0 +1,70 @@ +# Generated by Django 2.2 on 2019-05-13 15:08 + +import arkindex.documents.models +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import enumfields.fields +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("documents", "0006_transcribed_entities_table"), + ] + + operations = [ + migrations.AddField( + model_name="classification", + name="moderator", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="classifications", + to=settings.AUTH_USER_MODEL, + ), + ), + migrations.AddField( + model_name="classification", + name="state", + field=enumfields.fields.EnumField( + db_index=True, default="pending", enum=arkindex.documents.models.ClassificationState, max_length=16 + ), + ), + migrations.CreateModel( + name="MLClass", + fields=[ + ("id", models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ("name", models.CharField(max_length=100)), + ( + "corpus", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="ml_classes", to="documents.Corpus" + ), + ), + ], + options={ + "verbose_name_plural": "classes", + "unique_together": {("name", "corpus")}, + "ordering": ("corpus", "name"), + }, + ), + migrations.AddField( + model_name="classification", + name="ml_class", + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="classifications", + to="documents.MLClass", + ), + ), + migrations.AlterField( + model_name="classification", + name="class_name", + field=models.CharField(max_length=100, null=True), + ), + ] diff --git a/arkindex/documents/migrations/0008_auto_20190513_1508.py b/arkindex/documents/migrations/0008_auto_20190513_1508.py new file mode 100644 index 0000000000000000000000000000000000000000..36e2011de0b90f48da56ee80cf715a30c7842e24 --- /dev/null +++ b/arkindex/documents/migrations/0008_auto_20190513_1508.py @@ -0,0 +1,42 @@ +# Generated by Django 2.2 on 2019-05-13 15:08 + +from django.db import migrations + + +def migrate_classes_to_dedicated_model_forward(apps, schema_editor): + Classification = apps.get_model('documents', 'Classification') + objs = [] + ml_class_cache = {} + for classification in Classification.objects.all().select_related('page__corpus'): + # Cache MLClass to avoid extra requests + try: + ml_class = ml_class_cache[classification.page.corpus.pk, classification.class_name] + except KeyError: + ml_class, __ = apps.get_model('documents', 'MLClass').objects.get_or_create( + name=classification.class_name, + corpus=classification.page.corpus + ) + ml_class_cache[classification.page.corpus.pk, classification.class_name] = ml_class + classification.ml_class = ml_class + objs.append(classification) + Classification.objects.bulk_update(objs, ['ml_class'], batch_size=2000) + + +def migrate_classes_to_dedicated_model_backward(apps, schema_editor): + Classification = apps.get_model('documents', 'Classification') + objs = [] + for classification in Classification.objects.all().select_related('ml_class'): + classification.class_name = classification.ml_class.name + objs.append(classification) + Classification.objects.bulk_update(objs, ['class_name'], batch_size=2000) + + +class Migration(migrations.Migration): + + dependencies = [ + ('documents', '0007_auto_20190513_1508'), + ] + + operations = [ + migrations.RunPython(migrate_classes_to_dedicated_model_forward, migrate_classes_to_dedicated_model_backward) + ] diff --git a/arkindex/documents/migrations/0009_auto_20190513_1510.py b/arkindex/documents/migrations/0009_auto_20190513_1510.py new file mode 100644 index 0000000000000000000000000000000000000000..bf7ada32289f0271d60e2639ad59a05cc14c6d7e --- /dev/null +++ b/arkindex/documents/migrations/0009_auto_20190513_1510.py @@ -0,0 +1,23 @@ +# Generated by Django 2.2 on 2019-05-13 15:10 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("documents", "0008_auto_20190513_1508"), + ] + + operations = [ + migrations.AlterField( + model_name="classification", + name="ml_class", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, related_name="classifications", to="documents.MLClass" + ), + ), + migrations.AlterUniqueTogether(name="classification", unique_together={("page", "source", "ml_class")}), + migrations.RemoveField(model_name="classification", name="class_name"), + ] diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 47629dbd757b703ef7e31a3435099d20611fcf61..4b01bc34ed12999dd0ab307a0f6a78f3ed33b97a 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -417,13 +417,6 @@ class Page(Element): """ return self.transcriptions.filter(type=TranscriptionType.Page) - @property - def best_classes(self): - """ - The three most probable classifications - """ - return self.classifications.order_by('-confidence')[:3] - class Act(Element): """ @@ -662,6 +655,31 @@ class TranscriptionEntity(models.Model): super().clean() +class MLClass(models.Model): + """ + A type of classification + """ + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + name = models.CharField(max_length=100) + corpus = models.ForeignKey('documents.Corpus', on_delete=models.CASCADE, related_name='ml_classes') + + class Meta: + unique_together = ( + ('name', 'corpus'), + ) + verbose_name_plural = 'classes' + ordering = ('corpus', 'name') + + def __str__(self): + return self.name + + +class ClassificationState(Enum): + Pending = 'pending' + Validated = 'validated' + Rejected = 'rejected' + + class Classification(models.Model): """ A result of a classifier on a page @@ -677,12 +695,24 @@ class Classification(models.Model): on_delete=models.CASCADE, related_name='classifications', ) - class_name = models.CharField(max_length=100) + moderator = models.ForeignKey( + 'users.User', + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name='classifications' + ) + ml_class = models.ForeignKey( + MLClass, + on_delete=models.CASCADE, + related_name='classifications' + ) + state = EnumField(ClassificationState, max_length=16, default=ClassificationState.Pending, db_index=True) confidence = models.FloatField(null=True, blank=True) class Meta: unique_together = ( - ('page', 'source', 'class_name'), + ('page', 'source', 'ml_class'), ) diff --git a/arkindex/documents/serializers/light.py b/arkindex/documents/serializers/light.py index 59097ebc8c6d6d3f893acc6a4d8da6b380dd4ee0..5c626493b8458a0cfb79f1bf4c1cfb3edd1ea8db 100644 --- a/arkindex/documents/serializers/light.py +++ b/arkindex/documents/serializers/light.py @@ -28,7 +28,7 @@ class PageLightSerializer(serializers.ModelSerializer): page_type = EnumField(PageType) direction = EnumField(PageDirection) image = ImageSerializer(source='zone.image') - best_classes = ClassificationSerializer(many=True) + classifications = ClassificationSerializer(many=True) class Meta: model = Page @@ -39,7 +39,7 @@ class PageLightSerializer(serializers.ModelSerializer): 'direction', 'display_name', 'image', - 'best_classes', + 'classifications', ) diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 3eb3b54bc79432441f426b26f9f9e2d9bca9ba8f..5db346e563f3ddc03e78eb4914863bc09d7343b3 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -3,7 +3,7 @@ from arkindex_common.ml_tool import MLToolType from arkindex_common.enums import TranscriptionType from arkindex.project.serializer_fields import EnumField, MLToolField, PolygonField from arkindex.documents.models import ( - Corpus, Element, Page, Transcription, DataSource, Classification + Corpus, Element, Page, Transcription, DataSource, MLClass, Classification, ClassificationState ) from arkindex.images.serializers import ZoneSerializer @@ -26,22 +26,55 @@ class DataSourceSerializer(serializers.ModelSerializer): ) +class MLClassSerializer(serializers.ModelSerializer): + """ + Serializer for MLClass instances + Used as a nested endpoint below corpus, so exclude corpus from payload + """ + class Meta: + model = MLClass + exclude = ('corpus',) + + class ClassificationSerializer(serializers.ModelSerializer): """ Serialize a classification on a Page """ source = DataSourceSerializer() + ml_class = MLClassSerializer() + state = EnumField(ClassificationState) class Meta: model = Classification fields = ( + 'id', 'source', - 'class_name', + 'ml_class', + 'state', 'confidence', ) +class ClassificationCreateSerializer(serializers.ModelSerializer): + """ + Specific serializer for the manual classification creation + """ + class Meta: + model = Classification + fields = ( + 'page', + 'ml_class', + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.context.get('request'): # May be None when generating an OpenAPI schema or using from a REPL + return + self.fields['page'].queryset = Page.objects.filter( + corpus__in=Corpus.objects.writable(self.context['request'].user)) + + class TranscriptionSerializer(serializers.ModelSerializer): """ Serialises a Transcription @@ -127,7 +160,7 @@ class ClassificationBulkSerializer(serializers.Serializer): Single classification serializer for bulk insertion Cannot use ModelSerializer as they become read-only when nested """ - class_name = serializers.CharField() + ml_class = serializers.PrimaryKeyRelatedField(queryset=MLClass.objects.all()) confidence = serializers.FloatField(min_value=0, max_value=1) diff --git a/arkindex/documents/tests/commands/test_delete_corpus.py b/arkindex/documents/tests/commands/test_delete_corpus.py index da8a4b8125a188e3181d5f58c2f522c92085cd4a..69c8c54a64de9b9e7b9c81ef6edb2d7928b526c7 100644 --- a/arkindex/documents/tests/commands/test_delete_corpus.py +++ b/arkindex/documents/tests/commands/test_delete_corpus.py @@ -3,8 +3,7 @@ from django.db.models.signals import pre_delete from arkindex_common.enums import TranscriptionType, MetaType, DataImportMode from arkindex_common.ml_tool import MLToolType from arkindex.project.tests import FixtureTestCase - -from arkindex.documents.models import Corpus, Element, Page, ElementType, DataSource +from arkindex.documents.models import Corpus, Element, Page, ElementType, DataSource, MLClass from arkindex.dataimport.models import EventType @@ -63,7 +62,10 @@ class TestDeleteCorpus(FixtureTestCase): revision='Early Access', internal=False, ), - class_name='klass', + ml_class=MLClass.objects.create( + name='klass', + corpus=cls.corpus2, + ), confidence=0.5, ) cls.page.transcriptions.create( @@ -140,7 +142,7 @@ class TestDeleteCorpus(FixtureTestCase): cl = self.page.classifications.get() self.assertEqual(cl.source.slug, 'classeur') - self.assertEqual(cl.class_name, 'klass') + self.assertEqual(cl.ml_class.name, 'klass') self.assertEqual(cl.confidence, 0.5) ts = self.page.transcriptions.get() diff --git a/arkindex/documents/tests/test_bulk_classification.py b/arkindex/documents/tests/test_bulk_classification.py index 8071321fa77677f2d2f1c34bc91b68c4350536a3..191a289c5c2f72066bae12fae75b2ec63300b3b9 100644 --- a/arkindex/documents/tests/test_bulk_classification.py +++ b/arkindex/documents/tests/test_bulk_classification.py @@ -2,7 +2,7 @@ from django.urls import reverse from rest_framework import status from unittest.mock import patch from arkindex.project.tests import FixtureAPITestCase -from arkindex.documents.models import DataSource +from arkindex.documents.models import DataSource, MLClass class TestBulkClassification(FixtureAPITestCase): @@ -37,23 +37,25 @@ class TestBulkClassification(FixtureAPITestCase): def test_bulk_classification(self): self.client.force_login(self.user) + dog_class = MLClass.objects.create(name="dog", corpus=self.corpus) + cat_class = MLClass.objects.create(name="cat", corpus=self.corpus) response = self.client.post(reverse('api:classification-bulk'), format='json', data={ "parent": str(self.page.id), "classifier": self.src.slug, "classifications": [ { - "class_name": "dog", + "ml_class": str(dog_class.pk), "confidence": 0.99, }, { - "class_name": "cat", + "ml_class": str(cat_class.pk), "confidence": 0.42, } ] }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertCountEqual( - list(self.page.classifications.values_list('class_name', 'confidence')), + list(self.page.classifications.values_list('ml_class__name', 'confidence')), [ ('dog', 0.99), ('cat', 0.42), @@ -65,38 +67,42 @@ class TestBulkClassification(FixtureAPITestCase): Test the bulk classification API deletes previous classifications """ self.client.force_login(self.user) + dog_class = MLClass.objects.create(name="dog", corpus=self.corpus) + cat_class = MLClass.objects.create(name="cat", corpus=self.corpus) response = self.client.post(reverse('api:classification-bulk'), format='json', data={ "parent": str(self.page.id), "classifier": self.src.slug, "classifications": [ { - "class_name": "dog", + "ml_class": str(dog_class.pk), "confidence": 0.99, }, { - "class_name": "cat", + "ml_class": str(cat_class.pk), "confidence": 0.42, } ] }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) + doggo_class = MLClass.objects.create(name="doggo", corpus=self.corpus) + catte_class = MLClass.objects.create(name="catte", corpus=self.corpus) response = self.client.post(reverse('api:classification-bulk'), format='json', data={ "parent": str(self.page.id), "classifier": self.src.slug, "classifications": [ { - "class_name": "doggo", + "ml_class": str(doggo_class.pk), "confidence": 0.5, }, { - "class_name": "catte", + "ml_class": str(catte_class.pk), "confidence": 0.85, } ] }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertCountEqual( - list(self.page.classifications.values_list('class_name', 'confidence')), + list(self.page.classifications.values_list('ml_class__name', 'confidence')), [ ('doggo', 0.5), ('catte', 0.85), diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..fde4de99713e79b6972c1314ebec0d35d5a533b0 --- /dev/null +++ b/arkindex/documents/tests/test_classes.py @@ -0,0 +1,162 @@ +from django.urls import reverse +from rest_framework import status + +from arkindex.documents.models import Classification, DataSource, Element, ElementType, MLClass, Page +from arkindex.project.tests import FixtureAPITestCase + + +class TestClasses(FixtureAPITestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.text = MLClass.objects.create(name='text', corpus=cls.corpus) + cls.cover = MLClass.objects.create(name='cover', corpus=cls.corpus) + + def test_list(self): + """ + Test listing results alpha-ordered + """ + self.client.force_login(self.user) + response = self.client.get(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 2, + "number": 1, + "next": None, + "previous": None, + "results": [ + { + "id": str(self.cover.pk), + "name": "cover" + }, + { + "id": str(self.text.pk), + "name": "text" + }, + ] + }) + + # Add other classes to ensure ordering is preserved + self.white_page = MLClass.objects.create(name='white-page', corpus=self.corpus) + self.image = MLClass.objects.create(name='image', corpus=self.corpus) + response = self.client.get(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 4, + "number": 1, + "next": None, + "previous": None, + "results": [ + { + "id": str(self.cover.pk), + "name": "cover" + }, + { + "id": str(self.image.pk), + "name": "image" + }, + { + "id": str(self.text.pk), + "name": "text" + }, + { + "id": str(self.white_page.pk), + "name": "white-page" + }, + ] + }) + + def test_list_with_volume_filter(self): + """ + Test listing results alpha-ordered with a volume filter + """ + self.client.force_login(self.user) + volume = Element.objects.create(corpus=self.corpus, type=ElementType.Volume, name="Custom volume") + data_source = DataSource.objects.get(slug='test') + + # Test with a newly created volume: must not return any MLClass + response = self.client.get(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {'volume': volume.id}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 0, + "number": 1, + "next": None, + "previous": None, + "results": [] + }) + + # Attach page to the volume: must not return any MLClass without classification + page = Page.objects.create(corpus=self.corpus, name='Custom page') + page.add_parent(volume) + response = self.client.get(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {'volume': volume.id}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 0, + "number": 1, + "next": None, + "previous": None, + "results": [] + }) + + # Add classification: must return class + Classification.objects.create(page=page, source=data_source, ml_class=self.text) + response = self.client.get(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {'volume': volume.id}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 1, + "number": 1, + "next": None, + "previous": None, + "results": [ + { + "id": str(self.text.pk), + "name": "text" + } + ] + }) + + # Add other classifications to ensure ordering is preserved + self.white_page = MLClass.objects.create(name='white-page', corpus=self.corpus) + self.image = MLClass.objects.create(name='image', corpus=self.corpus) + Classification.objects.create(page=page, source=data_source, ml_class=self.white_page) + Classification.objects.create(page=page, source=data_source, ml_class=self.image) + response = self.client.get(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {'volume': volume.id}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 3, + "number": 1, + "next": None, + "previous": None, + "results": [ + { + "id": str(self.image.pk), + "name": "image" + }, + { + "id": str(self.text.pk), + "name": "text" + }, + { + "id": str(self.white_page.pk), + "name": "white-page" + } + ] + }) + + def test_create(self): + self.client.force_login(self.user) + response = self.client.post(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {}) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_update(self): + self.client.force_login(self.user) + response = self.client.put(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {}) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + response = self.client.patch(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk}), {}) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_delete(self): + self.client.force_login(self.user) + response = self.client.delete(reverse('api:corpus-classes', kwargs={'pk': self.corpus.pk})) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) diff --git a/arkindex/documents/tests/test_corpus.py b/arkindex/documents/tests/test_corpus.py index f73df455abed78f67e062d9438411839cf082c80..3f2e93e923dacf823088f2b703eb99c74b3d058d 100644 --- a/arkindex/documents/tests/test_corpus.py +++ b/arkindex/documents/tests/test_corpus.py @@ -1,8 +1,9 @@ +from django.contrib.auth.models import AnonymousUser from django.urls import reverse -from arkindex.project.tests import FixtureAPITestCase from rest_framework import status -from arkindex.documents.models import Corpus, Element, ElementType -from django.contrib.auth.models import AnonymousUser + +from arkindex.documents.models import Classification, Corpus, DataSource, Element, ElementType, MLClass, Page +from arkindex.project.tests import FixtureAPITestCase class TestCorpus(FixtureAPITestCase): @@ -173,3 +174,74 @@ class TestCorpus(FixtureAPITestCase): def test_delete_requires_login(self): response = self.client.delete(reverse('api:corpus-retrieve', kwargs={'pk': self.corpus_private.id})) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_pages_list(self): + self.client.force_login(self.user) + response = self.client.get(reverse('api:corpus-pages', kwargs={'pk': self.corpus_public.id})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 6) + self.assertEqual(len(response.data['results']), 6) + self.assertIsNone(response.data['next']) + self.assertIsNone(response.data['previous']) + + # Add a class to filter with: without classification, must not return any result + text = MLClass.objects.create(name='text', corpus=self.corpus_public) + response = self.client.get( + reverse('api:corpus-pages', kwargs={'pk': self.corpus_public.id}), {'ml_class': text.id}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 0, + "number": 1, + "next": None, + "previous": None, + "results": [] + }) + + # Add a classification: must return one result + classification = Classification.objects.create( + page=Page.objects.get(id='c3f32f55-4167-4b6d-9e7a-8fa1e8bb77f5'), + source=DataSource.objects.get(slug='test'), ml_class=text) + response = self.client.get( + reverse('api:corpus-pages', kwargs={'pk': self.corpus_public.id}), {'ml_class': text.id}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "count": 1, + "number": 1, + "next": None, + "previous": None, + "results": [{ + 'id': 'c3f32f55-4167-4b6d-9e7a-8fa1e8bb77f5', + 'page_type': 'page', + 'nb': 1, + 'direction': 'verso', + 'display_name': 'Page no. 1, verso', + 'image': { + 'id': 'ed0a53ee-ce6d-4ce2-8203-fbcc1e561541', + 'path': 'img2', + 's3_url': None, + 'width': 1000, + 'height': 1000, + 'url': 'http://server/img2', + 'status': 'unchecked' + }, + 'classifications': [ + { + 'id': str(classification.id), + 'source': { + 'id': str(classification.source.id), + 'type': 'recognizer', + 'slug': 'test', + 'revision': '4.2', + 'internal': False + }, + 'ml_class': { + 'id': str(text.id), + 'name': 'text' + }, + 'state': 'pending', + 'confidence': None + } + ] + } + ] + }) diff --git a/arkindex/documents/tests/test_moderation.py b/arkindex/documents/tests/test_moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..49a5e1707fea8c4cdf65b67939cbb7b287fffa04 --- /dev/null +++ b/arkindex/documents/tests/test_moderation.py @@ -0,0 +1,162 @@ +from django.urls import reverse +from rest_framework import status + +from arkindex.documents.models import Classification, ClassificationState, DataSource, MLClass, Page +from arkindex.project.tests import FixtureAPITestCase + + +class TestClasses(FixtureAPITestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.text = MLClass.objects.create(name='text', corpus=cls.corpus) + + def _create_classification(self): + return Classification.objects.create( + page=Page.objects.get(id='c3f32f55-4167-4b6d-9e7a-8fa1e8bb77f5'), + source=DataSource.objects.get(slug='test'), ml_class=self.text, confidence=.5) + + def test_classification_creation(self): + """ + Ensure classification creation works and set auto fields correctly + """ + self.client.force_login(self.user) + response = self.client.post(reverse('api:classification-create'), { + 'page': 'c3f32f55-4167-4b6d-9e7a-8fa1e8bb77f5', + 'ml_class': str(self.text.id) + }) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['source']['type'], 'classifier') + self.assertEqual(response.data['source']['slug'], 'manual') + self.assertDictEqual(response.data['ml_class'], { + 'id': str(self.text.id), + 'name': 'text' + }) + self.assertEqual(response.data['state'], ClassificationState.Validated.value) + self.assertEqual(response.data['confidence'], 1) + + def test_classification_creation_without_permission(self): + """ + Without permission, should not raise a 403 Forbidden: indicates that the resource exists + Indicates that page PK does not exist instead + """ + response = self.client.post(reverse('api:classification-create'), { + 'page': 'c3f32f55-4167-4b6d-9e7a-8fa1e8bb77f5', + 'ml_class': str(self.text.id) + }) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.data, { + 'page': [ + 'Invalid pk "c3f32f55-4167-4b6d-9e7a-8fa1e8bb77f5" - object does not exist.' + ] + }) + + def test_classification_validate(self): + self.client.force_login(self.user) + classification = self._create_classification() + response = self.client.put(reverse('api:classification-validate', kwargs={'pk': classification.id})) + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + self.assertDictEqual(response.json(), { + 'id': str(classification.id), + 'source': { + 'id': str(classification.source.id), + 'type': classification.source.type.value, + 'slug': classification.source.slug, + 'revision': classification.source.revision, + 'internal': classification.source.internal + }, + 'ml_class': { + 'id': str(classification.ml_class.id), + 'name': classification.ml_class.name + }, + 'state': ClassificationState.Validated.value, + 'confidence': classification.confidence + }) + + # Ensure moderator has been set + classification.refresh_from_db() + self.assertEqual(classification.moderator, self.user) + + def test_classification_validate_without_permissions(self): + classification = self._create_classification() + response = self.client.put(reverse('api:classification-validate', kwargs={'pk': classification.id})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_classification_reject(self): + self.client.force_login(self.user) + classification = self._create_classification() + response = self.client.put(reverse('api:classification-reject', kwargs={'pk': classification.id})) + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + self.assertDictEqual(response.json(), { + 'id': str(classification.id), + 'source': { + 'id': str(classification.source.id), + 'type': classification.source.type.value, + 'slug': classification.source.slug, + 'revision': classification.source.revision, + 'internal': classification.source.internal + }, + 'ml_class': { + 'id': str(classification.ml_class.id), + 'name': classification.ml_class.name + }, + 'state': ClassificationState.Rejected.value, + 'confidence': classification.confidence + }) + + # Ensure moderator has been set + classification.refresh_from_db() + self.assertEqual(classification.moderator, self.user) + + def test_classification_reject_without_permissions(self): + classification = self._create_classification() + response = self.client.put(reverse('api:classification-reject', kwargs={'pk': classification.id})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_classification_can_still_be_moderated(self): + self.client.force_login(self.user) + classification = self._create_classification() + classification.moderator = self.user + classification.state = ClassificationState.Validated.value + classification.save() + + # First try to reject + response = self.client.put(reverse('api:classification-reject', kwargs={'pk': classification.id})) + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + self.assertDictEqual(response.json(), { + 'id': str(classification.id), + 'source': { + 'id': str(classification.source.id), + 'type': classification.source.type.value, + 'slug': classification.source.slug, + 'revision': classification.source.revision, + 'internal': classification.source.internal + }, + 'ml_class': { + 'id': str(classification.ml_class.id), + 'name': classification.ml_class.name + }, + 'state': ClassificationState.Rejected.value, + 'confidence': classification.confidence + }) + + # Then try to validate + response = self.client.put(reverse('api:classification-validate', kwargs={'pk': classification.id})) + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + self.assertDictEqual(response.json(), { + 'id': str(classification.id), + 'source': { + 'id': str(classification.source.id), + 'type': classification.source.type.value, + 'slug': classification.source.slug, + 'revision': classification.source.revision, + 'internal': classification.source.internal + }, + 'ml_class': { + 'id': str(classification.ml_class.id), + 'name': classification.ml_class.name + }, + 'state': ClassificationState.Validated.value, + 'confidence': classification.confidence + }) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 7938454b670685261448b8f95edce4bf3c275d6f..35cca6eec33a9da4c169a92aaeb56d83b62613c4 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -4,14 +4,15 @@ from django.views.generic.base import RedirectView from arkindex.documents.api.elements import ( ElementsList, RelatedElementsList, ElementRetrieve, ElementPages, ElementSurfaces, CorpusList, CorpusRetrieve, CorpusPages, ActEdit, PageDetails, SurfaceDetails, - ElementTranscriptions, ElementsCreate + ElementTranscriptions, ElementsCreate, ClassificationCreate, ClassificationValidate, + ClassificationReject ) from arkindex.documents.api.search import PageSearch, ActSearch, EntitySearch from arkindex.documents.api.ml import \ ClassificationBulk, TranscriptionCreate, TranscriptionBulk, PageXmlTranscriptionsImport from arkindex.documents.api.entities import ( - CorpusRoles, EntityDetails, EntityCreate, EntityElements, EntityLinkCreate, - TranscriptionEntityCreate, TranscriptionEntities, ElementEntities + CorpusRoles, EntityDetails, EntityCreate, EntityLinkCreate, EntityElements, MLClassList, + CorpusMLClassList, TranscriptionEntityCreate, TranscriptionEntities, ElementEntities ) from arkindex.documents.api.iiif import ( VolumeManifest, ActManifest, PageAnnotationList, PageActAnnotationList, SurfaceAnnotationList, @@ -50,9 +51,16 @@ api = [ path('surface/<uuid:pk>/', SurfaceDetails.as_view(), name='surface-details'), path('corpus/', CorpusList.as_view(), name='corpus'), path('corpus/<uuid:pk>/', CorpusRetrieve.as_view(), name='corpus-retrieve'), + path('corpus/<uuid:pk>/classes/', CorpusMLClassList.as_view(), name='corpus-classes'), path('corpus/<uuid:pk>/pages/', CorpusPages.as_view(), name='corpus-pages'), path('corpus/<uuid:pk>/roles/', CorpusRoles.as_view(), name='corpus-roles'), + # Moderation + path('ml-classes/', MLClassList.as_view(), name='mlclass-list'), + path('classifications/', ClassificationCreate.as_view(), name='classification-create'), + path('classifications/<uuid:pk>/validate/', ClassificationValidate.as_view(), name='classification-validate'), + path('classifications/<uuid:pk>/reject/', ClassificationReject.as_view(), name='classification-reject'), + # Manifests path('manifest/<uuid:pk>/pages/', VolumeManifest.as_view(), name='volume-manifest'), path('manifest/<uuid:pk>/act/', ActManifest.as_view(), name='act-manifest'), diff --git a/arkindex/project/mixins.py b/arkindex/project/mixins.py index b0005c0c5a76dc9f7896174d1a2117ea1a8707d0..134044f5a6892e330e684ad330173789b37fc343 100644 --- a/arkindex/project/mixins.py +++ b/arkindex/project/mixins.py @@ -27,6 +27,17 @@ class CorpusACLMixin(object): return Right.Admin in corpus.get_acl_rights(self.request.user) +class NestedCorpusMixin(CorpusACLMixin): + """ + Used for corpora nested endpoints. + Fetch the corpus instance once, and attach it to both view and request. + """ + def initialize_request(self, request, *args, **kwargs): + request = super().initialize_request(request, *args, **kwargs) + request.corpus = self.corpus = self.get_corpus(self.kwargs['pk']) + return request + + class SearchAPIMixin(CorpusACLMixin): template_path = None es_source = True diff --git a/arkindex/project/pagination.py b/arkindex/project/pagination.py index c278e2fa9875647202f19dbaee936ac7ccec690e..d430b96abeac139b0d50abbf1bbc524debc6a90a 100644 --- a/arkindex/project/pagination.py +++ b/arkindex/project/pagination.py @@ -7,6 +7,8 @@ class PageNumberPagination(pagination.PageNumberPagination): """ Enhanced PageNumberPagination with extra information to help front-end pagination displays """ + page_size_query_param = 'page_size' + max_page_size = 20 def get_paginated_response(self, data): return Response(OrderedDict([ diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index 087ccbf3f52bb22eaa71a35fb17c77eb1d1b2aa3..0fbc3118c001fd5356b5e706fc36ae12a0c63390 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -123,6 +123,7 @@ INSTALLED_APPS = [ # Tools 'rest_framework', 'rest_framework.authtoken', + 'django_filters', 'corsheaders', 'ponos', diff --git a/openapi/Dockerfile b/openapi/Dockerfile index f778aee54e9a074f69ea7ae117132f478bc02c0b..f1cc8ccc6415428d8216694e21d243e304881c99 100644 --- a/openapi/Dockerfile +++ b/openapi/Dockerfile @@ -1,6 +1,6 @@ FROM registry.gitlab.com/arkindex/backend:latest -RUN pip uninstall -y djangorestframework +RUN pip uninstall -y djangorestframework django-filter COPY ["patch.py", "run.sh", "requirements.txt", "patch.yml", "/"] RUN pip install -r /requirements.txt && rm /requirements.txt diff --git a/openapi/requirements.txt b/openapi/requirements.txt index ceabbab7b9a2b2666061618f4ad29b263290d53a..b42a59416a0b53cc17cf1134259ee94615957a02 100644 --- a/openapi/requirements.txt +++ b/openapi/requirements.txt @@ -1,3 +1,4 @@ git+https://github.com/encode/django-rest-framework.git@37f210a455cc92cb3f61a23e194a1d0de58d149b#egg=djangorestframework +git+https://github.com/n2ygk/django-filter.git@556ce2740f0c7aede2a60e5e94d07565f706f6f5#egg=django-filter coreapi==2.3.3 apistar>=0.7.2 diff --git a/requirements.txt b/requirements.txt index 7f406b55fdcfe49bcef22ffaeaee9fcfa1478ee4..14b4145497534f580853c1ebe1f6bc8111c2e528 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ chardet==3.0.4 django-admin-hstore-widget==1.0.1 django-cors-headers==2.4.0 django-enumfields==1.0.0 +django-filter==2.1.0 djangorestframework==3.9.2 et-xmlfile==1.0.1 gitpython==2.1.11