Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (4)
......@@ -477,6 +477,15 @@ class ClassificationsSerializer(serializers.Serializer):
'worker_version': ['This field XOR classifier field must be set to create classifications']
})
ml_class_names = [
classification['ml_class']
for classification in data['classifications']
]
if len(ml_class_names) != len(set(ml_class_names)):
raise ValidationError({
'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.']
})
return data
def create(self, validated_data):
......
......@@ -248,3 +248,22 @@ class TestBulkClassification(FixtureAPITestCase):
('catte', 0.85, True),
],
)
def test_bulk_classification_no_duplicates(self):
"""
Test the bulk classification API prevents creating classifications with duplicate ML classes
"""
self.client.force_login(self.user)
with self.assertNumQueries(4):
response = self.client.post(
reverse('api:classification-bulk'),
format='json',
data=self.create_classifications_data([
{"class_name": 'dog', "confidence": 0.99},
{"class_name": 'dog', "confidence": 0.99},
])
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.']
})
......@@ -19,7 +19,7 @@ urlpatterns = [
# Link sent via email for password resets
path('user/reset/<uidb64>/<token>/', frontend_view.as_view(), name='password_reset_confirm'),
# Redirection URL for successful OAuth2 flows
path('imports/credentials/', frontend_view.as_view(), name='credentials'),
path('process/credentials/', frontend_view.as_view(), name='credentials'),
]
if 'debug_toolbar' in settings.INSTALLED_APPS:
......