From 0806509f39f0d822e92b7cc4d3b7acf3f3d24bcc Mon Sep 17 00:00:00 2001 From: mlbonhomme <bonhomme@teklia.com> Date: Thu, 7 Mar 2024 18:34:30 +0100 Subject: [PATCH] WIP fix API endpoints :'( --- arkindex/documents/fixtures/data.json | 12 +- arkindex/project/api_v1.py | 4 +- arkindex/training/api.py | 63 ++--- .../migrations/0007_datasetset_model.py | 6 +- arkindex/training/models.py | 8 + arkindex/training/serializers.py | 21 +- arkindex/training/tests/test_datasets_api.py | 220 +++++++++++------- 7 files changed, 197 insertions(+), 137 deletions(-) diff --git a/arkindex/documents/fixtures/data.json b/arkindex/documents/fixtures/data.json index ac55bc4e72..0037a3d48a 100644 --- a/arkindex/documents/fixtures/data.json +++ b/arkindex/documents/fixtures/data.json @@ -4022,7 +4022,7 @@ "model": "training.datasetset", "pk": "00e3b37b-f1ed-4adb-a1de-a1c103edaa24", "fields": { - "name": "Test", + "name": "test", "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2" } }, @@ -4030,7 +4030,7 @@ "model": "training.datasetset", "pk": "95255ff4-7bca-424c-8e7e-d8b5c33c585f", "fields": { - "name": "Train", + "name": "training", "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2" } }, @@ -4038,7 +4038,7 @@ "model": "training.datasetset", "pk": "b21e4b31-1dc1-4015-9933-3a8e180cf2e0", "fields": { - "name": "Validation", + "name": "validation", "dataset": "170bc276-0e11-483a-9486-88a4bf2079b2" } }, @@ -4046,7 +4046,7 @@ "model": "training.datasetset", "pk": "b76d919c-ce7e-4896-94c3-9b47862b997a", "fields": { - "name": "Validation", + "name": "validation", "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810" } }, @@ -4054,7 +4054,7 @@ "model": "training.datasetset", "pk": "d5f4d410-0588-4d19-8d08-b57a11bad67f", "fields": { - "name": "Train", + "name": "training", "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810" } }, @@ -4062,7 +4062,7 @@ "model": "training.datasetset", "pk": "db70ad8a-8d6b-44c6-b8d3-837bf5a4ab8e", "fields": { - "name": "Test", + "name": "test", "dataset": "6b32dab0-e3f2-4fff-b0ad-2ea1ee550810" } } diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 17c92d941b..3c893fc12f 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -109,7 +109,7 @@ from arkindex.training.api import ( DatasetElementDestroy, DatasetElements, DatasetUpdate, - ElementDatasets, + ElementDatasetSets, MetricValueBulkCreate, MetricValueCreate, ModelCompatibleWorkerManage, @@ -184,7 +184,7 @@ api = [ # Datasets path("corpus/<uuid:pk>/datasets/", CorpusDataset.as_view(), name="corpus-datasets"), path("corpus/<uuid:pk>/datasets/selection/", CreateDatasetElementsSelection.as_view(), name="dataset-elements-selection"), - path("element/<uuid:pk>/datasets/", ElementDatasets.as_view(), name="element-datasets"), + path("element/<uuid:pk>/datasets/", ElementDatasetSets.as_view(), name="element-datasets"), path("datasets/<uuid:pk>/", DatasetUpdate.as_view(), name="dataset-update"), path("datasets/<uuid:pk>/clone/", DatasetClone.as_view(), name="dataset-clone"), path("datasets/<uuid:pk>/elements/", DatasetElements.as_view(), name="dataset-elements"), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 06034116e9..240282d88e 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -29,6 +29,7 @@ from arkindex.project.tools import BulkMap from arkindex.training.models import ( Dataset, DatasetElement, + DatasetSet, DatasetState, MetricValue, Model, @@ -40,7 +41,7 @@ from arkindex.training.serializers import ( DatasetElementInfoSerializer, DatasetElementSerializer, DatasetSerializer, - ElementDatasetSerializer, + ElementDatasetSetSerializer, MetricValueBulkSerializer, MetricValueCreateSerializer, ModelCompatibleWorkerSerializer, @@ -58,7 +59,7 @@ def _fetch_datasetelement_neighbors(datasetelements): """ Retrieve the neighbors for a list of DatasetElements, and annotate these DatasetElements with next and previous attributes. - The ElementDatasets endpoint uses arkindex.project.tools.BulkMap to apply this method and + The ElementDatasetSets endpoint uses arkindex.project.tools.BulkMap to apply this method and perform the second request *after* DRF's pagination, because there is no way to perform post-processing after pagination in Django without having to use Django private methods. """ @@ -708,7 +709,8 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): raise ValidationError(detail="This dataset is in complete state and cannot be modified anymore.") def perform_destroy(self, dataset): - dataset.dataset_elements.all().delete() + DatasetElement.objects.filter(set__dataset_id=dataset.id).delete() + dataset.sets.all().delete() super().perform_destroy(dataset) @@ -769,13 +771,13 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView): def get_queryset(self): qs = ( - self.dataset.dataset_elements - .prefetch_related("element") + DatasetElement.objects.filter(set__dataset_id=self.dataset.id) + .prefetch_related("element", "set") .select_related("element__type", "element__corpus", "element__image__server") .order_by("element_id", "id") ) if "set" in self.request.query_params: - qs = qs.filter(set=self.request.query_params["set"]) + qs = qs.filter(set_name=self.request.query_params["set"]) return qs def get_serializer_context(self): @@ -788,20 +790,12 @@ class DatasetElements(CorpusACLMixin, ListCreateAPIView): @extend_schema_view( delete=extend_schema( operation_id="DestroyDatasetElement", - parameters=[ - OpenApiParameter( - "set", - type=str, - description="Name of the set from which to remove the element.", - required=True, - ) - ], tags=["datasets"] ) ) class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView): """ - Remove an element from a dataset. + Remove an element from a dataset set. Elements can only be removed from **open** datasets. @@ -812,17 +806,15 @@ class DatasetElementDestroy(CorpusACLMixin, DestroyAPIView): lookup_url_kwarg = "element" def destroy(self, request, *args, **kwargs): - if not self.request.query_params.get("set"): + if not (set_name := self.request.query_params.get("set")): raise ValidationError({"set": ["This field is required."]}) dataset_element = get_object_or_404( - DatasetElement.objects.select_related("dataset__corpus"), - dataset_id=self.kwargs["dataset"], - element_id=self.kwargs["element"], - set=self.request.query_params.get("set") + DatasetElement.objects.select_related("set__dataset__corpus").filter(set__dataset_id=self.kwargs["dataset"], set__name=set_name), + element_id=self.kwargs["element"] ) - if dataset_element.dataset.state != DatasetState.Open: + if dataset_element.set.dataset.state != DatasetState.Open: raise ValidationError({"dataset": ["Elements can only be removed from open Datasets."]}) - if not self.has_write_access(dataset_element.dataset.corpus): + if not self.has_write_access(dataset_element.set.dataset.corpus): raise PermissionDenied(detail="You need a Contributor access to the dataset to perform this action.") dataset_element.delete() return Response(status=status.HTTP_204_NO_CONTENT) @@ -897,14 +889,14 @@ class CreateDatasetElementsSelection(CorpusACLMixin, CreateAPIView): ) ], ) -class ElementDatasets(CorpusACLMixin, ListAPIView): +class ElementDatasetSets(CorpusACLMixin, ListAPIView): """ - List all datasets containing a specific element. + List all dataset sets containing a specific element. Requires a **guest** access to the element's corpus. """ permission_classes = (IsVerifiedOrReadOnly, ) - serializer_class = ElementDatasetSerializer + serializer_class = ElementDatasetSetSerializer @cached_property def element(self): @@ -916,9 +908,9 @@ class ElementDatasets(CorpusACLMixin, ListAPIView): def get_queryset(self): qs = ( - self.element.dataset_elements.all() - .select_related("dataset__creator") - .order_by("dataset__name", "set", "dataset_id") + self.element.dataset_elements + .select_related("set__dataset__creator") + .order_by("set__name", "id") ) with_neighbors = self.request.query_params.get("with_neighbors", "false") @@ -961,7 +953,9 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): serializer_class = DatasetSerializer def get_queryset(self): - return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) + return ( + Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) + ) def check_object_permissions(self, request, dataset): if not self.has_write_access(dataset.corpus): @@ -996,11 +990,18 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): clone.creator = request.user clone.save() + # Clone dataset sets + cloned_sets = DatasetSet.objects.bulk_create([ + DatasetSet(dataset_id=clone.id, name=set.name) + for set in dataset.sets.all() + ]) # Associate all elements to the clone DatasetElement.objects.bulk_create([ - DatasetElement(element_id=elt_id, dataset_id=clone.id, set=set_name) - for elt_id, set_name in dataset.dataset_elements.values_list("element_id", "set") + DatasetElement(element_id=elt_id, set=next(new_set for new_set in cloned_sets if new_set.name == set_name)) + for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id) + .values_list("element_id", "set__name") ]) + return Response( DatasetSerializer(clone).data, status=status.HTTP_201_CREATED, diff --git a/arkindex/training/migrations/0007_datasetset_model.py b/arkindex/training/migrations/0007_datasetset_model.py index 8630baa593..c3052f691c 100644 --- a/arkindex/training/migrations/0007_datasetset_model.py +++ b/arkindex/training/migrations/0007_datasetset_model.py @@ -59,7 +59,7 @@ class Migration(migrations.Migration): UPDATE training_datasetelement de SET set_id_id = ds.id FROM training_datasetset ds - WHERE de.dataset_id = ds.dataset_id + WHERE de.dataset_id = ds.dataset_id AND de.set = ds.name """, reverse_sql=migrations.RunSQL.noop, ), @@ -85,6 +85,10 @@ class Migration(migrations.Migration): name="set", field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="set_elements", to="training.datasetset"), ), + migrations.AddConstraint( + model_name="datasetelement", + constraint=models.UniqueConstraint(fields=("element_id", "set_id"), name="unique_set_element"), + ), migrations.RemoveField( model_name="dataset", name="sets" diff --git a/arkindex/training/models.py b/arkindex/training/models.py index 06f4f84c5f..c37ea35b09 100644 --- a/arkindex/training/models.py +++ b/arkindex/training/models.py @@ -309,3 +309,11 @@ class DatasetElement(models.Model): related_name="set_elements", on_delete=models.DO_NOTHING, ) + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["element_id", "set_id"], + name="unique_set_element", + ), + ] diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 4acb03a663..c65d048fc1 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -512,7 +512,7 @@ class DatasetSerializer(serializers.ModelSerializer): help_text="Display name of the user who created the dataset.", ) - set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, default=["Training", "Validation", "Test"]) + set_names = serializers.ListField(child=serializers.CharField(max_length=50), write_only=True, default=["training", "validation", "test"]) sets = DatasetSetSerializer(many=True, read_only=True) # When creating the dataset, the dataset's corpus comes from the URL, so the APIView passes it through @@ -657,6 +657,7 @@ class DatasetElementSerializer(serializers.ModelSerializer): default=_dataset_from_context, write_only=True, ) + set = serializers.SlugRelatedField(queryset=DatasetSet.objects.none(), slug_field="name") class Meta: model = DatasetElement @@ -665,7 +666,7 @@ class DatasetElementSerializer(serializers.ModelSerializer): validators = [ UniqueTogetherValidator( queryset=DatasetElement.objects.all(), - fields=["dataset", "element_id", "set"], + fields=["element_id", "set"], message="This element is already part of this set.", ) ] @@ -674,13 +675,12 @@ class DatasetElementSerializer(serializers.ModelSerializer): super().__init__(*args, **kwargs) if dataset := self.context.get("dataset"): self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus) + self.fields["set"].queryset = dataset.sets.all() - def validate_set(self, value): - # The set must match the `sets` array defined at the dataset level - dataset = self.context["dataset"] - if dataset and value not in dataset.sets: - raise ValidationError(f"This dataset has no set named {value}.") - return value + def validate(self, data): + data = super().validate(data) + data.pop("dataset") + return data class DatasetElementInfoSerializer(DatasetElementSerializer): @@ -698,10 +698,11 @@ class DatasetElementInfoSerializer(DatasetElementSerializer): fields = DatasetElementSerializer.Meta.fields + ("dataset",) -class ElementDatasetSerializer(serializers.ModelSerializer): - dataset = DatasetSerializer() +class ElementDatasetSetSerializer(serializers.ModelSerializer): + dataset = DatasetSerializer(source="set.dataset") previous = serializers.UUIDField(allow_null=True, read_only=True) next = serializers.UUIDField(allow_null=True, read_only=True) + set = serializers.SlugRelatedField(slug_field="name", read_only=True) class Meta: model = DatasetElement diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 9b19becc27..6a7387c464 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -9,7 +9,7 @@ from arkindex.documents.models import Corpus from arkindex.process.models import Process, ProcessDataset, ProcessMode from arkindex.project.tests import FixtureAPITestCase from arkindex.project.tools import fake_now -from arkindex.training.models import Dataset, DatasetSet, DatasetState +from arkindex.training.models import Dataset, DatasetElement, DatasetSet, DatasetState from arkindex.users.models import Role, User # Using the fake DB fixtures creation date when needed @@ -267,7 +267,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_create(self): self.client.force_login(self.user) - with self.assertNumQueries(10): + with self.assertNumQueries(11): response = self.client.post( reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}), data={"name": "My dataset", "description": "My dataset for my experiments."}, @@ -279,11 +279,17 @@ class TestDatasetsAPI(FixtureAPITestCase): "id": str(created_dataset.id), "name": "My dataset", "description": "My dataset for my experiments.", - "sets": {}, + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in created_dataset.sets.all() + ], "set_elements": { - "Training": 0, - "Test": 0, - "Validation": 0, + "training": 0, + "test": 0, + "validation": 0, }, "state": "open", "creator": "Test user", @@ -1066,14 +1072,15 @@ class TestDatasetsAPI(FixtureAPITestCase): """ DestroyDataset also deletes DatasetElements """ - self.dataset.dataset_elements.create(element_id=self.vol.id, set="test") - self.dataset.dataset_elements.create(element_id=self.vol.id, set="training") - self.dataset.dataset_elements.create(element_id=self.page1.id, set="training") - self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation") - self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation") + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + test_set.set_elements.create(element_id=self.vol.id) + train_set.set_elements.create(element_id=self.vol.id, set="training") + train_set.set_elements.create(element_id=self.page1.id, set="training") + validation_set.set_elements.create(element_id=self.page2.id, set="validation") + validation_set.set_elements.create(element_id=self.page3.id, set="validation") self.client.force_login(self.user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.delete( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), ) @@ -1082,7 +1089,7 @@ class TestDatasetsAPI(FixtureAPITestCase): with self.assertRaises(Dataset.DoesNotExist): self.dataset.refresh_from_db() - self.assertFalse(self.dataset.dataset_elements.exists()) + self.assertFalse(DatasetElement.objects.filter(set__dataset_id=self.dataset.id).exists()) # No elements should have been deleted self.vol.refresh_from_db() @@ -1112,12 +1119,13 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(filter_rights_mock.call_args, call(self.user, Corpus, Role.Guest.value)) def test_list_elements_set_filter_wrong_set(self): - self.dataset.dataset_elements.create(element_id=self.page1.id, set="test") + test_set = self.dataset.sets.order_by("name").first() + test_set.set_elements.create(element_id=self.page1.id) self.client.force_login(self.user) with self.assertNumQueries(4): response = self.client.get( reverse("api:dataset-elements", kwargs={"pk": str(self.dataset.id)}), - data={"set": "aaaaa"} + data={"set": "training"} ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -1128,10 +1136,11 @@ class TestDatasetsAPI(FixtureAPITestCase): }) def test_list_elements_set_filter(self): - self.dataset.dataset_elements.create(element_id=self.page1.id, set="test") - self.dataset.dataset_elements.create(element_id=self.page2.id, set="training") + test_set, train_set, _ = self.dataset.sets.all().order_by("name") + test_set.set_elements.create(element_id=self.page1.id) + train_set.set_elements.create(element_id=self.page2.id) self.client.force_login(self.user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.get( reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}), data={"set": "training", "with_count": "true"}, @@ -1142,15 +1151,17 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertDictEqual(data, {"count": 1, "next": None, "previous": None}) self.assertEqual(len(results), 1) dataset_element = results[0] + print(dataset_element) self.assertEqual(dataset_element["element"]["id"], str(self.page2.id)) self.assertEqual(dataset_element["set"], "training") @patch("arkindex.documents.models.Element.thumbnail", MagicMock(s3_url="s3_url")) def test_list_elements(self): - self.dataset.dataset_elements.create(element_id=self.vol.id, set="test") - self.dataset.dataset_elements.create(element_id=self.page1.id, set="training") - self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation") - self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation") + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + test_set.set_elements.create(element_id=self.vol.id) + train_set.set_elements.create(element_id=self.page1.id) + validation_set.set_elements.create(element_id=self.page2.id) + validation_set.set_elements.create(element_id=self.page3.id) self.page3.confidence = 0.42 self.page3.mirrored = True self.page3.rotation_angle = 42 @@ -1339,7 +1350,7 @@ class TestDatasetsAPI(FixtureAPITestCase): self.dataset.state = state self.dataset.save() with self.subTest(state=state): - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.get( reverse("api:dataset-elements", kwargs={"pk": self.dataset.pk}), {"page_size": 3}, @@ -1418,7 +1429,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_add_element_wrong_set(self): self.client.force_login(self.user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.post( reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), data={"set": "aaaaaaaaaaa", "element_id": str(self.vol.id)}, @@ -1426,7 +1437,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - "set": ["This dataset has no set named aaaaaaaaaaa."], + "set": ["Object with name=aaaaaaaaaaa does not exist."], }) def test_add_element_dataset_requires_open(self): @@ -1443,9 +1454,10 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertListEqual(response.json(), ["You can only add elements to a dataset in an open state."]) def test_add_element_already_exists(self): - self.dataset.dataset_elements.create(element=self.page1, set="test") + test_set = self.dataset.sets.order_by("name").first() + test_set.set_elements.create(element=self.page1) self.client.force_login(self.user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), data={"set": "test", "element_id": str(self.page1.id)}, @@ -1455,8 +1467,9 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertDictEqual(response.json(), {"non_field_errors": ["This element is already part of this set."]}) def test_add_element(self): + train_set = self.dataset.sets.get(name="training") self.client.force_login(self.user) - with self.assertNumQueries(10): + with self.assertNumQueries(11): response = self.client.post( reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), data={"set": "training", "element_id": str(self.page1.id)}, @@ -1464,7 +1477,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertQuerysetEqual( - self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"), + train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"), [("training", "Volume 1, page 1r")] ) @@ -1575,9 +1588,10 @@ class TestDatasetsAPI(FixtureAPITestCase): }) def test_add_from_selection(self): - self.dataset.dataset_elements.create(element=self.page1, set="training") + train_set = self.dataset.sets.get(name="training") + train_set.set_elements.create(element=self.page1) self.assertQuerysetEqual( - self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"), + train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"), [("training", "Volume 1, page 1r")] ) self.user.selected_elements.set([self.vol, self.page1, self.page2]) @@ -1586,12 +1600,12 @@ class TestDatasetsAPI(FixtureAPITestCase): with self.assertNumQueries(6): response = self.client.post( reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}), - data={"set": "training", "dataset_id": self.dataset.id}, + data={"set_id": str(train_set.id)}, format="json", ) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertQuerysetEqual( - self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"), + train_set.set_elements.values_list("set__name", "element__name").order_by("element__name"), [ ("training", "Volume 1"), ("training", "Volume 1, page 1r"), @@ -1623,8 +1637,9 @@ class TestDatasetsAPI(FixtureAPITestCase): """ A non authenticated user can list datasets of a public element """ - self.dataset.dataset_elements.create(element=self.vol, set="train") - with self.assertNumQueries(3): + train_set = self.dataset.sets.get(name="training") + train_set.set_elements.create(element=self.vol) + with self.assertNumQueries(4): response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.vol.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -1637,7 +1652,13 @@ class TestDatasetsAPI(FixtureAPITestCase): "id": str(self.dataset.id), "name": "First Dataset", "description": "dataset number one", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset.sets.all() + ], "set_elements": None, "state": "open", "corpus_id": str(self.corpus.id), @@ -1646,7 +1667,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, - "set": "train", + "set": "training", "previous": None, "next": None }] @@ -1654,10 +1675,12 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_element_datasets(self): self.client.force_login(self.user) - self.dataset.dataset_elements.create(element=self.page1, set="train") - self.dataset.dataset_elements.create(element=self.page1, set="validation") - self.dataset2.dataset_elements.create(element=self.page1, set="train") - with self.assertNumQueries(5): + _, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1, set="train") + validation_set.set_elements.create(element=self.page1, set="validation") + train_set_2 = self.dataset2.sets.get(name="training") + train_set_2.set_elements.create(element=self.page1, set="train") + with self.assertNumQueries(6): response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -1721,10 +1744,11 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_element_datasets_with_neighbors_false(self): self.client.force_login(self.user) - self.dataset.dataset_elements.create(element=self.page1, set="train") - self.dataset.dataset_elements.create(element=self.page1, set="validation") - self.dataset2.dataset_elements.create(element=self.page1, set="train") - with self.assertNumQueries(5): + _, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + validation_set.set_elements.create(element=self.page1) + train_set.set_elements.create(element=self.page1) + with self.assertNumQueries(6): response = self.client.get(reverse("api:element-datasets", kwargs={"pk": str(self.page1.id)}), {"with_neighbors": False}) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -1788,12 +1812,13 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_element_datasets_with_neighbors(self): self.client.force_login(self.user) - self.dataset.dataset_elements.create(element=self.page1, set="train") - self.dataset.dataset_elements.create(element=self.page2, set="train") - self.dataset.dataset_elements.create(element=self.page3, set="train") - self.dataset.dataset_elements.create(element=self.page1, set="validation") - self.dataset2.dataset_elements.create(element=self.page1, set="train") - self.dataset2.dataset_elements.create(element=self.page3, set="train") + _, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + train_set.set_elements.create(element=self.page2) + train_set.set_elements.create(element=self.page3) + validation_set.set_elements.create(element=self.page1) + train_set.set_elements.create(element=self.page1) + train_set.set_elements.create(element=self.page3) # Results are alphabetically ordered and must not depend on the random page UUIDs sorted_dataset_elements = sorted([str(self.page1.id), str(self.page2.id), str(self.page3.id)]) @@ -1940,13 +1965,14 @@ class TestDatasetsAPI(FixtureAPITestCase): self.dataset.state = DatasetState.Error self.dataset.task = self.task self.dataset.save() - self.dataset.dataset_elements.create(element=self.page1, set="test") - self.dataset.dataset_elements.create(element=self.page1, set="validation") - self.dataset.dataset_elements.create(element=self.vol, set="validation") + test_set, _, validation_set = self.dataset.sets.all().order_by("name") + test_set.set_elements.create(element=self.page1) + validation_set.set_elements.create(element=self.page1) + validation_set.set_elements.create(element=self.vol) self.assertCountEqual(self.corpus.datasets.values_list("name", flat=True), ["First Dataset", "Second Dataset"]) self.client.force_login(self.user) - with self.assertNumQueries(12): + with self.assertNumQueries(16): response = self.client.post( reverse("api:dataset-clone", kwargs={"pk": self.dataset.id}), format="json", @@ -1959,9 +1985,11 @@ class TestDatasetsAPI(FixtureAPITestCase): ]) data = response.json() clone = self.corpus.datasets.get(id=data["id"]) + test_clone, train_clone, val_clone = clone.sets.all().order_by("name") self.assertEqual(clone.creator, self.user) data.pop("created") data.pop("updated") + cloned_sets = data.pop("sets") self.assertDictEqual( response.json(), { @@ -1970,14 +1998,27 @@ class TestDatasetsAPI(FixtureAPITestCase): "description": self.dataset.description, "creator": self.user.display_name, "corpus_id": str(self.corpus.id), - "sets": ["training", "test", "validation"], "set_elements": {"test": 1, "training": 0, "validation": 2}, "state": DatasetState.Open.value, "task_id": str(self.task.id), }, ) + self.assertCountEqual(cloned_sets, [ + { + "name": "training", + "id": str(train_clone.id) + }, + { + "name": "test", + "id": str(test_clone.id) + }, + { + "name": "validation", + "id": str(val_clone.id) + } + ]) self.assertQuerysetEqual( - clone.dataset_elements.values_list("set", "element__name").order_by("element__name", "set"), + DatasetElement.objects.filter(set__dataset_id=clone.id).values_list("set__name", "element__name").order_by("element__name", "set__name"), [ ("validation", "Volume 1"), ("test", "Volume 1, page 1r"), @@ -2040,7 +2081,7 @@ class TestDatasetsAPI(FixtureAPITestCase): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -2051,28 +2092,29 @@ class TestDatasetsAPI(FixtureAPITestCase): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @patch("arkindex.project.mixins.has_access", return_value=False) def test_destroy_dataset_element_requires_contributor(self, has_access_mock): self.client.force_login(self.read_user) - self.dataset.dataset_elements.create(element=self.page1, set="train") - self.dataset.dataset_elements.create(element=self.page1, set="validation") - self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1) - self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1) + _, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + validation_set.set_elements.create(element=self.page1) + self.assertEqual(train_set.set_elements.count(), 1) + self.assertEqual(validation_set.set_elements.count(), 1) with self.assertNumQueries(3): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the dataset to perform this action."}) self.dataset.refresh_from_db() - self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1) - self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1) + self.assertEqual(train_set.set_elements.count(), 1) + self.assertEqual(validation_set.set_elements.count(), 1) self.assertEqual(has_access_mock.call_count, 1) self.assertEqual(has_access_mock.call_args, call(self.read_user, self.corpus, Role.Contributor.value, skip_public=False)) @@ -2089,23 +2131,24 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_destroy_dataset_element_requires_open_dataset(self): self.client.force_login(self.user) - self.dataset.dataset_elements.create(element=self.page1, set="train") - self.dataset.dataset_elements.create(element=self.page1, set="validation") + _, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + validation_set.set_elements.create(element=self.page1) self.dataset.state = DatasetState.Error self.dataset.save() - self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1) - self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1) + self.assertEqual(train_set.set_elements.count(), 1) + self.assertEqual(validation_set.set_elements.count(), 1) with self.assertNumQueries(3): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"dataset": ["Elements can only be removed from open Datasets."]}) self.dataset.refresh_from_db() - self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1) - self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1) + self.assertEqual(train_set.set_elements.count(), 1) + self.assertEqual(validation_set.set_elements.count(), 1) def test_destroy_dataset_element_dataset_doesnt_exist(self): self.client.force_login(self.user) @@ -2113,7 +2156,7 @@ class TestDatasetsAPI(FixtureAPITestCase): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "element": str(self.page1.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertDictEqual(response.json(), {"detail": "Not found."}) @@ -2135,49 +2178,52 @@ class TestDatasetsAPI(FixtureAPITestCase): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertDictEqual(response.json(), {"detail": "Not found."}) def test_destroy_dataset_element_element_not_in_dataset(self): - self.dataset.dataset_elements.create(element=self.page1, set="train") + train_set = self.dataset.sets.get(name="training") + train_set.set_elements.create(element=self.page1, set="train") self.client.force_login(self.user) with self.assertNumQueries(3): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": str(self.page2.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertDictEqual(response.json(), {"detail": "Not found."}) def test_destroy_dataset_element_wrong_set(self): - self.dataset.dataset_elements.create(element=self.page1, set="train") - self.dataset.dataset_elements.create(element=self.page2, set="validation") + _, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + validation_set.set_elements.create(element=self.page1) self.client.force_login(self.user) with self.assertNumQueries(3): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": str(self.page2.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertDictEqual(response.json(), {"detail": "Not found."}) def test_destroy_dataset_element(self): self.client.force_login(self.user) - self.dataset.dataset_elements.create(element=self.page1, set="train") - self.dataset.dataset_elements.create(element=self.page1, set="validation") - self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 1) - self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1) + _, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + validation_set.set_elements.create(element=self.page1) + self.assertEqual(train_set.set_elements.count(), 1) + self.assertEqual(validation_set.set_elements.count(), 1) with self.assertNumQueries(4): response = self.client.delete(reverse( "api:dataset-element", kwargs={"dataset": str(self.dataset.id), "element": str(self.page1.id)}) - + "?set=train" + + "?set=training" ) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.dataset.refresh_from_db() - self.assertEqual(self.dataset.dataset_elements.filter(set="train").count(), 0) - self.assertEqual(self.dataset.dataset_elements.filter(set="validation").count(), 1) + self.assertEqual(train_set.set_elements.count(), 0) + self.assertEqual(validation_set.set_elements.count(), 1) -- GitLab