diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py index b7bf6429586f4d9f737ee86382b6f2126953b223..d18c647bc38353257fe0f31005a64d7acd7f179e 100644 --- a/arkindex/project/serializer_fields.py +++ b/arkindex/project/serializer_fields.py @@ -284,9 +284,10 @@ class DatasetSetsCountField(serializers.DictField): def get_attribute(self, instance): if not self.context.get("sets_count", True): return None - elts_count = {k.name: 0 for k in instance.sets.all()} + dataset_sets = instance.sets.all() + elts_count = {k.name: 0 for k in dataset_sets} elts_count.update( - DatasetElement.objects.filter(set_id__in=instance.sets.values_list("id")) + DatasetElement.objects.filter(set_id__in=[ds.id for ds in dataset_sets]) .values("set__name") .annotate(count=Count("id")) .values_list("set__name", "count") diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 2765dcfaef005aa756b42dd0784da3a08214e065..819042749e6ee622e2ad2ad52b5f9b1d25147216 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -727,6 +727,9 @@ class SelectionDatasetElementSerializer(serializers.Serializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # for openAPI schema generation + if "request" not in self.context: + return self.fields["set_id"].queryset = DatasetSet.objects.filter( dataset__corpus_id__in=Corpus.objects.readable(self.context["request"].user) ).select_related("dataset") diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 72149bfc3c478e57b340965038ffcb2bab372c3e..9dcef03ffc9c263e32a26714646fa0c29543c1a6 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -1166,7 +1166,6 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertDictEqual(data, {"count": 1, "next": None, "previous": None}) self.assertEqual(len(results), 1) dataset_element = results[0] - print(dataset_element) self.assertEqual(dataset_element["element"]["id"], str(self.page2.id)) self.assertEqual(dataset_element["set"], "training") @@ -2152,7 +2151,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_clone_name_too_long(self): dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(13): response = self.client.post( reverse("api:dataset-clone", kwargs={"pk": dataset.id}), format="json",