From b9b314c7be41367afef02105e2fa29b95e6854b1 Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Tue, 12 Mar 2024 19:20:26 +0100
Subject: [PATCH] fix api schema

---
 arkindex/project/serializer_fields.py        | 5 +++--
 arkindex/training/serializers.py             | 3 +++
 arkindex/training/tests/test_datasets_api.py | 3 +--
 3 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py
index b7bf642958..d18c647bc3 100644
--- a/arkindex/project/serializer_fields.py
+++ b/arkindex/project/serializer_fields.py
@@ -284,9 +284,10 @@ class DatasetSetsCountField(serializers.DictField):
     def get_attribute(self, instance):
         if not self.context.get("sets_count", True):
             return None
-        elts_count = {k.name: 0 for k in instance.sets.all()}
+        dataset_sets = instance.sets.all()
+        elts_count = {k.name: 0 for k in dataset_sets}
         elts_count.update(
-            DatasetElement.objects.filter(set_id__in=instance.sets.values_list("id"))
+            DatasetElement.objects.filter(set_id__in=[ds.id for ds in dataset_sets])
             .values("set__name")
             .annotate(count=Count("id"))
             .values_list("set__name", "count")
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 2765dcfaef..819042749e 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -727,6 +727,9 @@ class SelectionDatasetElementSerializer(serializers.Serializer):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        # for openAPI schema generation
+        if "request" not in self.context:
+            return
         self.fields["set_id"].queryset = DatasetSet.objects.filter(
             dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user)
         ).select_related("dataset")
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index 72149bfc3c..9dcef03ffc 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -1166,7 +1166,6 @@ class TestDatasetsAPI(FixtureAPITestCase):
         self.assertDictEqual(data, {"count": 1, "next": None, "previous": None})
         self.assertEqual(len(results), 1)
         dataset_element = results[0]
-        print(dataset_element)
         self.assertEqual(dataset_element["element"]["id"], str(self.page2.id))
         self.assertEqual(dataset_element["set"], "training")
 
@@ -2152,7 +2151,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
     def test_clone_name_too_long(self):
         dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user)
         self.client.force_login(self.user)
-        with self.assertNumQueries(14):
+        with self.assertNumQueries(13):
             response = self.client.post(
                 reverse("api:dataset-clone", kwargs={"pk": dataset.id}),
                 format="json",
-- 
GitLab