diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 0a42b292e206b7a48e9b1eb247a31a02e03dad09..530b7599210b9e9dcd26e0af5949b006b6689027 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -72,18 +72,18 @@ def _fetch_datasetelement_neighbors(datasetelements): SELECT n.id, lag(element_id) OVER ( - partition BY (n.dataset_id, n.set) + partition BY (n.set_id) order by n.element_id ) as previous, lead(element_id) OVER ( - partition BY (n.dataset_id, n.set) + partition BY (n.set_id) order by n.element_id ) as next FROM training_datasetelement as n - WHERE (dataset_id, set) IN ( - SELECT dataset_id, set + WHERE set_id IN ( + SELECT set_id FROM training_datasetelement WHERE id IN %(ids)s ) @@ -688,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): serializer_class = DatasetSerializer def get_queryset(self): - queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)).prefetch_related("sets") + queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) return queryset.select_related("corpus", "creator") def check_object_permissions(self, request, obj): @@ -910,7 +910,8 @@ class ElementDatasetSets(CorpusACLMixin, ListAPIView): qs = ( self.element.dataset_elements .select_related("set__dataset__creator") - .order_by("set__name", "id") + .prefetch_related("set__dataset__sets") + .order_by("set__dataset__name", "set__name") ) with_neighbors = self.request.query_params.get("with_neighbors", "false") @@ -953,7 +954,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): serializer_class = DatasetSerializer def get_queryset(self): - return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) + return ( + Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) + ) def check_object_permissions(self, request, dataset): if not self.has_write_access(dataset.corpus): diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 37dacc0d76fda21965f198cdefbf847a4f4d8b17..2765dcfaef005aa756b42dd0784da3a08214e065 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -512,7 +512,11 @@ class DatasetSerializer(serializers.ModelSerializer): help_text="Display name of the user who created the dataset.", ) - set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, required=False) + set_names = serializers.ListField( + child=serializers.CharField(max_length=50), + write_only=True, + default=serializers.CreateOnlyDefault(["training", "validation", "test"]) + ) sets = DatasetSetSerializer(many=True, read_only=True) # When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through @@ -587,10 +591,7 @@ class DatasetSerializer(serializers.ModelSerializer): @transaction.atomic def create(self, validated_data): - if "set_names" not in validated_data: - sets = ["training", "validation", "test"] - else: - sets = validated_data.pop("set_names") + sets = validated_data.pop("set_names") dataset = Dataset.objects.create(**validated_data) DatasetSet.objects.bulk_create( DatasetSet( diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 56f7e3fb418a4d8364a3831c4cc6948630144791..20b854a193c9f38042d3586f238748a95972c835 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -981,7 +981,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_retrieve(self): self.client.force_login(self.user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.get( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}) ) @@ -1010,7 +1010,7 @@ class TestDatasetsAPI(FixtureAPITestCase): self.client.force_login(self.user) self.dataset.task = self.task self.dataset.save() - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.get( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}) ) @@ -1431,7 +1431,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_add_element_wrong_element(self): element = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="folder")) self.client.force_login(self.user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.post( reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), data={"set": "test", "element_id": str(element.id)}, @@ -1724,7 +1724,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, - "set": "train", + "set": "training", "previous": None, "next": None }, { @@ -1770,7 +1770,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), }, - "set": "train", + "set": "training", "previous": None, "next": None }] @@ -1811,7 +1811,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, - "set": "train", + "set": "training", "previous": None, "next": None }, { @@ -1857,7 +1857,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), }, - "set": "train", + "set": "training", "previous": None, "next": None }] @@ -1880,9 +1880,10 @@ class TestDatasetsAPI(FixtureAPITestCase): sorted_dataset2_elements = sorted([str(self.page1.id), str(self.page3.id)]) page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id)) - with self.assertNumQueries(7): + with self.assertNumQueries(8): response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True}) self.assertEqual(response.status_code, status.HTTP_200_OK) + self.maxDiff = None self.assertDictEqual(response.json(), { "count": 3, "next": None, @@ -1908,7 +1909,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, - "set": "train", + "set": "training", "previous": ( sorted_dataset_elements[page1_index_1 - 1] if page1_index_1 - 1 >= 0 @@ -1952,7 +1953,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "id": str(ds.id), "name": ds.name } - for ds in self.dataset.sets.all() + for ds in self.dataset2.sets.all() ], "set_elements": None, "state": "open", @@ -1962,7 +1963,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), }, - "set": "train", + "set": "training", "previous": ( sorted_dataset2_elements[page1_index_2 - 1] if page1_index_2 == 1 @@ -2101,7 +2102,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_clone_existing_name(self): self.corpus.datasets.create(name="Clone of First Dataset", creator=self.user) self.client.force_login(self.user) - with self.assertNumQueries(11): + with self.assertNumQueries(15): response = self.client.post( reverse("api:dataset-clone", kwargs={"pk": self.dataset.id}), format="json", @@ -2115,12 +2116,14 @@ class TestDatasetsAPI(FixtureAPITestCase): ]) data = response.json() - data.pop("id") data.pop("created") data.pop("updated") + cloned_dataset = Dataset.objects.get(id=data["id"]) + self.maxDiff = None self.assertDictEqual( response.json(), { + "id": str(cloned_dataset.id), "name": "Clone of First Dataset 1", "description": self.dataset.description, "creator": self.user.display_name, @@ -2130,9 +2133,9 @@ class TestDatasetsAPI(FixtureAPITestCase): "id": str(ds.id), "name": ds.name } - for ds in self.dataset.sets.all() + for ds in cloned_dataset.sets.all() ], - "set_elements": {k: 0 for k in self.dataset.sets.all()}, + "set_elements": {str(k.name): 0 for k in self.dataset.sets.all()}, "state": DatasetState.Open.value, "task_id": None, }, @@ -2141,7 +2144,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_clone_name_too_long(self): dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user) self.client.force_login(self.user) - with self.assertNumQueries(11): + with self.assertNumQueries(14): response = self.client.post( reverse("api:dataset-clone", kwargs={"pk": dataset.id}), format="json",