diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py
index e6429abdd1cf3abfe0f5f9fbd0c56b9f01cf6dee..40ef2f086adb0e6e821644c6bace3ee20713f607 100644
--- a/arkindex/documents/serializers/ml.py
+++ b/arkindex/documents/serializers/ml.py
@@ -112,6 +112,8 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
     """
     Serializer to create a single classification, defaulting to manual
     """
+    element = serializers.PrimaryKeyRelatedField(queryset=Element.objects.using('default'))
+    ml_class = serializers.PrimaryKeyRelatedField(queryset=MLClass.objects.using('default'))
     worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None)
     confidence = serializers.FloatField(
         min_value=0,
@@ -143,11 +145,11 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
         read_only_fields = ('id', 'state')
         validators = [
             UniqueTogetherValidator(
-                queryset=Classification.objects.filter(worker_version__isnull=False, source_id__isnull=True),
+                queryset=Classification.objects.using('default').filter(worker_version__isnull=False, source_id__isnull=True),
                 fields=['element', 'worker_version', 'ml_class']
             ),
             UniqueTogetherValidator(
-                queryset=Classification.objects.filter(worker_version__isnull=True, source_id__isnull=True),
+                queryset=Classification.objects.using('default').filter(worker_version__isnull=True, source_id__isnull=True),
                 fields=['element', 'ml_class']
             )
         ]