Skip to content
Snippets Groups Projects
Commit c0ec6b2b authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

Merge branch 'validated-as-best-classes' into 'master'

Validated as best classes

See merge request !550
parents 3bc72cb4 1b71b58f
No related branches found
No related tags found
1 merge request!550Validated as best classes
from django.db.models import Prefetch, prefetch_related_objects
from django.db.models import Q
from django_filters import rest_framework as filters
from rest_framework.exceptions import ValidationError
from rest_framework.generics import (
......@@ -8,7 +9,10 @@ from rest_framework.generics import (
from rest_framework import status, response
from rest_framework.response import Response
from arkindex_common.enums import TranscriptionType
from arkindex.documents.models import Classification, Corpus, Element, ElementPath, Right, Transcription, Region
from arkindex.documents.models import (
Corpus, Element, ElementPath, Right,
Classification, ClassificationState, Transcription, Region
)
from arkindex.documents.serializers.elements import (
CorpusSerializer, ElementSerializer, ElementSlimSerializer, BestClassesElementSerializer,
ElementCreateSerializer, ElementNeighborsSerializer, ElementParentSerializer,
......@@ -27,6 +31,16 @@ from uuid import UUID
classifications_queryset = Classification.objects.select_related('ml_class', 'source').order_by('-confidence')
best_classifications_prefetch = Prefetch(
'classifications',
queryset=classifications_queryset
# Keep best or validated distinct classifications
.filter(Q(state=ClassificationState.Validated) | Q(best=True))
# Remove rejected classifications
.exclude(state=ClassificationState.Rejected).distinct(),
to_attr='best_classes'
)
class ElementsList(CorpusACLMixin, ListAPIView):
"""
......@@ -128,11 +142,6 @@ class ElementsList(CorpusACLMixin, ListAPIView):
best_classes = self.request.query_params.get('best_classes')
if best_classes and best_classes.lower() not in ('false', '0'):
best_classifications_prefetch = Prefetch(
'classifications',
queryset=classifications_queryset.filter(best=True),
to_attr='best_classes'
)
filtered_queryset = filtered_queryset.prefetch_related(best_classifications_prefetch)
# ID is required by postgres to order elements with common corpus, type and name
......@@ -353,11 +362,6 @@ class ElementParents(ListAPIView):
best_classes = self.request.query_params.get('best_classes')
if best_classes and best_classes.lower() not in ('false', '0'):
best_classifications_prefetch = Prefetch(
'classifications',
queryset=classifications_queryset.filter(best=True),
to_attr='best_classes'
)
prefetch_related_lookups += (best_classifications_prefetch, )
if recursive_param is None or recursive_param.lower() in ('false', '0'):
......@@ -473,11 +477,6 @@ class ElementChildren(ListAPIView):
best_classes = self.request.query_params.get('best_classes')
if best_classes and best_classes.lower() not in ('false', '0'):
best_classifications_prefetch = Prefetch(
'classifications',
queryset=classifications_queryset.filter(best=True),
to_attr='best_classes'
)
prefetch_related_lookups += (best_classifications_prefetch, )
if recursive_param is None or recursive_param.lower() in ('false', '0'):
......
from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Element, Classification, DataSource, MLClass
from arkindex.documents.models import Element, Classification, ClassificationState, DataSource, MLClass
from arkindex_common.ml_tool import MLToolType
from arkindex.project.tests import FixtureAPITestCase
......@@ -256,3 +257,90 @@ class TestClasses(FixtureAPITestCase):
list(map(lambda c: (c['source']['slug'], c['confidence']), elt['best_classes'])),
[('test', .99)]
)
def test_rejected_best_classes(self):
"""
A machine classification that have been rejected by a human must not appear
"""
self.populate_classified_elements()
child = Element.objects.filter(type=self.classified.id).first()
child.classifications.all().update(state=ClassificationState.Rejected)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}),
data={'type': self.classified.slug, 'best_classes': 'yes'}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = response.json()
self.assertEqual(data['count'], 12)
for elt in data['results']:
if elt['id'] == str(child.id):
self.assertListEqual(elt['best_classes'], [])
continue
self.assertNotEqual(elt['best_classes'], [])
def test_validated_non_best_class(self):
"""
A non best class validated by a human is considered as best as it is for the human
"""
self.populate_classified_elements()
parent = Element.objects.get_ascending(self.common_children.id)[-1]
parent.classifications.all().filter(confidence=.7).update(state=ClassificationState.Validated)
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
data={'type': self.classified.slug, 'best_classes': 'yes'}
)
for elt in response.json()['results']:
if elt['id'] == str(parent.id):
self.assertListEqual(
list(map(lambda c: (c['source']['slug'], c['confidence']), elt['best_classes'])),
[('test', .99), ('test', .7)]
)
continue
self.assertListEqual(
list(map(lambda c: (c['source']['slug'], c['confidence']), elt['best_classes'])),
[('test', .99)]
)
def test_rejected_human_class(self):
"""
A manual classification rejected by a human may not appear in best classes
"""
self.populate_classified_elements()
data_source, _ = DataSource.objects.get_or_create(type=MLToolType.NER, slug="manual", internal=False)
element = Element.objects.filter(type=self.classified.id).first()
classif = element.classifications.create(
source_id=data_source.id,
ml_class_id=self.text.id,
confidence=1,
best=True,
)
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
data={'type': self.classified.slug, 'best_classes': 1}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
for elt in response.json()['results']:
if elt['id'] == str(element.id):
self.assertListEqual(
list(map(lambda c: (c['source']['slug'], c['confidence']), elt['best_classes'])),
[('manual', 1.0), ('test', .99)]
)
continue
self.assertListEqual(
list(map(lambda c: (c['source']['slug'], c['confidence']), elt['best_classes'])),
[('test', .99)]
)
# Reject the manual classification
classif.state = ClassificationState.Rejected
classif.save()
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
data={'type': self.classified.slug, 'best_classes': 1}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
for elt in response.json()['results']:
self.assertListEqual(
list(map(lambda c: (c['source']['slug'], c['confidence']), elt['best_classes'])),
[('test', .99)]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment