diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 58aff0acc7113a4494245d55a9047b5b0d79334e..dae717afba2c9ec867afc28aff76ffca7364c11c 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -477,6 +477,15 @@ class ClassificationsSerializer(serializers.Serializer): 'worker_version': ['This field XOR classifier field must be set to create classifications'] }) + ml_class_names = [ + classification['ml_class'] + for classification in data['classifications'] + ] + if len(ml_class_names) != len(set(ml_class_names)): + raise ValidationError({ + 'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.'] + }) + return data def create(self, validated_data): diff --git a/arkindex/documents/tests/test_bulk_classification.py b/arkindex/documents/tests/test_bulk_classification.py index 42df6fa6d0eb0aa9a521f477f360f241ebced08c..f1203fdd5ba39b33959d7cd842b4106c337ed41a 100644 --- a/arkindex/documents/tests/test_bulk_classification.py +++ b/arkindex/documents/tests/test_bulk_classification.py @@ -248,3 +248,22 @@ class TestBulkClassification(FixtureAPITestCase): ('catte', 0.85, True), ], ) + + def test_bulk_classification_no_duplicates(self): + """ + Test the bulk classification API prevents creating classifications with duplicate ML classes + """ + self.client.force_login(self.user) + with self.assertNumQueries(4): + response = self.client.post( + reverse('api:classification-bulk'), + format='json', + data=self.create_classifications_data([ + {"class_name": 'dog', "confidence": 0.99}, + {"class_name": 'dog', "confidence": 0.99}, + ]) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.'] + })