diff --git a/arkindex/documents/tasks.py b/arkindex/documents/tasks.py index ee4c680a850683f409fda99ca42e28d3d7e8efc7..16719f502a1e5d342f84b6cf24c55fa7b65206cc 100644 --- a/arkindex/documents/tasks.py +++ b/arkindex/documents/tasks.py @@ -23,7 +23,7 @@ from arkindex.documents.models import ( TranscriptionEntity, ) from arkindex.ponos.models import Task -from arkindex.process.models import Process, ProcessDataset, ProcessElement, WorkerActivity, WorkerRun +from arkindex.process.models import Process, ProcessDatasetSet, ProcessElement, WorkerActivity, WorkerRun from arkindex.training.models import DatasetElement, DatasetSet from arkindex.users.models import User @@ -70,9 +70,9 @@ def corpus_delete(corpus_id: str) -> None: Selection.objects.filter(element__corpus_id=corpus_id), corpus.memberships.all(), corpus.exports.all(), - # ProcessDataset M2M - ProcessDataset.objects.filter(dataset__corpus_id=corpus_id), - ProcessDataset.objects.filter(process__corpus_id=corpus_id), + # ProcessDatasetSet M2M + ProcessDatasetSet.objects.filter(set__dataset__corpus_id=corpus_id), + ProcessDatasetSet.objects.filter(process__corpus_id=corpus_id), DatasetElement.objects.filter(set__dataset__corpus_id=corpus_id), DatasetSet.objects.filter(dataset__corpus_id=corpus_id), corpus.datasets.all(), diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py index cea863ab3ce0effaa6e0250a92181e466774b196..83567cafec34adda395f063680d3941ccf0ca454 100644 --- a/arkindex/documents/tests/tasks/test_corpus_delete.py +++ b/arkindex/documents/tests/tasks/test_corpus_delete.py @@ -3,7 +3,14 @@ from django.db.models.signals import pre_delete from arkindex.documents.models import Corpus, Element, EntityType, MetaType, Transcription from arkindex.documents.tasks import corpus_delete from arkindex.ponos.models import Farm, State, Task -from arkindex.process.models import CorpusWorkerVersion, Process, ProcessDataset, ProcessMode, Repository, WorkerVersion +from arkindex.process.models import ( + CorpusWorkerVersion, + Process, + ProcessDatasetSet, + ProcessMode, + Repository, + WorkerVersion, +) from arkindex.project.tests import FixtureTestCase, force_constraints_immediate from arkindex.training.models import Dataset, DatasetSet @@ -123,16 +130,16 @@ class TestDeleteCorpus(FixtureTestCase): name=set_name ) for set_name in ["test", "training", "validation"] ) - # Process on cls.corpus and with a dataset from cls.corpus + # Process on cls.corpus and with a set from cls.corpus dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True))) - # Process on cls.corpus with a dataset from another corpus + ProcessDatasetSet.objects.create(process=dataset_process1, set=test_set_1) + # Process on cls.corpus with a set from another corpus dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=list(dataset1.sets.values_list("name", flat=True))) - ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True))) - # Process on another corpus with a dataset from another corpus and none from cls.corpus + ProcessDatasetSet.objects.create(process=dataset_process2, set=test_set_1) + ProcessDatasetSet.objects.create(process=dataset_process2, set=cls.dataset2.sets.get(name="training")) + # Process on another corpus with a set from another corpus and none from cls.corpus cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True))) + ProcessDatasetSet.objects.create(process=cls.dataset_process3, set=cls.dataset2.sets.get(name="validation")) cls.rev = cls.repo.revisions.create( hash="42", diff --git a/arkindex/process/api.py b/arkindex/process/api.py index f3eb182c9454e29f7f18afdc4205d0903875c7db..c0b61314091252adaabfad79ba772c1b011e3d80 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -46,7 +46,6 @@ from rest_framework.generics import ( RetrieveDestroyAPIView, RetrieveUpdateAPIView, RetrieveUpdateDestroyAPIView, - UpdateAPIView, ) from rest_framework.response import Response from rest_framework.serializers import Serializer @@ -61,7 +60,7 @@ from arkindex.process.models import ( GitRef, GitRefType, Process, - ProcessDataset, + ProcessDatasetSet, ProcessMode, Revision, Worker, @@ -87,7 +86,7 @@ from arkindex.process.serializers.imports import ( StartProcessSerializer, ) from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer -from arkindex.process.serializers.training import ProcessDatasetSerializer +from arkindex.process.serializers.training import ProcessDatasetSetSerializer from arkindex.process.serializers.worker_runs import ( CorpusWorkerRunSerializer, UserWorkerRunSerializer, @@ -126,7 +125,7 @@ from arkindex.project.pagination import CountCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.project.tools import PercentileCont from arkindex.project.triggers import process_delete -from arkindex.training.models import Model +from arkindex.training.models import DatasetSet, Model from arkindex.users.models import Role, Scope logger = logging.getLogger(__name__) @@ -565,7 +564,7 @@ class StartProcess(CorpusACLMixin, CreateAPIView): "model_version__model", "configuration", ))) - .prefetch_related("datasets") + .prefetch_related(Prefetch("sets", queryset=DatasetSet.objects.select_related("dataset"))) # Uses Exists() for has_tasks and not a __isnull because we are not joining on tasks and do not need to fetch them .annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef("pk")))) ) @@ -677,20 +676,20 @@ class DataFileCreate(CreateAPIView): @extend_schema(tags=["process"]) @extend_schema_view( get=extend_schema( - operation_id="ListProcessDatasets", + operation_id="ListProcessSets", description=dedent( """ - List all datasets on a process. + List all dataset sets on a process. Requires a **guest** access to the process. """ ), ), ) -class ProcessDatasets(ProcessACLMixin, ListAPIView): +class ProcessDatasetSets(ProcessACLMixin, ListAPIView): permission_classes = (IsVerified, ) - serializer_class = ProcessDatasetSerializer - queryset = ProcessDataset.objects.none() + serializer_class = ProcessDatasetSetSerializer + queryset = ProcessDatasetSet.objects.none() @cached_property def process(self): @@ -704,10 +703,10 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): def get_queryset(self): return ( - ProcessDataset.objects.filter(process_id=self.process.id) - .select_related("process__creator", "dataset__creator") - .prefetch_related("dataset__sets") - .order_by("dataset__name") + ProcessDatasetSet.objects.filter(process_id=self.process.id) + .select_related("process__creator", "set__dataset__creator") + .prefetch_related(Prefetch("set__dataset__sets", queryset=DatasetSet.objects.order_by("name"))) + .order_by("set__dataset__name", "set__name") ) def get_serializer_context(self): @@ -722,51 +721,52 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): @extend_schema(tags=["process"]) @extend_schema_view( post=extend_schema( - operation_id="CreateProcessDataset", + operation_id="CreateProcessSet", description=dedent( """ - Add a dataset to a process. + Add a dataset set to a process. Requires an **admin** access to the process and a **guest** access to the dataset's corpus. """ ), ), delete=extend_schema( - operation_id="DestroyProcessDataset", + operation_id="DestroyProcessSet", description=dedent( """ - Remove a dataset from a process. + Remove a dataset set from a process. Requires an **admin** access to the process. """ ), ), ) -class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): +class ProcessDatasetSetManage(CreateAPIView, DestroyAPIView): permission_classes = (IsVerified, ) - serializer_class = ProcessDatasetSerializer + serializer_class = ProcessDatasetSetSerializer def get_object(self): - process_dataset = get_object_or_404( - ProcessDataset.objects - .select_related("dataset__creator", "process__corpus") - .prefetch_related("dataset__sets") + qs = ( + ProcessDatasetSet.objects + .select_related("set__dataset__creator", "process__corpus") # Required to check for a process that have already started - .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))), - dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"] + .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))) + ) + # Only prefetch the dataset sets when creating + if self.request.method != "DELETE": + qs.prefetch_related(Prefetch("set__dataset__sets", queryset=DatasetSet.objects.order_by("name"))) + process_set = get_object_or_404( + qs, + set_id=self.kwargs["set"], process_id=self.kwargs["process"] ) # Copy the has_tasks annotation onto the process - process_dataset.process.has_tasks = process_dataset.process_has_tasks - return process_dataset + process_set.process.has_tasks = process_set.process_has_tasks + return process_set def destroy(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - # Ignore the sets when retrieving the ProcessDataset instance, as there cannot be - # two ProcessDatasets with the same dataset and process, whatever the sets - validated_data = serializer.validated_data - del validated_data["sets"] - get_object_or_404(ProcessDataset, **validated_data).delete() + get_object_or_404(ProcessDatasetSet, **serializer.validated_data).delete() return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/arkindex/process/migrations/0032_processdatasetset_model.py b/arkindex/process/migrations/0032_processdatasetset_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c6088bd88bcfa2ab8c82fa429871e0ae842a236f --- /dev/null +++ b/arkindex/process/migrations/0032_processdatasetset_model.py @@ -0,0 +1,52 @@ +# Generated by Django 4.1.7 on 2024-03-21 12:00 + +import uuid + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("training", "0007_datasetset_model"), + ("process", "0031_process_corpus_check_and_remove_revision_field"), + ] + + operations = [ + migrations.CreateModel( + name="ProcessDatasetSet", + fields=[ + ("id", models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ("process", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name="process_sets", to="process.process")), + ("set", models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name="process_sets", to="training.datasetset")), + ], + ), + migrations.AddConstraint( + model_name="processdatasetset", + constraint=models.UniqueConstraint(fields=("process", "set"), name="unique_process_set"), + ), + migrations.RunSQL( + """ + INSERT INTO process_processdatasetset (id, process_id, set_id) + SELECT gen_random_uuid(), p.process_id, dss.id + FROM ( + SELECT DISTINCT process_id, unnest(sets) AS set + FROM process_processdataset + ) p + INNER JOIN training_datasetset AS dss ON (dataset_id = dss.dataset_id AND set = dss.name) + """, + ), + migrations.RemoveField( + model_name="process", + name="datasets", + ), + migrations.AddField( + model_name="process", + name="sets", + field=models.ManyToManyField(related_name="processes", through="process.ProcessDatasetSet", to="training.datasetset"), + ), + migrations.DeleteModel( + name="ProcessDataset", + ), + ] diff --git a/arkindex/process/models.py b/arkindex/process/models.py index e32dd410fe26622c3d7ea6506dac24cea163457b..ba269d422b4b8fccdabd9242b006139ddf561a01 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -76,7 +76,7 @@ class Process(IndexableModel): corpus = models.ForeignKey("documents.Corpus", on_delete=models.CASCADE, related_name="processes", null=True) mode = EnumField(ProcessMode, max_length=30) files = models.ManyToManyField("process.DataFile", related_name="processes") - datasets = models.ManyToManyField("training.Dataset", related_name="processes", through="process.ProcessDataset") + sets = models.ManyToManyField("training.DatasetSet", related_name="processes", through="process.ProcessDatasetSet") versions = models.ManyToManyField("process.WorkerVersion", through="process.WorkerRun", related_name="processes") activity_state = EnumField(ActivityState, max_length=32, default=ActivityState.Disabled) @@ -465,26 +465,19 @@ class ProcessElement(models.Model): ) -class ProcessDataset(models.Model): +class ProcessDatasetSet(models.Model): """ - Link between Processes and Datasets. + Link between Processes and Dataset Sets. """ id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - process = models.ForeignKey(Process, on_delete=models.CASCADE, related_name="process_datasets") - dataset = models.ForeignKey("training.Dataset", on_delete=models.CASCADE, related_name="process_datasets") - sets = ArrayField( - models.CharField(max_length=50, validators=[MinLengthValidator(1)]), - validators=[ - MinLengthValidator(1), - validate_unique_set_names, - ] - ) + process = models.ForeignKey(Process, on_delete=models.CASCADE, related_name="process_sets") + set = models.ForeignKey("training.DatasetSet", on_delete=models.DO_NOTHING, related_name="process_sets") class Meta: constraints = [ models.UniqueConstraint( - fields=["process", "dataset"], - name="unique_process_dataset", + fields=["process", "set"], + name="unique_process_set", ) ] diff --git a/arkindex/process/serializers/imports.py b/arkindex/process/serializers/imports.py index bc071b1f728f41357aaf157d8e69567b3b26c96a..349781accf6c32d2fb65593ec193a9f240dddb99 100644 --- a/arkindex/process/serializers/imports.py +++ b/arkindex/process/serializers/imports.py @@ -415,11 +415,11 @@ class StartProcessSerializer(serializers.Serializer): errors = defaultdict(list) if self.instance.mode == ProcessMode.Dataset: - # Only call .count() and .all() as they will access the prefetched datasets and not cause any extra query - if not self.instance.datasets.count(): - errors["non_field_errors"].append("A dataset process cannot be started if it does not have any associated datasets.") - elif not any(dataset.corpus_id == self.instance.corpus.id for dataset in self.instance.datasets.all()): - errors["non_field_errors"].append("At least one of the process datasets must be from the same corpus as the process.") + # Only call .count() and .all() as they will access the prefetched dataset sets and not cause any extra query + if not self.instance.sets.count(): + errors["non_field_errors"].append("A dataset process cannot be started if it does not have any associated dataset sets.") + elif not any(ds.dataset.corpus_id == self.instance.corpus.id for ds in self.instance.sets.all()): + errors["non_field_errors"].append("At least one of the process sets must be from the same corpus as the process.") if validated_data.get("generate_thumbnails"): errors["thumbnails"].append("Thumbnails generation is not supported on Dataset processes.") diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py index f280a16a650566dfcaab024d23d1d1c3d52faa8a..1eda5f72f60264f03774343c62363dd83a68d316 100644 --- a/arkindex/process/serializers/training.py +++ b/arkindex/process/serializers/training.py @@ -5,47 +5,41 @@ from rest_framework import serializers from rest_framework.exceptions import PermissionDenied, ValidationError from arkindex.documents.models import Corpus -from arkindex.process.models import Process, ProcessDataset, ProcessMode, Task +from arkindex.process.models import Process, ProcessDatasetSet, ProcessMode, Task from arkindex.project.mixins import ProcessACLMixin -from arkindex.training.models import Dataset +from arkindex.training.models import DatasetSet from arkindex.training.serializers import DatasetSerializer from arkindex.users.models import Role -def _dataset_id_from_context(serializer_field): - return serializer_field.context.get("view").kwargs["dataset"] +def _set_id_from_context(serializer_field): + return serializer_field.context.get("view").kwargs["set"] def _process_id_from_context(serializer_field): return serializer_field.context.get("view").kwargs["process"] -_dataset_id_from_context.requires_context = True +_set_id_from_context.requires_context = True _process_id_from_context.requires_context = True -class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): +class ProcessDatasetSetSerializer(ProcessACLMixin, serializers.ModelSerializer): process_id = serializers.HiddenField( write_only=True, default=_process_id_from_context ) - dataset_id = serializers.HiddenField( + set_id = serializers.HiddenField( write_only=True, - default=_dataset_id_from_context - ) - dataset = DatasetSerializer(read_only=True) - sets = serializers.ListField( - child=serializers.CharField(max_length=50), - required=False, - allow_null=False, - allow_empty=False, - min_length=1 + default=_set_id_from_context ) + dataset = DatasetSerializer(read_only=True, source="set.dataset") + set_name = serializers.CharField(read_only=True, source="set.name") class Meta: - model = ProcessDataset - fields = ("dataset_id", "dataset", "process_id", "id", "sets", ) - read_only_fields = ("process_id", "id", ) + model = ProcessDatasetSet + fields = ("process_id", "set_id", "id", "dataset", "set_name", ) + read_only_fields = ("process_id", "id", "dataset", "set_name", ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -81,44 +75,34 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): if not access or not (access >= Role.Admin.value): raise PermissionDenied(detail="You do not have admin access to this process.") - # Validate dataset + # Validate dataset set if not self.instance: if request_method == "DELETE": - # Allow deleting ProcessDatasets even if the user looses access to the corpus - dataset_qs = Dataset.objects.all() + # Allow deleting ProcessDatasetSets even if the user looses access to the corpus + set_qs = DatasetSet.objects.all() else: - dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) + set_qs = ( + DatasetSet.objects.filter(dataset__corpus__in=Corpus.objects.readable(self._user)) + .select_related("dataset__creator").prefetch_related("dataset__sets") + ) try: - dataset = dataset_qs.select_related("creator").prefetch_related("sets").get(pk=data["dataset_id"]) - except Dataset.DoesNotExist: - raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']}) + dataset_set = set_qs.get(pk=data["set_id"]) + except DatasetSet.DoesNotExist: + raise ValidationError({"set": [f'Invalid pk "{str(data["set_id"])}" - object does not exist.']}) else: - dataset = self.instance.dataset - data["dataset"] = dataset + dataset_set = self.instance.set if process.mode != ProcessMode.Dataset: - errors["process"].append('Datasets can only be added to or removed from processes of mode "dataset".') + errors["process"].append('Dataset sets can only be added to or removed from processes of mode "dataset".') if process.has_tasks: - errors["process"].append("Datasets cannot be updated on processes that have already started.") - - if self.context["request"].method == "POST" and process.datasets.filter(id=dataset.id).exists(): - errors["dataset"].append("This dataset is already selected in this process.") + errors["process"].append("Dataset sets cannot be updated on processes that have already started.") - # Validate sets - sets = data.get("sets") - if not sets or len(sets) == 0: - if not self.instance: - data["sets"] = [item.name for item in list(dataset.sets.all())] - else: - errors["sets"].append("This field cannot be empty.") - else: - if any(s not in [item.name for item in list(dataset.sets.all())] for s in sets): - errors["sets"].append("The specified sets must all exist in the specified dataset.") - if len(set(sets)) != len(sets): - errors["sets"].append("Sets must be unique.") + if self.context["request"].method == "POST" and process.sets.filter(id=dataset_set.id).exists(): + errors["set"].append("This dataset set is already selected in this process.") if errors: raise ValidationError(errors) + data["set"] = dataset_set return data diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py index 59e454700caae9873e09d5c55644cb57e220f616..e64131810d51a5e8375be7b8fad3ec9c482d85ff 100644 --- a/arkindex/process/tests/test_create_process.py +++ b/arkindex/process/tests/test_create_process.py @@ -13,7 +13,7 @@ from arkindex.process.models import ( ActivityState, FeatureUsage, Process, - ProcessDataset, + ProcessDatasetSet, ProcessMode, Repository, WorkerActivity, @@ -899,8 +899,8 @@ class TestCreateProcess(FixtureAPITestCase): self.client.force_login(self.user) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) dataset = self.corpus.datasets.first() - test_sets = list(dataset.sets.values_list("name", flat=True)) - ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) + test_set = dataset.sets.get(name="test") + ProcessDatasetSet.objects.create(process=process, set=test_set) process.versions.set([self.version_2, self.version_3]) with self.assertNumQueries(9): @@ -930,8 +930,8 @@ class TestCreateProcess(FixtureAPITestCase): self.worker_1.save() process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) dataset = self.corpus.datasets.first() - test_sets = list(dataset.sets.values_list("name", flat=True)) - ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) + test_set = dataset.sets.get(name="test") + ProcessDatasetSet.objects.create(process=process, set=test_set) process.versions.add(self.version_1) with self.assertNumQueries(9): diff --git a/arkindex/process/tests/test_process_dataset_sets.py b/arkindex/process/tests/test_process_dataset_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..07e0b808f955e3a11d0cdef22c1913eb5d20bb30 --- /dev/null +++ b/arkindex/process/tests/test_process_dataset_sets.py @@ -0,0 +1,444 @@ +import uuid +from unittest.mock import call, patch + +from django.urls import reverse +from rest_framework import status + +from arkindex.documents.models import Corpus +from arkindex.ponos.models import Farm +from arkindex.process.models import Process, ProcessDatasetSet, ProcessMode +from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Dataset, DatasetSet +from arkindex.users.models import Role, User + +# Using the fake DB fixtures creation date when needed +FAKE_CREATED = "2020-02-02T01:23:45.678000Z" + + +class TestProcessDatasetSets(FixtureAPITestCase): + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.private_corpus = Corpus.objects.create(name="Private corpus") + with patch("django.utils.timezone.now") as mock_now: + mock_now.return_value = FAKE_CREATED + cls.private_dataset = cls.private_corpus.datasets.create( + name="Dead sea scrolls", + description="Human instrumentality manual", + creator=cls.user + ) + DatasetSet.objects.bulk_create([ + DatasetSet(dataset_id=cls.private_dataset.id, name=set_name) + for set_name in ["validation", "training", "test"] + ]) + cls.test_user = User.objects.create(email="katsuragi@nerv.co.jp", verified_email=True) + cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value) + + # Datasets from another corpus + cls.dataset1, cls.dataset2 = Dataset.objects.filter(corpus=cls.corpus).order_by("name") + + cls.dataset_process = Process.objects.create( + creator_id=cls.user.id, + mode=ProcessMode.Dataset, + corpus_id=cls.private_corpus.id, + farm=Farm.objects.get(name="Wheat farm") + ) + ProcessDatasetSet.objects.bulk_create( + ProcessDatasetSet(process=cls.dataset_process, set=dataset_set) + for dataset_set in cls.dataset1.sets.all() + ) + ProcessDatasetSet.objects.bulk_create( + ProcessDatasetSet(process=cls.dataset_process, set=dataset_set) + for dataset_set in cls.private_dataset.sets.all() + ) + + # Control process to check that its datasets are not retrieved + cls.dataset_process_2 = Process.objects.create( + creator_id=cls.user.id, + mode=ProcessMode.Dataset, + corpus_id=cls.corpus.id + ) + ProcessDatasetSet.objects.bulk_create( + ProcessDatasetSet(process=cls.dataset_process_2, set=dataset_set) + for dataset_set in cls.dataset2.sets.all() + ) + + # List process dataset sets + + def test_list_requires_login(self): + with self.assertNumQueries(0): + response = self.client.get(reverse("api:process-sets", kwargs={"pk": self.dataset_process.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_list_process_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.get(reverse("api:process-sets", kwargs={"pk": str(uuid.uuid4())})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + @patch("arkindex.project.mixins.get_max_level", return_value=None) + def test_list_process_access_level(self, get_max_level_mock): + self.private_corpus.memberships.filter(user=self.test_user).delete() + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.get(reverse("api:process-sets", kwargs={"pk": self.dataset_process.id})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You do not have guest access to this process."}) + + self.assertEqual(get_max_level_mock.call_count, 1) + self.assertEqual(get_max_level_mock.call_args, call(self.test_user, self.private_corpus)) + + def test_list(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.get(reverse("api:process-sets", kwargs={"pk": self.dataset_process.id})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertCountEqual(response.json()["results"], [ + { + "id": str(ProcessDatasetSet.objects.get(process=self.dataset_process, set=dataset_set).id), + "dataset": { + "id": str(self.private_dataset.id), + "name": "Dead sea scrolls", + "description": "Human instrumentality manual", + "creator": "Test user", + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.private_dataset.sets.order_by("name") + ], + "set_elements": None, + "corpus_id": str(self.private_corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "set_name": dataset_set.name + } + for dataset_set in self.private_dataset.sets.order_by("dataset__name", "name") + ] + [ + { + "id": str(ProcessDatasetSet.objects.get(process=self.dataset_process, set=dataset_set).id), + "dataset": { + "id": str(self.dataset1.id), + "name": "First Dataset", + "description": "dataset number one", + "creator": "Test user", + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.order_by("name") + ], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "set_name": dataset_set.name + } + for dataset_set in self.dataset1.sets.order_by("dataset__name", "name") + ]) + + # Create process dataset set + + def test_create_requires_login(self): + test_set = self.dataset2.sets.get(name="test") + with self.assertNumQueries(0): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_requires_verified(self): + unverified_user = User.objects.create(email="email@mail.com") + test_set = self.dataset2.sets.get(name="test") + self.client.force_login(unverified_user) + with self.assertNumQueries(2): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @patch("arkindex.project.mixins.get_max_level") + def test_create_access_level(self, get_max_level_mock): + cases = [None, Role.Guest.value, Role.Contributor.value] + test_set = self.dataset2.sets.get(name="test") + for level in cases: + with self.subTest(level=level): + get_max_level_mock.reset_mock() + get_max_level_mock.return_value = level + self.client.force_login(self.test_user) + + with self.assertNumQueries(3): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) + + self.assertEqual(get_max_level_mock.call_count, 1) + self.assertEqual(get_max_level_mock.call_args, call(self.test_user, self.private_corpus)) + + def test_create_process_mode(self): + cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local} + test_set = self.dataset2.sets.get(name="test") + for mode in cases: + with self.subTest(mode=mode): + self.dataset_process.mode = mode + self.dataset_process.save() + self.client.force_login(self.test_user) + + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), {"process": ['Dataset sets can only be added to or removed from processes of mode "dataset".']}) + + def test_create_process_mode_local(self): + self.client.force_login(self.user) + test_set = self.dataset2.sets.get(name="test") + local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) + with self.assertNumQueries(3): + response = self.client.post( + reverse("api:process-set", kwargs={"process": local_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) + + def test_create_wrong_process_uuid(self): + self.client.force_login(self.test_user) + test_set = self.dataset2.sets.get(name="test") + wrong_id = uuid.uuid4() + with self.assertNumQueries(3): + response = self.client.post( + reverse("api:process-set", kwargs={"process": wrong_id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) + + def test_create_wrong_set_uuid(self): + self.client.force_login(self.test_user) + wrong_id = uuid.uuid4() + with self.assertNumQueries(4): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": wrong_id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {"set": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) + + @patch("arkindex.users.managers.BaseACLManager.filter_rights", return_value=Corpus.objects.none()) + def test_create_dataset_access(self, filter_rights_mock): + new_corpus = Corpus.objects.create(name="NERV") + new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) + test_set = new_dataset.sets.create(name="test") + self.client.force_login(self.test_user) + + with self.assertNumQueries(3): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), {"set": [f'Invalid pk "{str(test_set.id)}" - object does not exist.']}) + + self.assertEqual(filter_rights_mock.call_count, 1) + self.assertEqual(filter_rights_mock.call_args, call(self.test_user, Corpus, Role.Guest.value)) + + def test_create_unique(self): + self.client.force_login(self.test_user) + test_set = self.dataset1.sets.get(name="test") + self.assertTrue(self.dataset_process.sets.filter(id=test_set.id).exists()) + + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {"set": ["This dataset set is already selected in this process."]}) + + def test_create_started(self): + self.client.force_login(self.test_user) + self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") + test_set = self.dataset2.sets.get(name="test") + + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {"process": ["Dataset sets cannot be updated on processes that have already started."]}) + + def test_create(self): + self.client.force_login(self.test_user) + test_set = self.dataset2.sets.get(name="test") + self.assertEqual(ProcessDatasetSet.objects.count(), 9) + self.assertFalse(ProcessDatasetSet.objects.filter(process=self.dataset_process.id, set=test_set.id).exists()) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ProcessDatasetSet.objects.count(), 10) + self.assertTrue(ProcessDatasetSet.objects.filter(process=self.dataset_process.id, set=test_set.id).exists()) + self.assertQuerysetEqual(self.dataset_process.sets.order_by("dataset__name", "name"), [ + *self.private_dataset.sets.order_by("name"), + *self.dataset1.sets.order_by("name"), + test_set, + ]) + created = ProcessDatasetSet.objects.get(process=self.dataset_process.id, set=test_set.id) + self.assertDictEqual(response.json(), { + "id": str(created.id), + "dataset": { + "id": str(self.dataset2.id), + "name": "Second Dataset", + "description": "dataset number two", + "creator": "Test user", + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset2.sets.all() + ], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "set_name": "test" + }) + + # Destroy process dataset set + + def test_destroy_requires_login(self): + train_set = self.dataset1.sets.get(name="training") + with self.assertNumQueries(0): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": train_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_destroy_process_does_not_exist(self): + train_set = self.dataset1.sets.get(name="training") + self.client.force_login(self.test_user) + wrong_id = uuid.uuid4() + with self.assertNumQueries(3): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": wrong_id, "set": train_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) + + def test_destroy_set_does_not_exist(self): + self.client.force_login(self.test_user) + wrong_id = uuid.uuid4() + with self.assertNumQueries(4): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": wrong_id}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"set": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) + + def test_destroy_not_found(self): + train_set = self.dataset2.sets.get(name="training") + self.assertFalse(ProcessDatasetSet.objects.filter(process=self.dataset_process, set=train_set).exists()) + self.client.force_login(self.test_user) + with self.assertNumQueries(5): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": train_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + @patch("arkindex.project.mixins.get_max_level", return_value=None) + def test_destroy_process_access_level(self, get_max_level_mock): + train_set = self.dataset1.sets.get(name="training") + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": train_set.id}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + self.assertDictEqual(response.json(), {"detail": "You do not have admin access to this process."}) + + self.assertEqual(get_max_level_mock.call_count, 1) + self.assertEqual(get_max_level_mock.call_args, call(self.test_user, self.private_corpus)) + + def test_destroy_no_dataset_access_requirement(self): + new_corpus = Corpus.objects.create(name="NERV") + new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) + test_set = new_dataset.sets.create(name="test") + ProcessDatasetSet.objects.create(process=self.dataset_process, set=test_set) + self.assertTrue(ProcessDatasetSet.objects.filter(process=self.dataset_process, set=test_set).exists()) + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": test_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertFalse(ProcessDatasetSet.objects.filter(process=self.dataset_process, set=test_set).exists()) + + def test_destroy_process_mode(self): + train_set = self.dataset1.sets.get(name="training") + cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local} + for mode in cases: + with self.subTest(mode=mode): + self.dataset_process.mode = mode + self.dataset_process.save() + self.client.force_login(self.test_user) + + with self.assertNumQueries(4): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": train_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), {"process": ['Dataset sets can only be added to or removed from processes of mode "dataset".']}) + + def test_destroy_process_mode_local(self): + train_set = self.dataset1.sets.get(name="training") + self.client.force_login(self.user) + local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) + with self.assertNumQueries(3): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": local_process.id, "set": train_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) + + def test_destroy_started(self): + train_set = self.dataset1.sets.get(name="training") + self.client.force_login(self.test_user) + self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") + + with self.assertNumQueries(4): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": train_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {"process": ["Dataset sets cannot be updated on processes that have already started."]}) + + def test_destroy(self): + train_set = self.dataset1.sets.get(name="training") + self.assertTrue(ProcessDatasetSet.objects.filter(process=self.dataset_process, set=train_set).exists()) + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.delete( + reverse("api:process-set", kwargs={"process": self.dataset_process.id, "set": train_set.id}), + ) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertFalse(ProcessDatasetSet.objects.filter(process=self.dataset_process, set=train_set).exists()) diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py deleted file mode 100644 index 550c2968dae2480d62631aa0bf7e8990d582a30b..0000000000000000000000000000000000000000 --- a/arkindex/process/tests/test_process_datasets.py +++ /dev/null @@ -1,836 +0,0 @@ -import uuid -from datetime import datetime, timezone -from unittest.mock import call, patch - -from django.urls import reverse -from rest_framework import status - -from arkindex.documents.models import Corpus -from arkindex.ponos.models import Farm -from arkindex.process.models import Process, ProcessDataset, ProcessMode -from arkindex.project.tests import FixtureAPITestCase -from arkindex.training.models import Dataset, DatasetSet -from arkindex.users.models import Role, User - -# Using the fake DB fixtures creation date when needed -FAKE_CREATED = "2020-02-02T01:23:45.678000Z" - - -class TestProcessDatasets(FixtureAPITestCase): - @classmethod - def setUpTestData(cls): - super().setUpTestData() - cls.private_corpus = Corpus.objects.create(name="Private corpus") - with patch("django.utils.timezone.now") as mock_now: - mock_now.return_value = FAKE_CREATED - cls.private_dataset = cls.private_corpus.datasets.create( - name="Dead sea scrolls", - description="Human instrumentality manual", - creator=cls.user - ) - DatasetSet.objects.bulk_create([ - DatasetSet(dataset_id=cls.private_dataset.id, name=set_name) - for set_name in ["validation", "training", "test"] - ]) - cls.test_user = User.objects.create(email="katsuragi@nerv.co.jp", verified_email=True) - cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value) - - # Datasets from another corpus - cls.dataset1, cls.dataset2 = Dataset.objects.filter(corpus=cls.corpus).order_by("name") - - cls.dataset_process = Process.objects.create( - creator_id=cls.user.id, - mode=ProcessMode.Dataset, - corpus_id=cls.private_corpus.id, - farm=Farm.objects.get(name="Wheat farm") - ) - cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=list(cls.dataset1.sets.values_list("name", flat=True))) - cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=list(cls.private_dataset.sets.values_list("name", flat=True))) - - # Control process to check that its datasets are not retrieved - cls.dataset_process_2 = Process.objects.create( - creator_id=cls.user.id, - mode=ProcessMode.Dataset, - corpus_id=cls.corpus.id - ) - ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True))) - - # List process datasets - - def test_list_requires_login(self): - with self.assertNumQueries(0): - response = self.client.get(reverse("api:process-datasets", kwargs={"pk": self.dataset_process.id})) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_list_process_does_not_exist(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.get(reverse("api:process-datasets", kwargs={"pk": str(uuid.uuid4())})) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - @patch("arkindex.project.mixins.get_max_level", return_value=None) - def test_list_process_access_level(self, get_max_level_mock): - self.private_corpus.memberships.filter(user=self.test_user).delete() - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.get(reverse("api:process-datasets", kwargs={"pk": self.dataset_process.id})) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertDictEqual(response.json(), {"detail": "You do not have guest access to this process."}) - - self.assertEqual(get_max_level_mock.call_count, 1) - self.assertEqual(get_max_level_mock.call_args, call(self.test_user, self.private_corpus)) - - def test_list(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(6): - response = self.client.get(reverse("api:process-datasets", kwargs={"pk": self.dataset_process.id})) - self.assertEqual(response.status_code, status.HTTP_200_OK) - sets_0 = response.json()["results"][0].pop("sets") - self.assertCountEqual(sets_0, ["validation", "training", "test"]) - self.assertDictEqual(response.json()["results"][0], { - "id": str(self.process_dataset_2.id), - "dataset": { - "id": str(self.private_dataset.id), - "name": "Dead sea scrolls", - "description": "Human instrumentality manual", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.private_dataset.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.private_corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - } - }) - sets_1 = response.json()["results"][1].pop("sets") - self.assertCountEqual(sets_1, ["validation", "training", "test"]) - self.assertDictEqual(response.json()["results"][1], { - "id": str(self.process_dataset_1.id), - "dataset": { - "id": str(self.dataset1.id), - "name": "First Dataset", - "description": "dataset number one", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.dataset1.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - } - }) - - # Create process dataset - - def test_create_requires_login(self): - test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - with self.assertNumQueries(0): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_create_requires_verified(self): - unverified_user = User.objects.create(email="email@mail.com") - test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - self.client.force_login(unverified_user) - with self.assertNumQueries(2): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - @patch("arkindex.project.mixins.get_max_level") - def test_create_access_level(self, get_max_level_mock): - cases = [None, Role.Guest.value, Role.Contributor.value] - test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - for level in cases: - with self.subTest(level=level): - get_max_level_mock.reset_mock() - get_max_level_mock.return_value = level - self.client.force_login(self.test_user) - - with self.assertNumQueries(3): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) - - self.assertEqual(get_max_level_mock.call_count, 1) - self.assertEqual(get_max_level_mock.call_args, call(self.test_user, self.private_corpus)) - - def test_create_process_mode(self): - cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local} - test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - for mode in cases: - with self.subTest(mode=mode): - self.dataset_process.mode = mode - self.dataset_process.save() - self.client.force_login(self.test_user) - - with self.assertNumQueries(6): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - self.assertEqual(response.json(), {"process": ['Datasets can only be added to or removed from processes of mode "dataset".']}) - - def test_create_process_mode_local(self): - self.client.force_login(self.user) - test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) - with self.assertNumQueries(3): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) - - def test_create_wrong_process_uuid(self): - self.client.force_login(self.test_user) - test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - wrong_id = uuid.uuid4() - with self.assertNumQueries(3): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.dataset2.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) - - def test_create_wrong_dataset_uuid(self): - self.client.force_login(self.test_user) - wrong_id = uuid.uuid4() - with self.assertNumQueries(4): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": wrong_id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {"dataset": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) - - @patch("arkindex.users.managers.BaseACLManager.filter_rights", return_value=Corpus.objects.none()) - def test_create_dataset_access(self, filter_rights_mock): - new_corpus = Corpus.objects.create(name="NERV") - new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) - test_sets = list(new_dataset.sets.values_list("name", flat=True)) - self.client.force_login(self.test_user) - - with self.assertNumQueries(3): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - self.assertEqual(response.json(), {"dataset": [f'Invalid pk "{str(new_dataset.id)}" - object does not exist.']}) - - self.assertEqual(filter_rights_mock.call_count, 1) - self.assertEqual(filter_rights_mock.call_args, call(self.test_user, Corpus, Role.Guest.value)) - - def test_create_unique(self): - self.client.force_login(self.test_user) - test_sets = list(self.dataset1.sets.values_list("name", flat=True)) - self.assertTrue(self.dataset_process.datasets.filter(id=self.dataset1.id).exists()) - - with self.assertNumQueries(6): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - self.assertDictEqual(response.json(), {"dataset": ["This dataset is already selected in this process."]}) - - def test_create_started(self): - self.client.force_login(self.test_user) - self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") - test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - - with self.assertNumQueries(6): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": test_sets} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]}) - - def test_create_default_sets(self): - self.client.force_login(self.test_user) - self.assertEqual(ProcessDataset.objects.count(), 3) - self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(7): - response = self.client.post( - reverse( - "api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id} - ), - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(ProcessDataset.objects.count(), 4) - self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - self.assertQuerysetEqual(self.dataset_process.datasets.order_by("name"), [ - self.private_dataset, - self.dataset1, - self.dataset2 - ]) - created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id) - process_sets = response.json().pop("sets") - self.assertCountEqual(process_sets, ["validation", "training", "test"]) - self.assertDictEqual(response.json(), { - "id": str(created.id), - "dataset": { - "id": str(self.dataset2.id), - "name": "Second Dataset", - "description": "dataset number two", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.dataset2.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - } - }) - - def test_create(self): - self.client.force_login(self.test_user) - self.assertEqual(ProcessDataset.objects.count(), 3) - self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(7): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": ["validation", "test"]} - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(ProcessDataset.objects.count(), 4) - self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - self.assertQuerysetEqual(self.dataset_process.datasets.order_by("name"), [ - self.private_dataset, - self.dataset1, - self.dataset2 - ]) - created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id) - self.assertDictEqual(response.json(), { - "id": str(created.id), - "dataset": { - "id": str(self.dataset2.id), - "name": "Second Dataset", - "description": "dataset number two", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.dataset2.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - }, - "sets": ["validation", "test"] - }) - - def test_create_wrong_sets(self): - self.client.force_login(self.test_user) - self.assertEqual(ProcessDataset.objects.count(), 3) - self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(6): - response = self.client.post( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": ["Unit-01"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]}) - - # Update process dataset - - def test_update_requires_login(self): - with self.assertNumQueries(0): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_update_requires_verified(self): - unverified_user = User.objects.create(email="email@mail.com") - self.client.force_login(unverified_user) - with self.assertNumQueries(2): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_update_access_level(self): - cases = [None, Role.Guest, Role.Contributor] - for level in cases: - with self.subTest(level=level): - self.private_corpus.memberships.filter(user=self.test_user).delete() - if level: - self.private_corpus.memberships.create(user=self.test_user, level=level.value) - self.client.force_login(self.test_user) - with self.assertNumQueries(5): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_update_process_does_not_exist(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_update_dataset_does_not_exist(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_update_m2m_does_not_exist(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_update(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(5): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - "id": str(self.process_dataset_1.id), - "dataset": { - "id": str(self.dataset1.id), - "name": "First Dataset", - "description": "dataset number one", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.dataset1.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - }, - "sets": ["test"] - }) - - def test_update_wrong_sets(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(4): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["Unit-01", "Unit-02"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]}) - - def test_update_unique_sets(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(4): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test", "test"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"sets": ["Sets must be unique."]}) - - def test_update_started_process(self): - self.dataset_process.tasks.create( - run=0, - depth=0, - slug="task", - expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), - ) - self.client.force_login(self.test_user) - with self.assertNumQueries(4): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]}) - - def test_update_only_sets(self): - """ - Non "sets" fields in the update request are ignored - """ - self.client.force_login(self.test_user) - with self.assertNumQueries(5): - response = self.client.put( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - "id": str(self.process_dataset_1.id), - "dataset": { - "id": str(self.dataset1.id), - "name": "First Dataset", - "description": "dataset number one", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.dataset1.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - }, - "sets": ["test"] - }) - - # Partial update process dataset - - def test_partial_update_requires_login(self): - with self.assertNumQueries(0): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_partial_update_requires_verified(self): - unverified_user = User.objects.create(email="email@mail.com") - self.client.force_login(unverified_user) - with self.assertNumQueries(2): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_partial_update_access_level(self): - cases = [None, Role.Guest, Role.Contributor] - for level in cases: - with self.subTest(level=level): - self.private_corpus.memberships.filter(user=self.test_user).delete() - if level: - self.private_corpus.memberships.create(user=self.test_user, level=level.value) - self.client.force_login(self.test_user) - with self.assertNumQueries(5): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_partial_update_process_does_not_exist(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_partial_update_dataset_does_not_exist(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_partial_update_m2m_does_not_exist(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_partial_update(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(5): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - "id": str(self.process_dataset_1.id), - "dataset": { - "id": str(self.dataset1.id), - "name": "First Dataset", - "description": "dataset number one", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.dataset1.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - }, - "sets": ["test"] - }) - - def test_partial_update_wrong_sets(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(4): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["Unit-01", "Unit-02"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]}) - - def test_partial_update_unique_sets(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(4): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test", "test"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"sets": ["Sets must be unique."]}) - - def test_partial_update_started_process(self): - self.dataset_process.tasks.create( - run=0, - depth=0, - slug="task", - expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), - ) - self.client.force_login(self.test_user) - with self.assertNumQueries(4): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]}) - - def test_partial_update_only_sets(self): - """ - Non "sets" fields in the partial update request are ignored - """ - self.client.force_login(self.test_user) - with self.assertNumQueries(5): - response = self.client.patch( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - "id": str(self.process_dataset_1.id), - "dataset": { - "id": str(self.dataset1.id), - "name": "First Dataset", - "description": "dataset number one", - "creator": "Test user", - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in self.dataset1.sets.all() - ], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED - }, - "sets": ["test"] - }) - - # Destroy process dataset - - def test_destroy_requires_login(self): - with self.assertNumQueries(0): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.private_dataset.id}), - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_destroy_process_does_not_exist(self): - self.client.force_login(self.test_user) - wrong_id = uuid.uuid4() - with self.assertNumQueries(3): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.private_dataset.id}) - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) - - def test_destroy_dataset_does_not_exist(self): - self.client.force_login(self.test_user) - wrong_id = uuid.uuid4() - with self.assertNumQueries(4): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": wrong_id}) - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), {"dataset": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) - - def test_destroy_not_found(self): - self.assertFalse(self.dataset_process.datasets.filter(id=self.dataset2.id).exists()) - self.client.force_login(self.test_user) - with self.assertNumQueries(6): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - @patch("arkindex.project.mixins.get_max_level", return_value=None) - def test_destroy_process_access_level(self, get_max_level_mock): - self.client.force_login(self.test_user) - with self.assertNumQueries(3): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.private_dataset.id}) - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - self.assertDictEqual(response.json(), {"detail": "You do not have admin access to this process."}) - - self.assertEqual(get_max_level_mock.call_count, 1) - self.assertEqual(get_max_level_mock.call_args, call(self.test_user, self.private_corpus)) - - def test_destroy_no_dataset_access_requirement(self): - new_corpus = Corpus.objects.create(name="NERV") - new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) - test_sets = list(new_dataset.sets.values_list("name", flat=True)) - ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=test_sets) - self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists()) - self.client.force_login(self.test_user) - with self.assertNumQueries(7): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}), - ) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists()) - - def test_destroy_process_mode(self): - cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local} - for mode in cases: - with self.subTest(mode=mode): - self.dataset_process.mode = mode - self.dataset_process.save() - self.client.force_login(self.test_user) - - with self.assertNumQueries(5): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - self.assertEqual(response.json(), {"process": ['Datasets can only be added to or removed from processes of mode "dataset".']}) - - def test_destroy_process_mode_local(self): - self.client.force_login(self.user) - local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) - with self.assertNumQueries(3): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}), - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) - - def test_destroy_started(self): - self.client.force_login(self.test_user) - self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") - - with self.assertNumQueries(5): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]}) - - def test_destroy(self): - self.client.force_login(self.test_user) - with self.assertNumQueries(7): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - ) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process, dataset=self.dataset1).exists()) - - def test_destroy_sets_agnostic(self): - """ - When deleting a process dataset, it doesn't matter what its sets are as there cannot be two process datasets - with the same process and dataset, whatever the sets are. - """ - self.process_dataset_1.sets = ["test"] - self.process_dataset_1.save() - self.client.force_login(self.test_user) - with self.assertNumQueries(7): - response = self.client.delete( - reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - ) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process, dataset=self.dataset1).exists()) diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index b897d3250d189a0e50e3a99634d317c3522bcc3d..84d6dcb9253147c24bcd704bde68579ee5e4b458 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -16,7 +16,7 @@ from arkindex.process.models import ( ActivityState, DataFile, Process, - ProcessDataset, + ProcessDatasetSet, ProcessMode, WorkerActivity, WorkerActivityState, @@ -43,6 +43,7 @@ class TestProcesses(FixtureAPITestCase): description="Human instrumentality manual", creator=cls.user ) + cls.private_dataset.sets.create(name="test") cls.img_df = cls.corpus.files.create( name="test.jpg", size=42, @@ -2319,12 +2320,13 @@ class TestProcesses(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - "non_field_errors": ["A dataset process cannot be started if it does not have any associated datasets."] + "non_field_errors": ["A dataset process cannot be started if it does not have any associated dataset sets."] }) def test_start_process_dataset_requires_dataset_in_same_corpus(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True))) + test_set = self.private_dataset.sets.get(name="test") + ProcessDatasetSet.objects.create(process=process2, set=test_set) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2336,13 +2338,15 @@ class TestProcesses(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - "non_field_errors": ["At least one of the process datasets must be from the same corpus as the process."] + "non_field_errors": ["At least one of the process sets must be from the same corpus as the process."] }) def test_start_process_dataset_unsupported_parameters(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.dataset2.sets.values_list("name", flat=True))) + test_set_1 = self.dataset1.sets.get(name="test") + test_set_2 = self.dataset2.sets.get(name="test") + ProcessDatasetSet.objects.create(process=process2, set=test_set_1) + ProcessDatasetSet.objects.create(process=process2, set=test_set_2) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.client.force_login(self.user) @@ -2366,8 +2370,10 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True))) + test_set_1 = self.dataset1.sets.get(name="test") + test_set_2 = self.private_dataset.sets.get(name="test") + ProcessDatasetSet.objects.create(process=process2, set=test_set_1) + ProcessDatasetSet.objects.create(process=process2, set=test_set_2) run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2387,8 +2393,8 @@ class TestProcesses(FixtureAPITestCase): self.assertEqual(process2.tasks.count(), 1) task = process2.tasks.get() self.assertEqual(task.slug, run.task_slug) - self.assertQuerysetEqual(process2.datasets.order_by("name"), [ - self.private_dataset, self.dataset1 + self.assertQuerysetEqual(process2.sets.order_by("dataset__name"), [ + test_set_2, test_set_1 ]) def test_start_process_from_docker_image(self): @@ -2562,8 +2568,10 @@ class TestProcesses(FixtureAPITestCase): It should be possible to pass chunks when starting a dataset process """ process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) - ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=list(self.dataset2.sets.values_list("name", flat=True))) + test_set_1 = self.dataset1.sets.get(name="test") + test_set_2 = self.dataset2.sets.get(name="test") + ProcessDatasetSet.objects.create(process=process, set=test_set_1) + ProcessDatasetSet.objects.create(process=process, set=test_set_2) # Add a worker run to this process run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 81d9724f0485b94a88eeda85aa198fa83e69e79d..aad7ba77c22be98ea347cf4cedeff218c4061f92 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -75,8 +75,8 @@ from arkindex.process.api import ( DataFileRetrieve, FilesProcess, ListProcessElements, - ProcessDatasetManage, - ProcessDatasets, + ProcessDatasetSetManage, + ProcessDatasetSets, ProcessDetails, ProcessList, ProcessRetry, @@ -272,8 +272,8 @@ api = [ path("process/<uuid:pk>/apply/", ApplyProcessTemplate.as_view(), name="apply-process-template"), path("process/<uuid:pk>/clear/", ClearProcess.as_view(), name="clear-process"), path("process/<uuid:pk>/select-failures/", SelectProcessFailures.as_view(), name="process-select-failures"), - path("process/<uuid:pk>/datasets/", ProcessDatasets.as_view(), name="process-datasets"), - path("process/<uuid:process>/dataset/<uuid:dataset>/", ProcessDatasetManage.as_view(), name="process-dataset"), + path("process/<uuid:pk>/sets/", ProcessDatasetSets.as_view(), name="process-sets"), + path("process/<uuid:process>/set/<uuid:set>/", ProcessDatasetSetManage.as_view(), name="process-set"), # ML models training path("modelversion/<uuid:pk>/", ModelVersionsRetrieve.as_view(), name="model-version-retrieve"), diff --git a/arkindex/sql_validation/corpus_delete.sql b/arkindex/sql_validation/corpus_delete.sql index 766566825d6557a396a95ec21749ab734761f9b4..6a7fc5c3fda81498cc5b5583bd671abcf811023c 100644 --- a/arkindex/sql_validation/corpus_delete.sql +++ b/arkindex/sql_validation/corpus_delete.sql @@ -165,18 +165,19 @@ FROM "documents_corpusexport" WHERE "documents_corpusexport"."corpus_id" = '{corpus_id}'::uuid; DELETE -FROM "process_processdataset" -WHERE "process_processdataset"."id" IN +FROM "process_processdatasetset" +WHERE "process_processdatasetset"."id" IN (SELECT U0."id" - FROM "process_processdataset" U0 - INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id") - WHERE U1."corpus_id" = '{corpus_id}'::uuid); + FROM "process_processdatasetset" U0 + INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id") + INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id") + WHERE U2."corpus_id" = '{corpus_id}'::uuid); DELETE -FROM "process_processdataset" -WHERE "process_processdataset"."id" IN +FROM "process_processdatasetset" +WHERE "process_processdatasetset"."id" IN (SELECT U0."id" - FROM "process_processdataset" U0 + FROM "process_processdatasetset" U0 INNER JOIN "process_process" U1 ON (U0."process_id" = U1."id") WHERE U1."corpus_id" = '{corpus_id}'::uuid); diff --git a/arkindex/sql_validation/corpus_delete_top_level_type.sql b/arkindex/sql_validation/corpus_delete_top_level_type.sql index d64cf0bb8b2eaabeb9a56b71612ba54223d8f0d2..693445cf0af7500ff68f5333f2dca9a72ef4f5e7 100644 --- a/arkindex/sql_validation/corpus_delete_top_level_type.sql +++ b/arkindex/sql_validation/corpus_delete_top_level_type.sql @@ -169,18 +169,19 @@ FROM "documents_corpusexport" WHERE "documents_corpusexport"."corpus_id" = '{corpus_id}'::uuid; DELETE -FROM "process_processdataset" -WHERE "process_processdataset"."id" IN +FROM "process_processdatasetset" +WHERE "process_processdatasetset"."id" IN (SELECT U0."id" - FROM "process_processdataset" U0 - INNER JOIN "training_dataset" U1 ON (U0."dataset_id" = U1."id") - WHERE U1."corpus_id" = '{corpus_id}'::uuid); + FROM "process_processdatasetset" U0 + INNER JOIN "training_datasetset" U1 ON (U0."set_id" = U1."id") + INNER JOIN "training_dataset" U2 ON (U1."dataset_id" = U2."id") + WHERE U2."corpus_id" = '{corpus_id}'::uuid); DELETE -FROM "process_processdataset" -WHERE "process_processdataset"."id" IN +FROM "process_processdatasetset" +WHERE "process_processdatasetset"."id" IN (SELECT U0."id" - FROM "process_processdataset" U0 + FROM "process_processdatasetset" U0 INNER JOIN "process_process" U1 ON (U0."process_id" = U1."id") WHERE U1."corpus_id" = '{corpus_id}'::uuid); diff --git a/arkindex/training/api.py b/arkindex/training/api.py index a2b21048eb3ff3dc8be2aba8b6ce2c03a9868ece..d8e952f15758132a581024beaa56d4c724855cba 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -22,6 +22,7 @@ from rest_framework.generics import ( from rest_framework.response import Response from arkindex.documents.models import Corpus, Element +from arkindex.process.models import ProcessDatasetSet from arkindex.project.mixins import ACLMixin, CorpusACLMixin, TrainingModelMixin from arkindex.project.pagination import CountCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly @@ -683,6 +684,8 @@ class CorpusDataset(CorpusACLMixin, ListCreateAPIView): """ Delete a dataset. Only datasets in an `open` state can be deleted. + A dataset cannot be deleted if one or more of its sets is selected in a process. + Requires an **admin** access to the dataset's corpus. """ ) @@ -708,13 +711,14 @@ class DatasetUpdate(ACLMixin, RetrieveUpdateDestroyAPIView): if request.method in permissions.SAFE_METHODS: return - role = Role.Contributor - if request.method == "DELETE": - role = Role.Admin - if obj.state != DatasetState.Open: - raise PermissionDenied(detail="Only datasets in an open state can be deleted.") + role = Role.Admin if request.method == "DELETE" else Role.Contributor if not self.has_access(obj.corpus, role.value): raise PermissionDenied(detail=f"You do not have {str(role).lower()} access to corpus {obj.corpus.name}.") + if self.request.method == "DELETE": + if obj.state != DatasetState.Open: + raise ValidationError(detail="Only datasets in an open state can be deleted.") + if ProcessDatasetSet.objects.filter(set__in=obj.sets.all()).exists(): + raise ValidationError(detail="This dataset cannot be deleted because at least one of its sets is used in a process.") # Prevent editing anything on a complete dataset if obj.state == DatasetState.Complete: diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 85b11b1eee1620690a121675a7c41021b8a3ee42..f6b3831216e4494641dd03e28318a14ddcbe36ab 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -581,8 +581,8 @@ class DatasetSerializer(serializers.ModelSerializer): raise ValidationError({ "state": ["Ponos task authentication is required to update the state of a Dataset."] }) - # Dataset's state update is only allowed on tasks of Dataset processes, that have this dataset included - if not request.auth.process.datasets.filter(id=self.instance.id).exists(): + # Dataset's state update is only allowed on tasks of Dataset processes, that have sets from this dataset included + if not request.auth.process.sets.filter(dataset_id=self.instance.id).exists(): raise ValidationError({"state": ["A task can only update the state of one of the datasets of its process."]}) # Link a completed dataset to the current task which generated its artifacts if state == DatasetState.Complete: diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index b24708aa7bf82907d5bb4ff62da120cf6ded5073..4446928dd0ec6ab25e5e41b2f8ba3f5b89b25441 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -6,7 +6,7 @@ from django.utils import timezone as DjangoTimeZone from rest_framework import status from arkindex.documents.models import Corpus -from arkindex.process.models import Process, ProcessDataset, ProcessMode +from arkindex.process.models import Process, ProcessDatasetSet, ProcessMode from arkindex.project.tests import FixtureAPITestCase from arkindex.project.tools import fake_now from arkindex.training.models import Dataset, DatasetElement, DatasetSet, DatasetState @@ -30,8 +30,11 @@ class TestDatasetsAPI(FixtureAPITestCase): cls.write_user = User.objects.get(email="user2@user.fr") cls.dataset = Dataset.objects.get(name="First Dataset") cls.dataset2 = Dataset.objects.get(name="Second Dataset") - ProcessDataset.objects.create(process=cls.process, dataset=cls.dataset, sets=["training", "test", "validation"]) - ProcessDataset.objects.create(process=cls.process, dataset=cls.dataset2, sets=["test"]) + ProcessDatasetSet.objects.bulk_create( + ProcessDatasetSet(process=cls.process, set=ds) + for ds in cls.dataset.sets.all() + ) + ProcessDatasetSet.objects.create(process=cls.process, set=cls.dataset2.sets.get(name="test")) cls.private_dataset = Dataset.objects.create(name="Private Dataset", description="Dead Sea Scrolls", corpus=cls.private_corpus, creator=cls.dataset_creator) cls.private_dataset_set = DatasetSet.objects.create(dataset=cls.private_dataset, name="Private set") cls.vol = cls.corpus.elements.get(name="Volume 1") @@ -680,7 +683,7 @@ class TestDatasetsAPI(FixtureAPITestCase): ) def test_update_ponos_task_state_requires_dataset_in_process(self): - self.process.process_datasets.all().delete() + self.process.process_sets.all().delete() with self.assertNumQueries(5): response = self.client.put( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), @@ -841,7 +844,7 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(self.dataset.state, DatasetState.Building) def test_partial_update_ponos_task_state_requires_dataset_in_process(self): - self.process.process_datasets.all().delete() + self.process.process_sets.all().delete() with self.assertNumQueries(5): response = self.client.patch( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), @@ -1059,8 +1062,20 @@ class TestDatasetsAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertDictEqual(response.json(), {"detail": "Not found."}) + def test_delete_dataset_in_process_forbidden(self): + self.client.force_login(self.user) + with self.assertNumQueries(4): + response = self.client.delete( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), ["This dataset cannot be deleted because at least one of its sets is used in a process."]) + self.dataset.refresh_from_db() + def test_delete(self): self.client.force_login(self.user) + # Remove dataset sets from process + ProcessDatasetSet.objects.filter(process_id=self.process.id, set__dataset_id=self.dataset.id).delete() with self.assertNumQueries(7): response = self.client.delete( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), @@ -1069,8 +1084,9 @@ class TestDatasetsAPI(FixtureAPITestCase): with self.assertRaises(Dataset.DoesNotExist): self.dataset.refresh_from_db() - def test_delete_not_open(self): - self.client.force_login(self.user) + @patch("arkindex.project.mixins.has_access", return_value=True) + def test_delete_not_open(self, has_access_mock): + self.client.force_login(self.write_user) cases = [DatasetState.Building, DatasetState.Complete, DatasetState.Error] for state in cases: with self.subTest(state=state): @@ -1080,8 +1096,11 @@ class TestDatasetsAPI(FixtureAPITestCase): response = self.client.delete( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertDictEqual(response.json(), {"detail": "Only datasets in an open state can be deleted."}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), ["Only datasets in an open state can be deleted."]) + self.assertEqual(has_access_mock.call_count, 1) + self.assertEqual(has_access_mock.call_args, call(self.write_user, self.corpus, Role.Admin.value, skip_public=False)) + has_access_mock.reset_mock() def test_delete_elements(self): """ @@ -1093,6 +1112,8 @@ class TestDatasetsAPI(FixtureAPITestCase): train_set.set_elements.create(element_id=self.page1.id, set="training") validation_set.set_elements.create(element_id=self.page2.id, set="validation") validation_set.set_elements.create(element_id=self.page3.id, set="validation") + # Remove dataset sets from process + ProcessDatasetSet.objects.filter(process_id=self.process.id, set__dataset_id=self.dataset.id).delete() self.client.force_login(self.user) with self.assertNumQueries(7):