diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index ef615a93f6efac27d363b1b648327e8703e36ee4..35640b80e040bfa60583a74e92100534183e6757 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -39,7 +39,7 @@ best_classifications_prefetch = Prefetch( 'classifications', queryset=classifications_queryset # Keep best or validated distinct classifications - .filter(Q(state=ClassificationState.Validated) | Q(best=True)) + .filter(Q(state=ClassificationState.Validated) | Q(high_confidence=True)) # Remove rejected classifications .exclude(state=ClassificationState.Rejected).distinct(), to_attr='best_classes' @@ -156,7 +156,7 @@ class ElementsList(CorpusACLMixin, ListAPIView): filtered_queryset = filtered_queryset.filter( classifications__in=Classification.objects.all() .filter( - Q(state=ClassificationState.Validated) | Q(best=True), + Q(state=ClassificationState.Validated) | Q(high_confidence=True), ml_class_id=class_filter ) .exclude(state=ClassificationState.Rejected) @@ -409,7 +409,7 @@ class ElementParents(ListAPIView): class_filter = self.request.query_params.get('best_class') if class_filter is not None: filters['classifications__in'] = Classification.objects.all().filter( - Q(state=ClassificationState.Validated) | Q(best=True), + Q(state=ClassificationState.Validated) | Q(high_confidence=True), ml_class_id=class_filter ).exclude(state=ClassificationState.Rejected) @@ -555,7 +555,7 @@ class ElementChildren(ListAPIView): class_filter = self.request.query_params.get('best_class') if class_filter is not None: filters['classifications__in'] = Classification.objects.all().filter( - Q(state=ClassificationState.Validated) | Q(best=True), + Q(state=ClassificationState.Validated) | Q(high_confidence=True), ml_class_id=class_filter ).exclude(state=ClassificationState.Rejected) diff --git a/arkindex/documents/api/entities.py b/arkindex/documents/api/entities.py index a0d8297c56d71ffc619fec4d518f803d75591246..7534e6c7b12506a00a9d482f67f5c08bd1166e81 100644 --- a/arkindex/documents/api/entities.py +++ b/arkindex/documents/api/entities.py @@ -68,7 +68,7 @@ class CorpusMLClassList(NestedCorpusMixin, ListAPIView): # Keep non rejected best or validated classifications ( ( - Q(classifications__best=True) + Q(classifications__high_confidence=True) & ~Q(classifications__state=ClassificationState.Rejected) ) | Q(classifications__state=ClassificationState.Validated) diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index 0a86d701022db993f193780f379caf32be699698..358a6fd993e908e7cd7a3664604c84405f69d414 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -183,7 +183,7 @@ class ClassificationCreate(CreateAPIView): source=data_source, moderator=self.request.user, state=ClassificationState.Validated, - best=True, + high_confidence=True, confidence=1 ) headers = self.get_success_headers(serializer.data) @@ -208,7 +208,7 @@ class ClassificationBulk(CreateAPIView): source=source, ml_class=cl['ml_class'], confidence=cl['confidence'], - best=cl['best'] + high_confidence=cl['high_confidence'] ) for cl in serializer.validated_data['classifications'] ]) @@ -237,8 +237,11 @@ class ManageClassificationsSelection(SelectionMixin, CreateAPIView): return self.create(corpus, request, *args, **kwargs) elif mode == ClassificationMode.Validate: elements = self.get_selection(corpus_id) - Classification.objects.filter(element__in=elements, state=ClassificationState.Pending, best=True) \ - .update(state=ClassificationState.Validated) + Classification.objects.filter( + element__in=elements, + state=ClassificationState.Pending, + high_confidence=True + ).update(state=ClassificationState.Validated) return Response(status=status.HTTP_201_CREATED) raise NotImplementedError @@ -268,7 +271,7 @@ class ManageClassificationsSelection(SelectionMixin, CreateAPIView): source=data_source, moderator=self.request.user, state=ClassificationState.Validated, - best=False, + high_confidence=False, confidence=1 )) Classification.objects.bulk_create(classifications) diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 98b1ad9bc3ea449311da3764b9537da058d6172b..532e6794525c515d39d60ae34117902681830884 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -525,8 +525,8 @@ class Classification(models.Model): on_delete=models.CASCADE, related_name='classifications' ) - # Predicted class with highest score - best = models.BooleanField(default=False) + # Predicted class is considered as correct by its creator + high_confidence = models.BooleanField(default=False) state = EnumField(ClassificationState, max_length=16, default=ClassificationState.Pending, db_index=True) confidence = models.FloatField(null=True, blank=True) diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index bcc1f84cbd608fa3791fa3ade5c814490c7d9e22..33c8db066991781a9a6f948c0c30d72d56277ee6 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -81,14 +81,14 @@ class ClassificationSerializer(serializers.ModelSerializer): class Meta: model = Classification - read_only_fields = ('id', 'confidence', 'best') + read_only_fields = ('id', 'confidence', 'high_confidence') fields = ( 'id', 'source', 'ml_class', 'state', 'confidence', - 'best' + 'high_confidence' ) @@ -230,7 +230,7 @@ class ClassificationBulkSerializer(serializers.Serializer): """ class_name = serializers.CharField(source='ml_class') confidence = serializers.FloatField(min_value=0, max_value=1) - best = serializers.BooleanField(default=False) + high_confidence = serializers.BooleanField(default=False) class ClassificationsSerializer(serializers.Serializer): diff --git a/arkindex/documents/tests/test_bulk_classification.py b/arkindex/documents/tests/test_bulk_classification.py index b31604da0c418ffb2ae8e3cfcf5164afaca3f562..e5ae673723f9690f3aa523c0982d5c65cf142ba5 100644 --- a/arkindex/documents/tests/test_bulk_classification.py +++ b/arkindex/documents/tests/test_bulk_classification.py @@ -74,7 +74,7 @@ class TestBulkClassification(FixtureAPITestCase): { "class_name": 'dog', "confidence": 0.99, - "best": True + "high_confidence": True }, { "class_name": 'cat', @@ -84,7 +84,7 @@ class TestBulkClassification(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertCountEqual( - list(self.page.classifications.values_list('ml_class__name', 'confidence', 'best')), + list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')), [ ('dog', 0.99, True), ('cat', 0.42, False), @@ -127,7 +127,7 @@ class TestBulkClassification(FixtureAPITestCase): { "class_name": 'dog', "confidence": 0.99, - "best": True + "high_confidence": True }, { "class_name": 'cat', @@ -147,13 +147,13 @@ class TestBulkClassification(FixtureAPITestCase): { "class_name": 'catte', "confidence": 0.85, - "best": True + "high_confidence": True } ]) ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertCountEqual( - list(self.page.classifications.values_list('ml_class__name', 'confidence', 'best')), + list(self.page.classifications.values_list('ml_class__name', 'confidence', 'high_confidence')), [ ('doggo', 0.5, False), ('catte', 0.85, True), diff --git a/arkindex/documents/tests/test_classes.py b/arkindex/documents/tests/test_classes.py index c980b9021858001d994fd086a64539d7c7acb1d9..e4b8d3d3b22d5314b1719161e757eabcc654c681 100644 --- a/arkindex/documents/tests/test_classes.py +++ b/arkindex/documents/tests/test_classes.py @@ -42,7 +42,7 @@ class TestClasses(FixtureAPITestCase): source_id=source.id, ml_class_id=ml_class.id, confidence=score, - best=bool(score == .99) + high_confidence=bool(score == .99) ) def test_list(self): @@ -272,7 +272,7 @@ class TestClasses(FixtureAPITestCase): source_id=data_source.id, ml_class_id=self.text.id, confidence=1, - best=True, + high_confidence=True, ) response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), @@ -311,7 +311,7 @@ class TestClasses(FixtureAPITestCase): source_id=DataSource.objects.create(type=MLToolType.NER, slug='ner', internal=False).id, ml_class_id=self.text.id, confidence=.1337, - best=True, + high_confidence=True, ) with self.assertNumQueries(5): response = self.client.get( @@ -353,8 +353,8 @@ class TestClasses(FixtureAPITestCase): def test_class_filter_list_elements_distinct(self): self.populate_classified_elements() - self.assertEqual(Classification.objects.filter(best=True).count(), 24) - self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12) + self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) + self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(5): response = self.client.get( reverse('api:elements'), @@ -369,8 +369,8 @@ class TestClasses(FixtureAPITestCase): def test_class_filter_list_parents_distinct(self): self.populate_classified_elements() - self.assertEqual(Classification.objects.filter(best=True).count(), 24) - self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12) + self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) + self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(5): response = self.client.get( reverse('api:elements-parents', kwargs={'pk': str(self.common_children.id)}), @@ -385,8 +385,8 @@ class TestClasses(FixtureAPITestCase): def test_class_filter_list_children_distinct(self): self.populate_classified_elements() - self.assertEqual(Classification.objects.filter(best=True).count(), 24) - self.assertEqual(Classification.objects.filter(best=True).distinct('element_id').count(), 12) + self.assertEqual(Classification.objects.filter(high_confidence=True).count(), 24) + self.assertEqual(Classification.objects.filter(high_confidence=True).distinct('element_id').count(), 12) with self.assertNumQueries(5): response = self.client.get( reverse('api:elements-children', kwargs={'pk': str(self.parent.id)}), diff --git a/arkindex/documents/tests/test_moderation.py b/arkindex/documents/tests/test_moderation.py index d830e2491d870cd732b0fd8534bdbaed2d20e8d7..b361e92713d25de77ba921aa0b0156cc941eeef6 100644 --- a/arkindex/documents/tests/test_moderation.py +++ b/arkindex/documents/tests/test_moderation.py @@ -42,7 +42,7 @@ class TestClasses(FixtureAPITestCase): }) self.assertEqual(response.data['state'], ClassificationState.Validated.value) self.assertEqual(response.data['confidence'], 1) - self.assertEqual(response.data['best'], True) + self.assertEqual(response.data['high_confidence'], True) def test_classification_exists(self): """ @@ -67,19 +67,19 @@ class TestClasses(FixtureAPITestCase): def test_classification_ignored_params(self): """ - Ensure confidence and best attribute cannot be changed + Ensure confidence and high_confidence attributes cannot be changed """ self.client.force_login(self.user) response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), 'confidence': 0.5, - 'best': False + 'high_confidence': False }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() self.assertEqual(data['confidence'], 1) - self.assertEqual(data['best'], True) + self.assertEqual(data['high_confidence'], True) def test_classification_creation_without_permission(self): """ @@ -118,7 +118,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Validated.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) # Ensure moderator has been set @@ -151,7 +151,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Rejected.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) # Ensure moderator has been set @@ -189,7 +189,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Rejected.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) # Then try to validate @@ -211,7 +211,7 @@ class TestClasses(FixtureAPITestCase): }, 'state': ClassificationState.Validated.value, 'confidence': classification.confidence, - 'best': False + 'high_confidence': False }) def test_classification_selection_requires_login(self): @@ -331,7 +331,7 @@ class TestClasses(FixtureAPITestCase): element=act_x, source=source_1, state=ClassificationState.Pending, - best=True, + high_confidence=True, ml_class=line ) @@ -340,14 +340,14 @@ class TestClasses(FixtureAPITestCase): element=e, source=source_1, state=ClassificationState.Pending, - best=True, + high_confidence=True, ml_class=self.text ) Classification.objects.create( element=e, source=source_2, state=ClassificationState.Pending, - best=False, + high_confidence=False, ml_class=self.text ) @@ -366,11 +366,11 @@ class TestClasses(FixtureAPITestCase): for e in [self.folder, self.element]: classification = e.classifications.get(state=ClassificationState.Validated) - self.assertTrue(classification.best) + self.assertTrue(classification.high_confidence) self.assertEqual(classification.source, source_1) classification = e.classifications.get(state=ClassificationState.Pending) - self.assertFalse(classification.best) + self.assertFalse(classification.high_confidence) self.assertEqual(classification.source, source_2) classification = act_x.classifications.get()