Skip to content
Snippets Groups Projects
Commit 0de8d7f0 authored by Valentin Rigal's avatar Valentin Rigal Committed by Erwan Rouchet
Browse files

Rename best classification attribute to high_confidence

parent 7be51908
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,7 @@ class EventInline(admin.TabularInline):
class ClassificationInline(admin.TabularInline):
model = Classification
readonly_fields = ('confidence', 'high_confidence', )
class AllowedMetaDataAdmin(admin.ModelAdmin):
......
......@@ -37,7 +37,7 @@ best_classifications_prefetch = Prefetch(
'classifications',
queryset=classifications_queryset
# Keep best or validated distinct classifications
.filter(Q(state=ClassificationState.Validated) | Q(best=True))
.filter(Q(state=ClassificationState.Validated) | Q(high_confidence=True))
# Remove rejected classifications
.exclude(state=ClassificationState.Rejected).distinct(),
to_attr='best_classes'
......@@ -152,7 +152,7 @@ class ElementsList(CorpusACLMixin, ListAPIView):
filtered_queryset = filtered_queryset.filter(
classifications__in=Classification.objects.all()
.filter(
Q(state=ClassificationState.Validated) | Q(best=True),
Q(state=ClassificationState.Validated) | Q(high_confidence=True),
ml_class_id=class_filter
)
.exclude(state=ClassificationState.Rejected)
......@@ -405,7 +405,7 @@ class ElementParents(ListAPIView):
class_filter = self.request.query_params.get('best_class')
if class_filter is not None:
filters['classifications__in'] = Classification.objects.all().filter(
Q(state=ClassificationState.Validated) | Q(best=True),
Q(state=ClassificationState.Validated) | Q(high_confidence=True),
ml_class_id=class_filter
).exclude(state=ClassificationState.Rejected)
......@@ -549,7 +549,7 @@ class ElementChildren(ListAPIView):
class_filter = self.request.query_params.get('best_class')
if class_filter is not None:
filters['classifications__in'] = Classification.objects.all().filter(
Q(state=ClassificationState.Validated) | Q(best=True),
Q(state=ClassificationState.Validated) | Q(high_confidence=True),
ml_class_id=class_filter
).exclude(state=ClassificationState.Rejected)
......
......@@ -66,7 +66,7 @@ class CorpusMLClassList(NestedCorpusMixin, ListAPIView):
# Keep non rejected best or validated classifications
(
(
Q(classifications__best=True)
Q(classifications__high_confidence=True)
& ~Q(classifications__state=ClassificationState.Rejected)
)
| Q(classifications__state=ClassificationState.Validated)
......
......@@ -183,7 +183,7 @@ class ClassificationCreate(CreateAPIView):
source=data_source,
moderator=self.request.user,
state=ClassificationState.Validated,
best=True,
high_confidence=True,
confidence=1
)
headers = self.get_success_headers(serializer.data)
......@@ -208,7 +208,7 @@ class ClassificationBulk(CreateAPIView):
source=source,
ml_class=cl['ml_class'],
confidence=cl['confidence'],
best=cl['best']
high_confidence=cl['high_confidence']
)
for cl in serializer.validated_data['classifications']
])
......@@ -237,8 +237,11 @@ class ManageClassificationsSelection(SelectionMixin, CreateAPIView):
return self.create(corpus, request, *args, **kwargs)
elif mode == ClassificationMode.Validate:
elements = self.get_selection(corpus_id)
Classification.objects.filter(element__in=elements, state=ClassificationState.Pending, best=True) \
.update(state=ClassificationState.Validated)
Classification.objects.filter(
element__in=elements,
state=ClassificationState.Pending,
high_confidence=True
).update(state=ClassificationState.Validated)
return Response(status=status.HTTP_201_CREATED)
raise NotImplementedError
......@@ -268,7 +271,7 @@ class ManageClassificationsSelection(SelectionMixin, CreateAPIView):
source=data_source,
moderator=self.request.user,
state=ClassificationState.Validated,
best=False,
high_confidence=False,
confidence=1
))
Classification.objects.bulk_create(classifications)
......
# Generated by Django 2.2 on 2019-12-13 08:46
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('documents', '0029_allowedmetadata'),
]
operations = [
migrations.RenameField(
model_name='classification',
old_name='best',
new_name='high_confidence',
),
]
......@@ -525,8 +525,8 @@ class Classification(models.Model):
on_delete=models.CASCADE,
related_name='classifications'
)
# Predicted class with highest score
best = models.BooleanField(default=False)
# Predicted class is considered as correct by its creator
high_confidence = models.BooleanField(default=False)
state = EnumField(ClassificationState, max_length=16, default=ClassificationState.Pending, db_index=True)
confidence = models.FloatField(null=True, blank=True)
......
......@@ -81,14 +81,14 @@ class ClassificationSerializer(serializers.ModelSerializer):
class Meta:
model = Classification
read_only_fields = ('id', 'confidence', 'best')
read_only_fields = ('id', 'confidence', 'high_confidence')
fields = (
'id',
'source',
'ml_class',
'state',
'confidence',
'best'
'high_confidence'
)
......@@ -230,7 +230,7 @@ class ClassificationBulkSerializer(serializers.Serializer):
"""
class_name = serializers.CharField(source='ml_class')
confidence = serializers.FloatField(min_value=0, max_value=1)
best = serializers.BooleanField(default=False)
high_confidence = serializers.BooleanField(default=False)
class ClassificationsSerializer(serializers.Serializer):
......
......@@ -74,7 +74,7 @@ class TestBulkClassification(FixtureAPITestCase):
{
"class_name": 'dog',
"confidence": 0.99,
"best": True
"high_confidence": True
},
{
"class_name": 'cat',
......@@ -84,7 +84,7 @@ class TestBulkClassification(FixtureAPITestCase):
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertCountEqual(
list(self.page.classifications.values_list('ml_class__name', 'confidence', 'best')),
list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')),
[
('dog', 0.99, True),
('cat', 0.42, False),
......@@ -127,7 +127,7 @@ class TestBulkClassification(FixtureAPITestCase):
{
"class_name": 'dog',
"confidence": 0.99,
"best": True
"high_confidence": True
},
{
"class_name": 'cat',
......@@ -147,13 +147,13 @@ class TestBulkClassification(FixtureAPITestCase):
{
"class_name": 'catte',
"confidence": 0.85,
"best": True
"high_confidence": True
}
])
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertCountEqual(
list(self.page.classifications.values_list('ml_class__name', 'confidence', 'best')),
list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')),
[
('doggo', 0.5, False),
('catte', 0.85, True),
......
......@@ -42,7 +42,7 @@ class TestClasses(FixtureAPITestCase):
source_id=source.id,
ml_class_id=ml_class.id,
confidence=score,
best=bool(score == .99)
high_confidence=bool(score == .99)
)
def test_list(self):
......@@ -272,7 +272,7 @@ class TestClasses(FixtureAPITestCase):
source_id=data_source.id,
ml_class_id=self.text.id,
confidence=1,
best=True,
high_confidence=True,
)
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
......@@ -311,7 +311,7 @@ class TestClasses(FixtureAPITestCase):
source_id=DataSource.objects.create(type=MLToolType.NER, slug='ner', internal=False).id,
ml_class_id=self.text.id,
confidence=.1337,
best=True,
high_confidence=True,
)
with self.assertNumQueries(5):
response = self.client.get(
......@@ -353,8 +353,8 @@ class TestClasses(FixtureAPITestCase):
def test_class_filter_list_elements_distinct(self):
self.populate_classified_elements()
self.assertEqual(Classification.objects.filter(best=True).count(), 24)
self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12)
self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24)
self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements'),
......@@ -369,8 +369,8 @@ class TestClasses(FixtureAPITestCase):
def test_class_filter_list_parents_distinct(self):
self.populate_classified_elements()
self.assertEqual(Classification.objects.filter(best=True).count(), 24)
self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12)
self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24)
self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}),
......@@ -385,8 +385,8 @@ class TestClasses(FixtureAPITestCase):
def test_class_filter_list_children_distinct(self):
self.populate_classified_elements()
self.assertEqual(Classification.objects.filter(best=True).count(), 24)
self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12)
self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24)
self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}),
......
......@@ -42,7 +42,7 @@ class TestClasses(FixtureAPITestCase):
})
self.assertEqual(response.data['state'], ClassificationState.Validated.value)
self.assertEqual(response.data['confidence'], 1)
self.assertEqual(response.data['best'], True)
self.assertEqual(response.data['high_confidence'], True)
def test_classification_exists(self):
"""
......@@ -67,19 +67,19 @@ class TestClasses(FixtureAPITestCase):
def test_classification_ignored_params(self):
"""
Ensure confidence and best attribute cannot be changed
Ensure confidence and high_confidence attributes cannot be changed
"""
self.client.force_login(self.user)
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
'confidence': 0.5,
'best': False
'high_confidence': False
})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertEqual(data['confidence'], 1)
self.assertEqual(data['best'], True)
self.assertEqual(data['high_confidence'], True)
def test_classification_creation_without_permission(self):
"""
......@@ -118,7 +118,7 @@ class TestClasses(FixtureAPITestCase):
},
'state': ClassificationState.Validated.value,
'confidence': classification.confidence,
'best': False
'high_confidence': False
})
# Ensure moderator has been set
......@@ -151,7 +151,7 @@ class TestClasses(FixtureAPITestCase):
},
'state': ClassificationState.Rejected.value,
'confidence': classification.confidence,
'best': False
'high_confidence': False
})
# Ensure moderator has been set
......@@ -189,7 +189,7 @@ class TestClasses(FixtureAPITestCase):
},
'state': ClassificationState.Rejected.value,
'confidence': classification.confidence,
'best': False
'high_confidence': False
})
# Then try to validate
......@@ -211,7 +211,7 @@ class TestClasses(FixtureAPITestCase):
},
'state': ClassificationState.Validated.value,
'confidence': classification.confidence,
'best': False
'high_confidence': False
})
def test_classification_selection_requires_login(self):
......@@ -331,7 +331,7 @@ class TestClasses(FixtureAPITestCase):
element=act_x,
source=source_1,
state=ClassificationState.Pending,
best=True,
high_confidence=True,
ml_class=line
)
......@@ -340,14 +340,14 @@ class TestClasses(FixtureAPITestCase):
element=e,
source=source_1,
state=ClassificationState.Pending,
best=True,
high_confidence=True,
ml_class=self.text
)
Classification.objects.create(
element=e,
source=source_2,
state=ClassificationState.Pending,
best=False,
high_confidence=False,
ml_class=self.text
)
......@@ -366,11 +366,11 @@ class TestClasses(FixtureAPITestCase):
for e in [self.folder, self.element]:
classification = e.classifications.get(state=ClassificationState.Validated)
self.assertTrue(classification.best)
self.assertTrue(classification.high_confidence)
self.assertEqual(classification.source, source_1)
classification = e.classifications.get(state=ClassificationState.Pending)
self.assertFalse(classification.best)
self.assertFalse(classification.high_confidence)
self.assertEqual(classification.source, source_2)
classification = act_x.classifications.get()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment