diff --git a/arkindex/process/tests/test_process_dataset_sets.py b/arkindex/process/tests/test_process_dataset_sets.py index 07e0b808f955e3a11d0cdef22c1913eb5d20bb30..2e4f841c5acba86503a5ae385e480790afc22c00 100644 --- a/arkindex/process/tests/test_process_dataset_sets.py +++ b/arkindex/process/tests/test_process_dataset_sets.py @@ -112,6 +112,7 @@ class TestProcessDatasetSets(FixtureAPITestCase): "corpus_id": str(self.private_corpus.id), "state": "open", "task_id": None, + "unique_elements": True, "created": FAKE_CREATED, "updated": FAKE_CREATED }, @@ -137,6 +138,7 @@ class TestProcessDatasetSets(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "state": "open", "task_id": None, + "unique_elements": True, "created": FAKE_CREATED, "updated": FAKE_CREATED }, @@ -315,6 +317,7 @@ class TestProcessDatasetSets(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "state": "open", "task_id": None, + "unique_elements": True, "created": FAKE_CREATED, "updated": FAKE_CREATED }, diff --git a/arkindex/training/admin.py b/arkindex/training/admin.py index 66f3626575649c6e4a5616165d4a5073079ba93f..f8cc278929d0c29dd3f53a18fd10a91aab19d5b6 100644 --- a/arkindex/training/admin.py +++ b/arkindex/training/admin.py @@ -36,13 +36,21 @@ class DatasetSetInLine(admin.StackedInline): class DatasetAdmin(admin.ModelAdmin): - list_display = ("name", "corpus", "state") + list_display = ("name", "corpus", "state", "unique_elements") list_filter = (("state", EnumFieldListFilter), "corpus") search_fields = ("name", "description") - fields = ("id", "name", "created", "updated", "description", "corpus", "creator", "task") + fields = ("id", "name", "created", "updated", "description", "corpus", "creator", "task", "unique_elements") readonly_fields = ("id", "created", "updated", "task") inlines = [DatasetSetInLine, ] + def get_form(self, request, obj=None, **kwargs): + # Prevent editing the `unique_elements` attribute + self.readonly_fields = self.__class__.readonly_fields + if obj is not None: + self.readonly_fields += ("unique_elements",) + form = super().get_form(request, obj=None, **kwargs) + return form + admin.site.register(Model, ModelAdmin) admin.site.register(ModelVersion, ModelVersionAdmin) diff --git a/arkindex/training/migrations/0008_dataset_unique_elements.py b/arkindex/training/migrations/0008_dataset_unique_elements.py new file mode 100644 index 0000000000000000000000000000000000000000..229afdda5432d1dbc4c5d0519407d8c27d397d64 --- /dev/null +++ b/arkindex/training/migrations/0008_dataset_unique_elements.py @@ -0,0 +1,38 @@ +# Generated by Django 4.1.7 on 2024-03-28 14:54 + +from django.db import migrations, models + + +def update_unique_elements(apps, schema_editor): + """Update unique_elements to False when some elements are already duplicated""" + Dataset = apps.get_model("training", "Dataset") + DatasetElement = apps.get_model("training", "DatasetElement") + Dataset.objects.filter( + models.Exists( + DatasetElement.objects + .filter(set__dataset_id=models.OuterRef("pk")) + .values("element_id") + .annotate(dupes=models.Count("element_id")) + .filter(dupes__gte=2) + ) + ).update(unique_elements=False) + + +class Migration(migrations.Migration): + + dependencies = [ + ("training", "0007_datasetset_model"), + ] + + operations = [ + migrations.AddField( + model_name="dataset", + name="unique_elements", + field=models.BooleanField(default=True), + ), + migrations.RunPython( + update_unique_elements, + reverse_code=migrations.RunPython.noop, + elidable=True, + ), + ] diff --git a/arkindex/training/models.py b/arkindex/training/models.py index c37ea35b09193f087793b2c48c194cc5f3c64e48..ea84bd5c8af8588881188e9580dba7bd64b519ac 100644 --- a/arkindex/training/models.py +++ b/arkindex/training/models.py @@ -266,6 +266,7 @@ class Dataset(models.Model): name = models.CharField(max_length=100, validators=[MinLengthValidator(1)]) description = models.TextField(validators=[MinLengthValidator(1)]) state = EnumField(DatasetState, default=DatasetState.Open, max_length=50) + unique_elements = models.BooleanField(default=True) class Meta: constraints = [ diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index ce877bff5f83794c9ded6578cd0cab47f593b66c..b7bd7f911074af0ac1ef9d80ec9c0063c88a65cd 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -542,6 +542,10 @@ class DatasetSerializer(serializers.ModelSerializer): set_elements = DatasetSetsCountField( help_text="Distribution of elements in sets. This value is set to null when listing multiple datasets.", ) + unique_elements = serializers.BooleanField( + default=True, + help_text="Ensures that an element is only present in a single set at a time.", + ) def validate_state(self, state): """ @@ -575,6 +579,20 @@ class DatasetSerializer(serializers.ModelSerializer): raise ValidationError("Either do not specify set names to use the default values, or specify a non-empty list of names.") return set_names + def validate_unique_elements(self, unique): + # When updating a dataset to switch unique_elements from False to True, + # check that it does not contain duplicates. + if unique is True and self.instance and not self.instance.unique_elements and ( + DatasetElement.objects + .filter(set__dataset_id=self.instance.pk) + .values("element_id") + .annotate(dupes=Count("element_id")) + .filter(dupes__gte=2) + .exists() + ): + raise ValidationError("Some elements are currently contained by multiple sets.") + return unique + def validate(self, data): data = super().validate(data) @@ -636,6 +654,7 @@ class DatasetSerializer(serializers.ModelSerializer): # Hidden field to set the creator as the authenticated user "default_creator", "task_id", + "unique_elements", "created", "updated", ) @@ -700,6 +719,21 @@ class DatasetElementSerializer(serializers.ModelSerializer): self.fields["element_id"].queryset = Element.objects.filter(corpus=dataset.corpus) self.fields["set"].queryset = dataset.sets.all() + def validate_element_id(self, element): + dataset = self.context.get("dataset") + if dataset and dataset.unique_elements and ( + existing_set := ( + dataset.sets + .filter(set_elements__element=element) + .values_list("name", flat=True) + .first() + ) + ): + raise ValidationError([ + f"The dataset requires unique elements and this element is already present in set {existing_set}." + ]) + return element + def validate(self, data): data = super().validate(data) data.pop("dataset") @@ -759,6 +793,20 @@ class SelectionDatasetElementSerializer(serializers.Serializer): raise ValidationError(f"Dataset {set.dataset.id} is not part of corpus {corpus.name}.") if set.dataset.state == DatasetState.Complete: raise ValidationError(f"Dataset {set.dataset.id} is marked as completed.") + # Ensure adding elements to the dataset does not break uniqueness constraint + selection = self.context["request"].user.selected_elements.filter(corpus=corpus) + if set.dataset.unique_elements and ( + existing_set := ( + set.dataset.sets + .exclude(id=set.id) + .filter(set_elements__element_id__in=selection.values_list("id", flat=True)) + .values_list("name", flat=True) + .first() + ) + ): + raise ValidationError([ + f"The dataset requires unique elements and some elements are already present in set {existing_set}." + ]) return set def create(self, validated_data): diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 4446928dd0ec6ab25e5e41b2f8ba3f5b89b25441..ce1d83ce0a9e6ca83840bd7e2c6d91815cd3a19d 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -99,6 +99,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "creator": "Test user", "task_id": None, "corpus_id": str(self.corpus.id), + "unique_elements": True, "created": FAKE_CREATED, "updated": FAKE_CREATED, }, @@ -118,6 +119,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "creator": "Test user", "task_id": None, "corpus_id": str(self.corpus.id), + "unique_elements": True, "created": FAKE_CREATED, "updated": FAKE_CREATED, } @@ -273,7 +275,7 @@ class TestDatasetsAPI(FixtureAPITestCase): with self.assertNumQueries(8): response = self.client.post( reverse("api:corpus-datasets", kwargs={"pk": self.corpus.pk}), - data={"name": "My dataset", "description": "My dataset for my experiments."}, + data={"name": "My dataset", "description": "My dataset for my experiments.", "unique_elements": False}, format="json" ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -298,6 +300,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "creator": "Test user", "task_id": None, "corpus_id": str(self.corpus.id), + "unique_elements": False, "created": created_dataset.created.isoformat().replace("+00:00", "Z"), "updated": created_dataset.updated.isoformat().replace("+00:00", "Z"), }) @@ -333,6 +336,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "creator": "Test user", "task_id": None, "corpus_id": str(self.corpus.id), + "unique_elements": True, "created": created_dataset.created.isoformat().replace("+00:00", "Z"), "updated": created_dataset.updated.isoformat().replace("+00:00", "Z"), }) @@ -363,6 +367,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "creator": "Test user", "task_id": None, "corpus_id": str(self.corpus.id), + "unique_elements": True, "created": created_dataset.created.isoformat().replace("+00:00", "Z"), "updated": created_dataset.updated.isoformat().replace("+00:00", "Z"), }) @@ -579,6 +584,38 @@ class TestDatasetsAPI(FixtureAPITestCase): "set_names": ["This API endpoint does not allow updating a dataset's sets."] }) + def test_update(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.put( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + data={"name": "a", "description": "a", "unique_elements": False}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.dataset.refresh_from_db() + self.assertEqual(self.dataset.name, "a") + self.assertEqual(self.dataset.description, "a") + self.assertEqual(self.dataset.unique_elements, False) + + def test_update_forbidden_unique_elements(self): + self.dataset.unique_elements = False + self.dataset.save() + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + test_set.set_elements.create(element=self.vol) + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.put( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + data={"name": "a", "description": "a", "unique_elements": True}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "unique_elements": ["Some elements are currently contained by multiple sets."] + }) + def test_update_state_requires_ponos_auth(self): self.client.force_login(self.user) self.dataset.state = DatasetState.Building @@ -788,9 +825,7 @@ class TestDatasetsAPI(FixtureAPITestCase): with self.assertNumQueries(6): response = self.client.patch( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), - data={ - "description": "Omedeto!", - }, + data={"description": "Omedeto!", "unique_elements": False}, format="json" ) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -798,6 +833,7 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(self.dataset.state, DatasetState.Open) self.assertEqual(self.dataset.name, "First Dataset") self.assertEqual(self.dataset.description, "Omedeto!") + self.assertEqual(self.dataset.unique_elements, False) self.assertCountEqual(list(self.dataset.sets.values_list("name", flat=True)), ["training", "test", "validation"]) def test_partial_update_empty_or_blank_description_or_name(self): @@ -814,6 +850,24 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"name": ["This field may not be blank."], "description": ["This field may not be blank."]}) + def test_partial_update_forbidden_unique_elements(self): + self.dataset.unique_elements = False + self.dataset.save() + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + test_set.set_elements.create(element=self.vol) + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.patch( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + data={"name": "a", "description": "a", "unique_elements": True}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "unique_elements": ["Some elements are currently contained by multiple sets."] + }) + def test_partial_update_requires_ponos_auth(self): self.client.force_login(self.user) with self.assertNumQueries(5): @@ -1005,6 +1059,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "creator": "Test user", "task_id": None, "corpus_id": str(self.corpus.id), + "unique_elements": True, "created": FAKE_CREATED, "updated": FAKE_CREATED }) @@ -1464,7 +1519,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_add_element_wrong_set(self): self.client.force_login(self.user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), data={"set": "aaaaaaaaaaa", "element_id": str(self.vol.id)}, @@ -1491,6 +1546,8 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_add_element_already_exists(self): test_set = self.dataset.sets.order_by("name").first() test_set.set_elements.create(element=self.page1) + self.dataset.unique_elements = False + self.dataset.save() self.client.force_login(self.user) with self.assertNumQueries(6): response = self.client.post( @@ -1504,7 +1561,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_add_element(self): train_set = self.dataset.sets.get(name="training") self.client.force_login(self.user) - with self.assertNumQueries(11): + with self.assertNumQueries(12): response = self.client.post( reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), data={"set": "training", "element_id": str(self.page1.id)}, @@ -1516,6 +1573,47 @@ class TestDatasetsAPI(FixtureAPITestCase): [("training", "Volume 1, page 1r")] ) + def test_add_element_forbid_dupes(self): + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "test", "element_id": str(self.page1.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "element_id": ["The dataset requires unique elements and this element is already present in set training."] + }) + + def test_add_element_allow_dupes(self): + self.dataset.unique_elements = False + self.dataset.save() + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.page1) + self.client.force_login(self.user) + with self.assertNumQueries(11): + response = self.client.post( + reverse("api:dataset-elements", kwargs={"pk": self.dataset.id}), + data={"set": "test", "element_id": str(self.page1.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertQuerysetEqual( + ( + DatasetElement.objects + .filter(set__dataset=self.dataset) + .values_list("set__name", "element__name") + .order_by("set__name", "element__name") + ), + [ + ("test", "Volume 1, page 1r"), + ("training", "Volume 1, page 1r"), + ] + ) + # CreateDatasetElementSelection def test_add_from_selection_requires_login(self): @@ -1633,7 +1731,7 @@ class TestDatasetsAPI(FixtureAPITestCase): self.user.selected_elements.set([self.vol, self.page1, self.page2]) self.client.force_login(self.user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.post( reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}), data={"set_id": str(train_set.id)}, @@ -1649,6 +1747,53 @@ class TestDatasetsAPI(FixtureAPITestCase): ] ) + def test_add_from_selection_forbid_dupes(self): + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + self.user.selected_elements.set([self.vol, self.page1]) + + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.post( + reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}), + data={"set_id": str(test_set.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + "set_id": ["The dataset requires unique elements and some elements are already present in set training."] + }) + + def test_add_from_selection_allow_dupes(self): + test_set, train_set, validation_set = self.dataset.sets.all().order_by("name") + train_set.set_elements.create(element=self.vol) + self.user.selected_elements.set([self.vol, self.page1]) + test_set = self.dataset.sets.get(name="test") + self.dataset.unique_elements = False + self.dataset.save() + + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:dataset-elements-selection", kwargs={"pk": self.corpus.id}), + data={"set_id": str(test_set.id)}, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertQuerysetEqual( + ( + DatasetElement.objects + .filter(set__dataset=self.dataset) + .values_list("set__name", "element__name") + .order_by("set__name", "element__name") + ), + [ + ("test", "Volume 1"), + ("test", "Volume 1, page 1r"), + ("training", "Volume 1"), + ] + ) + @patch("arkindex.users.managers.BaseACLManager.filter_rights", return_value=Corpus.objects.none()) def test_element_datasets_requires_read_access(self, filter_rights_mock): self.client.force_login(self.user) @@ -1700,6 +1845,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, @@ -1741,6 +1887,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, @@ -1764,6 +1911,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, @@ -1787,6 +1935,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), }, @@ -1828,6 +1977,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, @@ -1851,6 +2001,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, @@ -1874,6 +2025,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), }, @@ -1926,6 +2078,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, @@ -1957,6 +2110,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset.updated.isoformat().replace("+00:00", "Z"), }, @@ -1980,6 +2134,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "corpus_id": str(self.corpus.id), "creator": "Test user", "task_id": None, + "unique_elements": True, "created": self.dataset2.created.isoformat().replace("+00:00", "Z"), "updated": self.dataset2.updated.isoformat().replace("+00:00", "Z"), }, @@ -2056,6 +2211,7 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_clone(self): self.dataset.creator = self.superuser self.dataset.state = DatasetState.Error + self.dataset.unique_elements = False self.dataset.task = self.task self.dataset.save() test_set, _, validation_set = self.dataset.sets.all().order_by("name") @@ -2080,6 +2236,7 @@ class TestDatasetsAPI(FixtureAPITestCase): clone = self.corpus.datasets.get(id=data["id"]) test_clone, train_clone, val_clone = clone.sets.all().order_by("name") self.assertEqual(clone.creator, self.user) + self.assertEqual(clone.unique_elements, False) data.pop("created") data.pop("updated") cloned_sets = data.pop("sets") @@ -2094,6 +2251,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "set_elements": {"test": 1, "training": 0, "validation": 2}, "state": DatasetState.Open.value, "task_id": str(self.task.id), + "unique_elements": False, }, ) self.assertCountEqual(cloned_sets, [ @@ -2152,6 +2310,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "set_elements": {str(k.name): 0 for k in self.dataset.sets.all()}, "state": DatasetState.Open.value, "task_id": None, + "unique_elements": True, }, ) self.assertCountEqual(cloned_sets, [