diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 8a6060ea2cdf960c2051ec44b059ac78d334e76a..b994211562c0709f8f4b3718d51bc43923034b75 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -275,7 +275,7 @@ class TestDatasetsAPI(FixtureAPITestCase): with self.assertNumQueries(8): response = self.client.post( reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}), - data={"name": "My dataset", "description": "My dataset for my experiments."}, + data={"name": "My dataset", "description": "My dataset for my experiments.", "unique_elements": False}, format="json" ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -300,7 +300,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "creator": "Test user", "task_id": None, "corpus_id": str(self.corpus.id), - "unique_elements": True, + "unique_elements": False, "created": created_dataset.created.isoformat().replace("+00:00", "Z"), "updated": created_dataset.updated.isoformat().replace("+00:00", "Z"), }) @@ -584,6 +584,38 @@ class TestDatasetsAPI(FixtureAPITestCase): "set_names": ["This API endpoint does not allow updating a dataset's sets."] }) + def test_update(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.put( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + data={"name": "a", "description": "a", "unique_elements": False}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.dataset.refresh_from_db() + self.assertEqual(self.dataset.name, "a") + self.assertEqual(self.dataset.description, "a") + self.assertEqual(self.dataset.unique_elements, False) + + def test_update_forbidden_unique_elements(self): + self.dataset.unique_elements = False + self.dataset.save() + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + test_set.set_elements.create(element=self.vol) + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.put( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + data={"name": "a", "description": "a", "unique_elements": True}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "unique_elements": ["Elements are currently contained by multiple sets."] + }) + def test_update_state_requires_ponos_auth(self): self.client.force_login(self.user) self.dataset.state = DatasetState.Building @@ -793,9 +825,7 @@ class TestDatasetsAPI(FixtureAPITestCase): with self.assertNumQueries(6): response = self.client.patch( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), - data={ - "description": "Omedeto!", - }, + data={"description": "Omedeto!", "unique_elements": False}, format="json" ) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -803,6 +833,7 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(self.dataset.state, DatasetState.Open) self.assertEqual(self.dataset.name, "First Dataset") self.assertEqual(self.dataset.description, "Omedeto!") + self.assertEqual(self.dataset.unique_elements, False) self.assertCountEqual(list(self.dataset.sets.values_list("name", flat=True)), ["training", "test", "validation"]) def test_partial_update_empty_or_blank_description_or_name(self): @@ -819,6 +850,24 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"name": ["This field may not be blank."], "description": ["This field may not be blank."]}) + def test_partial_update_forbidden_unique_elements(self): + self.dataset.unique_elements = False + self.dataset.save() + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + test_set.set_elements.create(element=self.vol) + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.patch( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + data={"name": "a", "description": "a", "unique_elements": True}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "unique_elements": ["Elements are currently contained by multiple sets."] + }) + def test_partial_update_requires_ponos_auth(self): self.client.force_login(self.user) with self.assertNumQueries(5): @@ -1524,6 +1573,47 @@ class TestDatasetsAPI(FixtureAPITestCase): [("training", "Volume 1, page 1r")] ) + def test_add_element_forbid_dupes(self): + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "test", "element_id": str(self.page1.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "element_id": ["The dataset prevent duplication and this element is already present in set training."] + }) + + def test_add_element_allow_dupes(self): + self.dataset.unique_elements = False + self.dataset.save() + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + self.client.force_login(self.user) + with self.assertNumQueries(11): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "test", "element_id": str(self.page1.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertQuerysetEqual( + ( + DatasetElement.objects + .filter(set__dataset=self.dataset) + .values_list("set__name", "element__name") + .order_by("set__name", "element__name") + ), + [ + ("test", "Volume 1, page 1r"), + ("training", "Volume 1, page 1r"), + ] + ) + # CreateDatasetElementSelection def test_add_from_selection_requires_login(self): @@ -1657,6 +1747,53 @@ class TestDatasetsAPI(FixtureAPITestCase): ] ) + def test_add_from_selection_forbid_dupes(self): + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + self.user.selected_elements.set([self.vol, self.page1]) + + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post( + reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}), + data={"set_id": str(test_set.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "set_id": ["The dataset prevent duplication and this element is already present in set training."] + }) + + def test_add_from_selection_allow_dupes(self): + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + self.user.selected_elements.set([self.vol, self.page1]) + test_set = self.dataset.sets.get(name="test") + self.dataset.unique_elements = False + self.dataset.save() + + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}), + data={"set_id": str(test_set.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertQuerysetEqual( + ( + DatasetElement.objects + .filter(set__dataset=self.dataset) + .values_list("set__name", "element__name") + .order_by("set__name", "element__name") + ), + [ + ("test", "Volume 1"), + ("test", "Volume 1, page 1r"), + ("training", "Volume 1"), + ] + ) + @patch("arkindex.users.managers.BaseACLManager.filter_rights", return_value=Corpus.objects.none()) def test_element_datasets_requires_read_access(self, filter_rights_mock): self.client.force_login(self.user) @@ -2074,6 +2211,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_clone(self): self.dataset.creator = self.superuser self.dataset.state = DatasetState.Error + self.dataset.unique_elements = False self.dataset.task = self.task self.dataset.save() test_set, _, validation_set = self.dataset.sets.all().order_by("name") @@ -2098,6 +2236,7 @@ class TestDatasetsAPI(FixtureAPITestCase): clone = self.corpus.datasets.get(id=data["id"]) test_clone, train_clone, val_clone = clone.sets.all().order_by("name") self.assertEqual(clone.creator, self.user) + self.assertEqual(clone.unique_elements, False) data.pop("created") data.pop("updated") cloned_sets = data.pop("sets") @@ -2112,7 +2251,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "set_elements": {"test": 1, "training": 0, "validation": 2}, "state": DatasetState.Open.value, "task_id": str(self.task.id), - "unique_elements": True, + "unique_elements": False, }, ) self.assertCountEqual(cloned_sets, [