Skip to content
Snippets Groups Projects
Commit 39bd4dbd authored by ml bonhomme's avatar ml bonhomme :bee:
Browse files

fixed datasets API

parent ca9ea129
No related branches found
No related tags found
1 merge request!2256New DatasetSet model
This commit is part of merge request !2256. Comments created here will be created in the context of that merge request.
...@@ -72,18 +72,18 @@ def _fetch_datasetelement_neighbors(datasetelements): ...@@ -72,18 +72,18 @@ def _fetch_datasetelement_neighbors(datasetelements):
SELECT SELECT
n.id, n.id,
lag(element_id) OVER ( lag(element_id) OVER (
partition BY (n.dataset_id, n.set) partition BY (n.set_id)
order by order by
n.element_id n.element_id
) as previous, ) as previous,
lead(element_id) OVER ( lead(element_id) OVER (
partition BY (n.dataset_id, n.set) partition BY (n.set_id)
order by order by
n.element_id n.element_id
) as next ) as next
FROM training_datasetelement as n FROM training_datasetelement as n
WHERE (dataset_id, set) IN ( WHERE set_id IN (
SELECT dataset_id, set SELECT set_id
FROM training_datasetelement FROM training_datasetelement
WHERE id IN %(ids)s WHERE id IN %(ids)s
) )
...@@ -688,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): ...@@ -688,7 +688,7 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView):
serializer_class = DatasetSerializer serializer_class = DatasetSerializer
def get_queryset(self): def get_queryset(self):
queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)).prefetch_related("sets") queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
return queryset.select_related("corpus", "creator") return queryset.select_related("corpus", "creator")
def check_object_permissions(self, request, obj): def check_object_permissions(self, request, obj):
...@@ -910,7 +910,8 @@ class ElementDatasetSets(CorpusACLMixin, ListAPIView): ...@@ -910,7 +910,8 @@ class ElementDatasetSets(CorpusACLMixin, ListAPIView):
qs = ( qs = (
self.element.dataset_elements self.element.dataset_elements
.select_related("set__dataset__creator") .select_related("set__dataset__creator")
.order_by("set__name", "id") .prefetch_related("set__dataset__sets")
.order_by("set__dataset__name", "set__name")
) )
with_neighbors = self.request.query_params.get("with_neighbors", "false") with_neighbors = self.request.query_params.get("with_neighbors", "false")
...@@ -953,7 +954,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): ...@@ -953,7 +954,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
serializer_class = DatasetSerializer serializer_class = DatasetSerializer
def get_queryset(self): def get_queryset(self):
return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) return (
Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user))
)
def check_object_permissions(self, request, dataset): def check_object_permissions(self, request, dataset):
if not self.has_write_access(dataset.corpus): if not self.has_write_access(dataset.corpus):
......
...@@ -512,7 +512,11 @@ class DatasetSerializer(serializers.ModelSerializer): ...@@ -512,7 +512,11 @@ class DatasetSerializer(serializers.ModelSerializer):
help_text="Display name of the user who created the dataset.", help_text="Display name of the user who created the dataset.",
) )
set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, required=False) set_names = serializers.ListField(
child=serializers.CharField(max_length=50),
write_only=True,
default=serializers.CreateOnlyDefault(["training", "validation", "test"])
)
sets = DatasetSetSerializer(many=True, read_only=True) sets = DatasetSetSerializer(many=True, read_only=True)
# When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through # When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through
...@@ -587,10 +591,7 @@ class DatasetSerializer(serializers.ModelSerializer): ...@@ -587,10 +591,7 @@ class DatasetSerializer(serializers.ModelSerializer):
@transaction.atomic @transaction.atomic
def create(self, validated_data): def create(self, validated_data):
if "set_names" not in validated_data: sets = validated_data.pop("set_names")
sets = ["training", "validation", "test"]
else:
sets = validated_data.pop("set_names")
dataset = Dataset.objects.create(**validated_data) dataset = Dataset.objects.create(**validated_data)
DatasetSet.objects.bulk_create( DatasetSet.objects.bulk_create(
DatasetSet( DatasetSet(
......
...@@ -981,7 +981,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -981,7 +981,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def test_retrieve(self): def test_retrieve(self):
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(5): with self.assertNumQueries(6):
response = self.client.get( response = self.client.get(
reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}) reverse("api:dataset-update", kwargs={"pk": self.dataset.pk})
) )
...@@ -1010,7 +1010,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1010,7 +1010,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
self.client.force_login(self.user) self.client.force_login(self.user)
self.dataset.task = self.task self.dataset.task = self.task
self.dataset.save() self.dataset.save()
with self.assertNumQueries(5): with self.assertNumQueries(6):
response = self.client.get( response = self.client.get(
reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}) reverse("api:dataset-update", kwargs={"pk": self.dataset.pk})
) )
...@@ -1431,7 +1431,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1431,7 +1431,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def test_add_element_wrong_element(self): def test_add_element_wrong_element(self):
element = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="folder")) element = self.private_corpus.elements.create(type=self.private_corpus.types.create(slug="folder"))
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(4): with self.assertNumQueries(5):
response = self.client.post( response = self.client.post(
reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}),
data={"set": "test", "element_id": str(element.id)}, data={"set": "test", "element_id": str(element.id)},
...@@ -1724,7 +1724,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1724,7 +1724,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"created": self.dataset.created.isoformat().replace("+00:00", "Z"), "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
}, },
"set": "train", "set": "training",
"previous": None, "previous": None,
"next": None "next": None
}, { }, {
...@@ -1770,7 +1770,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1770,7 +1770,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
}, },
"set": "train", "set": "training",
"previous": None, "previous": None,
"next": None "next": None
}] }]
...@@ -1811,7 +1811,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1811,7 +1811,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"created": self.dataset.created.isoformat().replace("+00:00", "Z"), "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
}, },
"set": "train", "set": "training",
"previous": None, "previous": None,
"next": None "next": None
}, { }, {
...@@ -1857,7 +1857,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1857,7 +1857,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
}, },
"set": "train", "set": "training",
"previous": None, "previous": None,
"next": None "next": None
}] }]
...@@ -1880,9 +1880,10 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1880,9 +1880,10 @@ class TestDatasetsAPI(FixtureAPITestCase):
sorted_dataset2_elements = sorted([str(self.page1.id), str(self.page3.id)]) sorted_dataset2_elements = sorted([str(self.page1.id), str(self.page3.id)])
page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id)) page1_index_2 = sorted_dataset2_elements.index(str(self.page1.id))
with self.assertNumQueries(7): with self.assertNumQueries(8):
response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True}) response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": True})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.maxDiff = None
self.assertDictEqual(response.json(), { self.assertDictEqual(response.json(), {
"count": 3, "count": 3,
"next": None, "next": None,
...@@ -1908,7 +1909,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1908,7 +1909,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"created": self.dataset.created.isoformat().replace("+00:00", "Z"), "created": self.dataset.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"),
}, },
"set": "train", "set": "training",
"previous": ( "previous": (
sorted_dataset_elements[page1_index_1 - 1] sorted_dataset_elements[page1_index_1 - 1]
if page1_index_1 - 1 >= 0 if page1_index_1 - 1 >= 0
...@@ -1952,7 +1953,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1952,7 +1953,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"id": str(ds.id), "id": str(ds.id),
"name": ds.name "name": ds.name
} }
for ds in self.dataset.sets.all() for ds in self.dataset2.sets.all()
], ],
"set_elements": None, "set_elements": None,
"state": "open", "state": "open",
...@@ -1962,7 +1963,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -1962,7 +1963,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
"created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "created": self.dataset2.created.isoformat().replace("+00:00", "Z"),
"updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"),
}, },
"set": "train", "set": "training",
"previous": ( "previous": (
sorted_dataset2_elements[page1_index_2 - 1] sorted_dataset2_elements[page1_index_2 - 1]
if page1_index_2 == 1 if page1_index_2 == 1
...@@ -2101,7 +2102,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -2101,7 +2102,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def test_clone_existing_name(self): def test_clone_existing_name(self):
self.corpus.datasets.create(name="Clone of First Dataset", creator=self.user) self.corpus.datasets.create(name="Clone of First Dataset", creator=self.user)
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(11): with self.assertNumQueries(15):
response = self.client.post( response = self.client.post(
reverse("api:dataset-clone", kwargs={"pk": self.dataset.id}), reverse("api:dataset-clone", kwargs={"pk": self.dataset.id}),
format="json", format="json",
...@@ -2115,12 +2116,14 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -2115,12 +2116,14 @@ class TestDatasetsAPI(FixtureAPITestCase):
]) ])
data = response.json() data = response.json()
data.pop("id")
data.pop("created") data.pop("created")
data.pop("updated") data.pop("updated")
cloned_dataset = Dataset.objects.get(id=data["id"])
self.maxDiff = None
self.assertDictEqual( self.assertDictEqual(
response.json(), response.json(),
{ {
"id": str(cloned_dataset.id),
"name": "Clone of First Dataset 1", "name": "Clone of First Dataset 1",
"description": self.dataset.description, "description": self.dataset.description,
"creator": self.user.display_name, "creator": self.user.display_name,
...@@ -2130,9 +2133,9 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -2130,9 +2133,9 @@ class TestDatasetsAPI(FixtureAPITestCase):
"id": str(ds.id), "id": str(ds.id),
"name": ds.name "name": ds.name
} }
for ds in self.dataset.sets.all() for ds in cloned_dataset.sets.all()
], ],
"set_elements": {k: 0 for k in self.dataset.sets.all()}, "set_elements": {str(k.name): 0 for k in self.dataset.sets.all()},
"state": DatasetState.Open.value, "state": DatasetState.Open.value,
"task_id": None, "task_id": None,
}, },
...@@ -2141,7 +2144,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ...@@ -2141,7 +2144,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
def test_clone_name_too_long(self): def test_clone_name_too_long(self):
dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user) dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user)
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(11): with self.assertNumQueries(14):
response = self.client.post( response = self.client.post(
reverse("api:dataset-clone", kwargs={"pk": dataset.id}), reverse("api:dataset-clone", kwargs={"pk": dataset.id}),
format="json", format="json",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment