From 3a49212f330f9930e3db357517748423189741b5 Mon Sep 17 00:00:00 2001 From: vrigal <rigal@teklia.com> Date: Fri, 13 Dec 2019 10:06:47 +0100 Subject: [PATCH] Update best to high_confidence --- arkindex/documents/api/elements.py | 8 +++--- arkindex/documents/api/entities.py | 2 +- arkindex/documents/api/ml.py | 13 ++++++---- arkindex/documents/models.py | 4 +-- arkindex/documents/serializers/ml.py | 6 ++--- .../tests/test_bulk_classification.py | 10 +++---- arkindex/documents/tests/test_classes.py | 18 ++++++------- arkindex/documents/tests/test_moderation.py | 26 +++++++++---------- 8 files changed, 45 insertions(+), 42 deletions(-) diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index ef615a93f6..35640b80e0 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -39,7 +39,7 @@ best_classifications_prefetch = Prefetch( 'classifications', queryset=classifications_queryset # Keep best or validated distinct classifications - .filter(Q(state=ClassificationState.Validated) | Q(best=True)) + .filter(Q(state=ClassificationState.Validated) | Q(high_confidence=True)) # Remove rejected classifications .exclude(state=ClassificationState.Rejected).distinct(), to_attr='best_classes' @@ -156,7 +156,7 @@ class ElementsList(CorpusACLMixin, ListAPIView): filtered_queryset = filtered_queryset.filter( classifications__in=Classification.objects.all() .filter( - Q(state=ClassificationState.Validated) | Q(best=True), + Q(state=ClassificationState.Validated) | Q(high_confidence=True), ml_class_id=class_filter ) .exclude(state=ClassificationState.Rejected) @@ -409,7 +409,7 @@ class ElementParents(ListAPIView): class_filter = self.request.query_params.get('best_class') if class_filter is not None: filters['classifications__in'] = Classification.objects.all().filter( - Q(state=ClassificationState.Validated) | Q(best=True), + Q(state=ClassificationState.Validated) | Q(high_confidence=True), ml_class_id=class_filter ).exclude(state=ClassificationState.Rejected) @@ -555,7 +555,7 @@ class ElementChildren(ListAPIView): class_filter = self.request.query_params.get('best_class') if class_filter is not None: filters['classifications__in'] = Classification.objects.all().filter( - Q(state=ClassificationState.Validated) | Q(best=True), + Q(state=ClassificationState.Validated) | Q(high_confidence=True), ml_class_id=class_filter ).exclude(state=ClassificationState.Rejected) diff --git a/arkindex/documents/api/entities.py b/arkindex/documents/api/entities.py index a0d8297c56..7534e6c7b1 100644 --- a/arkindex/documents/api/entities.py +++ b/arkindex/documents/api/entities.py @@ -68,7 +68,7 @@ class CorpusMLClassList(NestedCorpusMixin, ListAPIView): # Keep non rejected best or validated classifications ( ( - Q(classifications__best=True) + Q(classifications__high_confidence=True) & ~Q(classifications__state=ClassificationState.Rejected) ) | Q(classifications__state=ClassificationState.Validated) diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index 0a86d70102..358a6fd993 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -183,7 +183,7 @@ class ClassificationCreate(CreateAPIView): source=data_source, moderator=self.request.user, state=ClassificationState.Validated, - best=True, + high_confidence=True, confidence=1 ) headers = self.get_success_headers(serializer.data) @@ -208,7 +208,7 @@ class ClassificationBulk(CreateAPIView): source=source, ml_class=cl['ml_class'], confidence=cl['confidence'], - best=cl['best'] + high_confidence=cl['high_confidence'] ) for cl in serializer.validated_data['classifications'] ]) @@ -237,8 +237,11 @@ class ManageClassificationsSelection(SelectionMixin, CreateAPIView): return self.create(corpus, request, *args, **kwargs) elif mode == ClassificationMode.Validate: elements = self.get_selection(corpus_id) - Classification.objects.filter(element__in=elements, state=ClassificationState.Pending, best=True) \ - .update(state=ClassificationState.Validated) + Classification.objects.filter( + element__in=elements, + state=ClassificationState.Pending, + high_confidence=True + ).update(state=ClassificationState.Validated) return Response(status=status.HTTP_201_CREATED) raise NotImplementedError @@ -268,7 +271,7 @@ class ManageClassificationsSelection(SelectionMixin, CreateAPIView): source=data_source, moderator=self.request.user, state=ClassificationState.Validated, - best=False, + high_confidence=False, confidence=1 )) Classification.objects.bulk_create(classifications) diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 98b1ad9bc3..532e679452 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -525,8 +525,8 @@ class Classification(models.Model): on_delete=models.CASCADE, related_name='classifications' ) - # Predicted class with highest score - best = models.BooleanField(default=False) + # Predicted class is considered as correct by its creator + high_confidence = models.BooleanField(default=False) state = EnumField(ClassificationState, max_length=16, default=ClassificationState.Pending, db_index=True) confidence = models.FloatField(null=True, blank=True) diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index bcc1f84cbd..33c8db0669 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -81,14 +81,14 @@ class ClassificationSerializer(serializers.ModelSerializer): class Meta: model = Classification - read_only_fields = ('id', 'confidence', 'best') + read_only_fields = ('id', 'confidence', 'high_confidence') fields = ( 'id', 'source', 'ml_class', 'state', 'confidence', - 'best' + 'high_confidence' ) @@ -230,7 +230,7 @@ class ClassificationBulkSerializer(serializers.Serializer): """ class_name = serializers.CharField(source='ml_class') confidence = serializers.FloatField(min_value=0, max_value=1) - best = serializers.BooleanField(default=False) + high_confidence = serializers.BooleanField(default=False) class ClassificationsSerializer(serializers.Serializer): diff --git a/arkindex/documents/tests/test_bulk_classification.py b/arkindex/documents/tests/test_bulk_classification.py index b31604da0c..e5ae673723 100644 --- a/arkindex/documents/tests/test_bulk_classification.py +++ b/arkindex/documents/tests/test_bulk_classification.py @@ -74,7 +74,7 @@ class TestBulkClassification(FixtureAPITestCase): { "class_name": 'dog', "confidence": 0.99, - "best": True + "high_confidence": True }, { "class_name": 'cat', @@ -84,7 +84,7 @@ class TestBulkClassification(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertCountEqual( - list(self.page.classifications.values_list('ml_class__name', 'confidence', 'best')), + list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')), [ ('dog', 0.99, True), ('cat', 0.42, False), @@ -127,7 +127,7 @@ class TestBulkClassification(FixtureAPITestCase): { "class_name": 'dog', "confidence": 0.99, - "best": True + "high_confidence": True }, { "class_name": 'cat', @@ -147,13 +147,13 @@ class TestBulkClassification(FixtureAPITestCase): { "class_name": 'catte', "confidence": 0.85, - "best": True + "high_confidence": True } ]) ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertCountEqual( - list(self.page.classifications.values_list('ml_class__name', 'confidence', 'best')), + list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')), [ ('doggo', 0.5, False), ('catte', 0.85, True), diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py index c980b90218..e4b8d3d3b2 100644 --- a/arkindex/documents/tests/test_classes.py +++ b/arkindex/documents/tests/test_classes.py @@ -42,7 +42,7 @@ class TestClasses(FixtureAPITestCase): source_id=source.id, ml_class_id=ml_class.id, confidence=score, - best=bool(score == .99) + high_confidence=bool(score == .99) ) def test_list(self): @@ -272,7 +272,7 @@ class TestClasses(FixtureAPITestCase): source_id=data_source.id, ml_class_id=self.text.id, confidence=1, - best=True, + high_confidence=True, ) response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), @@ -311,7 +311,7 @@ class TestClasses(FixtureAPITestCase): source_id=DataSource.objects.create(type=MLToolType.NER, slug='ner', internal=False).id, ml_class_id=self.text.id, confidence=.1337, - best=True, + high_confidence=True, ) with self.assertNumQueries(5): response = self.client.get( @@ -353,8 +353,8 @@ class TestClasses(FixtureAPITestCase): def test_class_filter_list_elements_distinct(self): self.populate_classified_elements() - self.assertEqual(Classification.objects.filter(best=True).count(), 24) - self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12) + self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) + self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(5): response = self.client.get( reverse('api:elements'), @@ -369,8 +369,8 @@ class TestClasses(FixtureAPITestCase): def test_class_filter_list_parents_distinct(self): self.populate_classified_elements() - self.assertEqual(Classification.objects.filter(best=True).count(), 24) - self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12) + self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) + self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(5): response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), @@ -385,8 +385,8 @@ class TestClasses(FixtureAPITestCase): def test_class_filter_list_children_distinct(self): self.populate_classified_elements() - self.assertEqual(Classification.objects.filter(best=True).count(), 24) - self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12) + self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) + self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(5): response = self.client.get( reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}), diff --git a/arkindex/documents/tests/test_moderation.py b/arkindex/documents/tests/test_moderation.py index d830e2491d..b361e92713 100644 --- a/arkindex/documents/tests/test_moderation.py +++ b/arkindex/documents/tests/test_moderation.py @@ -42,7 +42,7 @@ class TestClasses(FixtureAPITestCase): }) self.assertEqual(response.data['state'], ClassificationState.Validated.value) self.assertEqual(response.data['confidence'], 1) - self.assertEqual(response.data['best'], True) + self.assertEqual(response.data['high_confidence'], True) def test_classification_exists(self): """ @@ -67,19 +67,19 @@ class TestClasses(FixtureAPITestCase): def test_classification_ignored_params(self): """ - Ensure confidence and best attribute cannot be changed + Ensure confidence and high_confidence attributes cannot be changed """ self.client.force_login(self.user) response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), 'confidence': 0.5, - 'best': False + 'high_confidence': False }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() self.assertEqual(data['confidence'], 1) - self.assertEqual(data['best'], True) + self.assertEqual(data['high_confidence'], True) def test_classification_creation_without_permission(self): """ @@ -118,7 +118,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Validated.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) # Ensure moderator has been set @@ -151,7 +151,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Rejected.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) # Ensure moderator has been set @@ -189,7 +189,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Rejected.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) # Then try to validate @@ -211,7 +211,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Validated.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) def test_classification_selection_requires_login(self): @@ -331,7 +331,7 @@ class TestClasses(FixtureAPITestCase): element=act_x, source=source_1, state=ClassificationState.Pending, - best=True, + high_confidence=True, ml_class=line ) @@ -340,14 +340,14 @@ class TestClasses(FixtureAPITestCase): element=e, source=source_1, state=ClassificationState.Pending, - best=True, + high_confidence=True, ml_class=self.text ) Classification.objects.create( element=e, source=source_2, state=ClassificationState.Pending, - best=False, + high_confidence=False, ml_class=self.text ) @@ -366,11 +366,11 @@ class TestClasses(FixtureAPITestCase): for e in [self.folder, self.element]: classification = e.classifications.get(state=ClassificationState.Validated) - self.assertTrue(classification.best) + self.assertTrue(classification.high_confidence) self.assertEqual(classification.source, source_1) classification = e.classifications.get(state=ClassificationState.Pending) - self.assertFalse(classification.best) + self.assertFalse(classification.high_confidence) self.assertEqual(classification.source, source_2) classification = act_x.classifications.get() -- GitLab