From b9b314c7be41367afef02105e2fa29b95e6854b1 Mon Sep 17 00:00:00 2001 From: mlbonhomme <bonhomme@teklia.com> Date: Tue, 12 Mar 2024 19:20:26 +0100 Subject: [PATCH] fix api schema --- arkindex/project/serializer_fields.py | 5 +++-- arkindex/training/serializers.py | 3 +++ arkindex/training/tests/test_datasets_api.py | 3 +-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py index b7bf642958..d18c647bc3 100644 --- a/arkindex/project/serializer_fields.py +++ b/arkindex/project/serializer_fields.py @@ -284,9 +284,10 @@ class DatasetSetsCountField(serializers.DictField): def get_attribute(self, instance): if not self.context.get("sets_count", True): return None - elts_count = {k.name: 0 for k in instance.sets.all()} + dataset_sets = instance.sets.all() + elts_count = {k.name: 0 for k in dataset_sets} elts_count.update( - DatasetElement.objects.filter(set_id__in=instance.sets.values_list("id")) + DatasetElement.objects.filter(set_id__in=[ds.id for ds in dataset_sets]) .values("set__name") .annotate(count=Count("id")) .values_list("set__name", "count") diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 2765dcfaef..819042749e 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -727,6 +727,9 @@ class SelectionDatasetElementSerializer(serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # for openAPI schema generation + if "request" not in self.context: + return self.fields["set_id"].queryset = DatasetSet.objects.filter( dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user) ).select_related("dataset") diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 72149bfc3c..9dcef03ffc 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -1166,7 +1166,6 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertDictEqual(data, {"count": 1, "next": None, "previous": None}) self.assertEqual(len(results), 1) dataset_element = results[0] - print(dataset_element) self.assertEqual(dataset_element["element"]["id"], str(self.page2.id)) self.assertEqual(dataset_element["set"], "training") @@ -2152,7 +2151,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_clone_name_too_long(self): dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(13): response = self.client.post( reverse("api:dataset-clone", kwargs={"pk": dataset.id}), format="json", -- GitLab