diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index b7bd7f911074af0ac1ef9d80ec9c0063c88a65cd..2f658f8b730a1bc6b6bd25c05350c0625ad0595e 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -717,7 +717,7 @@ class DatasetElementSerializer(serializers.ModelSerializer):
         super().__init__(*args, **kwargs)
         if dataset := self.context.get("dataset"):
             self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus)
-            self.fields["set"].queryset = dataset.sets.all()
+            self.fields["set"].queryset = dataset.sets.using("default")
 
     def validate_element_id(self, element):
         dataset = self.context.get("dataset")