diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py
index b7bf6429586f4d9f737ee86382b6f2126953b223..d18c647bc38353257fe0f31005a64d7acd7f179e 100644
--- a/arkindex/project/serializer_fields.py
+++ b/arkindex/project/serializer_fields.py
@@ -284,9 +284,10 @@ class DatasetSetsCountField(serializers.DictField):
     def get_attribute(self, instance):
         if not self.context.get("sets_count", True):
             return None
-        elts_count = {k.name: 0 for k in instance.sets.all()}
+        dataset_sets = instance.sets.all()
+        elts_count = {k.name: 0 for k in dataset_sets}
         elts_count.update(
-            DatasetElement.objects.filter(set_id__in=instance.sets.values_list("id"))
+            DatasetElement.objects.filter(set_id__in=[ds.id for ds in dataset_sets])
             .values("set__name")
             .annotate(count=Count("id"))
             .values_list("set__name", "count")
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 2765dcfaef005aa756b42dd0784da3a08214e065..819042749e6ee622e2ad2ad52b5f9b1d25147216 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -727,6 +727,9 @@ class SelectionDatasetElementSerializer(serializers.Serializer):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        # for openAPI schema generation
+        if "request" not in self.context:
+            return
         self.fields["set_id"].queryset = DatasetSet.objects.filter(
             dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user)
         ).select_related("dataset")
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 72149bfc3c478e57b340965038ffcb2bab372c3e..9dcef03ffc9c263e32a26714646fa0c29543c1a6 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -1166,7 +1166,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertDictEqual(data, {"count": 1, "next": None, "previous": None})
         self.assertEqual(len(results), 1)
         dataset_element = results[0]
-        print(dataset_element)
         self.assertEqual(dataset_element["element"]["id"], str(self.page2.id))
         self.assertEqual(dataset_element["set"], "training")
 
@@ -2152,7 +2151,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
     def test_clone_name_too_long(self):
         dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user)
         self.client.force_login(self.user)
-        with self.assertNumQueries(14):
+        with self.assertNumQueries(13):
             response = self.client.post(
                 reverse("api:dataset-clone", kwargs={"pk": dataset.id}),
                 format="json",