From 42d879437c676a21a848e564048cd0f3b50ffa63 Mon Sep 17 00:00:00 2001 From: ml bonhomme <bonhomme@teklia.com> Date: Wed, 21 Feb 2024 10:29:50 +0000 Subject: [PATCH] Allow selecting dataset sets in dataset processes --- .../tests/tasks/test_corpus_delete.py | 15 +- arkindex/process/api.py | 45 +- .../migrations/0029_processdataset_sets.py | 43 ++ arkindex/process/models.py | 9 +- arkindex/process/serializers/training.py | 109 ++-- arkindex/process/tests/test_create_process.py | 8 +- .../process/tests/test_process_datasets.py | 467 ++++++++++++++++-- arkindex/process/tests/test_processes.py | 12 +- arkindex/training/serializers.py | 7 +- arkindex/training/tests/test_datasets_api.py | 54 +- 10 files changed, 665 insertions(+), 104 deletions(-) create mode 100644 arkindex/process/migrations/0029_processdataset_sets.py diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py index a17590abe7..de915af8ff 100644 --- a/arkindex/documents/tests/tasks/test_corpus_delete.py +++ b/arkindex/documents/tests/tasks/test_corpus_delete.py @@ -3,7 +3,7 @@ from django.db.models.signals import pre_delete from arkindex.documents.models import Corpus, Element, EntityType, MetaType, Transcription from arkindex.documents.tasks import corpus_delete from arkindex.ponos.models import Farm, State, Task -from arkindex.process.models import CorpusWorkerVersion, ProcessMode, Repository, WorkerVersion +from arkindex.process.models import CorpusWorkerVersion, ProcessDataset, ProcessMode, Repository, WorkerVersion from arkindex.project.tests import FixtureTestCase, force_constraints_immediate from arkindex.training.models import Dataset @@ -118,13 +118,14 @@ class TestDeleteCorpus(FixtureTestCase): cls.dataset2 = Dataset.objects.create(name="Dead Sea Scrolls", description="How to trigger a Third Impact", creator=cls.user, corpus=cls.corpus2) # Process on cls.corpus and with a dataset from cls.corpus dataset_process1 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset) - dataset_process1.datasets.set([dataset1]) + ProcessDataset.objects.create(process=dataset_process1, dataset=dataset1, sets=dataset1.sets) # Process on cls.corpus with a dataset from another corpus dataset_process2 = cls.corpus.processes.create(creator=cls.user, mode=ProcessMode.Dataset) - dataset_process2.datasets.set([dataset1, cls.dataset2]) + ProcessDataset.objects.create(process=dataset_process2, dataset=dataset1, sets=dataset1.sets) + ProcessDataset.objects.create(process=dataset_process2, dataset=cls.dataset2, sets=cls.dataset2.sets) # Process on another corpus with a dataset from another corpus and none from cls.corpus - cls.dataset_process2 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset) - cls.dataset_process2.datasets.set([cls.dataset2]) + cls.dataset_process3 = cls.corpus2.processes.create(creator=cls.user, mode=ProcessMode.Dataset) + ProcessDataset.objects.create(process=cls.dataset_process3, dataset=cls.dataset2, sets=cls.dataset2.sets) cls.rev = cls.repo.revisions.create( hash="42", @@ -200,14 +201,14 @@ class TestDeleteCorpus(FixtureTestCase): self.df.refresh_from_db() self.vol.refresh_from_db() self.page.refresh_from_db() - self.dataset_process2.refresh_from_db() + self.dataset_process3.refresh_from_db() self.assertTrue(self.repo.revisions.filter(id=self.rev.id).exists()) self.assertEqual(self.process.revision, self.rev) self.assertEqual(self.process.files.get(), self.df) self.assertTrue(Element.objects.get_descending(self.vol.id).filter(id=self.page.id).exists()) self.assertTrue(self.corpus2.datasets.filter(id=self.dataset2.id).exists()) - self.assertTrue(self.corpus2.processes.filter(id=self.dataset_process2.id).exists()) + self.assertTrue(self.corpus2.processes.filter(id=self.dataset_process3.id).exists()) md = self.vol.metadatas.get() self.assertEqual(md.name, "meta") diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 26001aaab9..db75eac89a 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -46,6 +46,7 @@ from rest_framework.generics import ( RetrieveDestroyAPIView, RetrieveUpdateAPIView, RetrieveUpdateDestroyAPIView, + UpdateAPIView, ) from rest_framework.response import Response from rest_framework.serializers import Serializer @@ -125,8 +126,7 @@ from arkindex.project.pagination import CountCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.project.tools import PercentileCont from arkindex.project.triggers import process_delete -from arkindex.training.models import Dataset, Model -from arkindex.training.serializers import DatasetSerializer +from arkindex.training.models import Model from arkindex.users.models import Role, Scope logger = logging.getLogger(__name__) @@ -695,8 +695,8 @@ class DataFileCreate(CreateAPIView): ) class ProcessDatasets(ProcessACLMixin, ListAPIView): permission_classes = (IsVerified, ) - serializer_class = DatasetSerializer - queryset = Dataset.objects.none() + serializer_class = ProcessDatasetSerializer + queryset = ProcessDataset.objects.none() @cached_property def process(self): @@ -709,7 +709,11 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): return process def get_queryset(self): - return self.process.datasets.select_related("creator").order_by("name") + return ( + ProcessDataset.objects.filter(process_id=self.process.id) + .select_related("process__creator", "dataset__creator") + .order_by("dataset__name") + ) def get_serializer_context(self): context = super().get_serializer_context() @@ -717,6 +721,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): if not self.kwargs: return context context["process"] = self.process + # Disable set elements counts in serialized dataset context["sets_count"] = False return context @@ -744,24 +749,30 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): ), ), ) -class ProcessDatasetManage(CreateAPIView, DestroyAPIView): +class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): permission_classes = (IsVerified, ) serializer_class = ProcessDatasetSerializer - def get_serializer_from_params(self, process=None, dataset=None, **kwargs): - data = {"process": process, "dataset": dataset} - kwargs["context"] = self.get_serializer_context() - return ProcessDatasetSerializer(data=data, **kwargs) + def get_object(self): + process_dataset = get_object_or_404( + ProcessDataset.objects + .select_related("dataset__creator", "process__corpus") + # Required to check for a process that have already started + .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))), + dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"] + ) + # Copy the has_tasks annotation onto the process + process_dataset.process.has_tasks = process_dataset.process_has_tasks + return process_dataset - def create(self, request, *args, **kwargs): - serializer = self.get_serializer_from_params(**kwargs) - serializer.is_valid(raise_exception=True) - serializer.create(serializer.validated_data) - headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + def get_serializer_context(self): + context = super().get_serializer_context() + # Disable set elements counts in serialized dataset + context["sets_count"] = False + return context def destroy(self, request, *args, **kwargs): - serializer = self.get_serializer_from_params(**kwargs) + serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) get_object_or_404(ProcessDataset, **serializer.validated_data).delete() return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/arkindex/process/migrations/0029_processdataset_sets.py b/arkindex/process/migrations/0029_processdataset_sets.py new file mode 100644 index 0000000000..868c1cc29f --- /dev/null +++ b/arkindex/process/migrations/0029_processdataset_sets.py @@ -0,0 +1,43 @@ +import django.core.validators +from django.db import migrations, models + +import arkindex.project.fields +import arkindex.training.models + + +class Migration(migrations.Migration): + dependencies = [ + ("process", "0028_remove_process_model_remove_process_test_folder_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="processdataset", + name="sets", + field=arkindex.project.fields.ArrayField(base_field=models.CharField(max_length=50), blank=True, default=list, size=None), + ), + migrations.RunSQL( + [ + """ + UPDATE process_processdataset p + SET sets = d.sets + FROM training_dataset d + WHERE p.dataset_id = d.id + """ + ], + reverse_sql=migrations.RunSQL.noop, + elidable=True, + ), + migrations.AlterField( + model_name="processdataset", + name="sets", + field=arkindex.project.fields.ArrayField( + base_field=models.CharField( + max_length=50, + validators=[django.core.validators.MinLengthValidator(1)] + ), + size=None, + validators=[django.core.validators.MinLengthValidator(1), arkindex.training.models.validate_unique_set_names] + ), + ), + ] diff --git a/arkindex/process/models.py b/arkindex/process/models.py index 8f1db92144..ac26165f47 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -27,7 +27,7 @@ from arkindex.project.aws import S3FileMixin, S3FileStatus from arkindex.project.fields import ArrayField, MD5HashField from arkindex.project.models import IndexableModel from arkindex.project.validators import MaxValueValidator -from arkindex.training.models import ModelVersion, ModelVersionState +from arkindex.training.models import ModelVersion, ModelVersionState, validate_unique_set_names from arkindex.users.models import Role @@ -466,6 +466,13 @@ class ProcessDataset(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) process = models.ForeignKey(Process, on_delete=models.CASCADE, related_name="process_datasets") dataset = models.ForeignKey("training.Dataset", on_delete=models.CASCADE, related_name="process_datasets") + sets = ArrayField( + models.CharField(max_length=50, validators=[MinLengthValidator(1)]), + validators=[ + MinLengthValidator(1), + validate_unique_set_names, + ] + ) class Meta: constraints = [ diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py index 96db960576..f56fc916a9 100644 --- a/arkindex/process/serializers/training.py +++ b/arkindex/process/serializers/training.py @@ -5,62 +5,96 @@ from rest_framework import serializers from rest_framework.exceptions import PermissionDenied, ValidationError from arkindex.documents.models import Corpus -from arkindex.ponos.models import Task -from arkindex.process.models import Process, ProcessDataset, ProcessMode +from arkindex.process.models import Process, ProcessDataset, ProcessMode, Task from arkindex.project.mixins import ProcessACLMixin from arkindex.training.models import Dataset +from arkindex.training.serializers import DatasetSerializer from arkindex.users.models import Role +def _dataset_id_from_context(serializer_field): + return serializer_field.context.get("view").kwargs["dataset"] + + +def _process_id_from_context(serializer_field): + return serializer_field.context.get("view").kwargs["process"] + + +_dataset_id_from_context.requires_context = True +_process_id_from_context.requires_context = True + + class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): - process = serializers.PrimaryKeyRelatedField( - queryset=( - Process - .objects - # Avoid a stale read when adding a dataset right after creating a process - .using("default") - # Required for ACL checks - .select_related("corpus") - # Required to check for a process that have already started - .annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef("pk")))) - ), - style={"base_template": "input.html"}, + process_id = serializers.HiddenField( + write_only=True, + default=_process_id_from_context ) - dataset = serializers.PrimaryKeyRelatedField( - queryset=Dataset.objects.none(), - style={"base_template": "input.html"}, + dataset_id = serializers.HiddenField( + write_only=True, + default=_dataset_id_from_context + ) + dataset = DatasetSerializer(read_only=True) + sets = serializers.ListField( + child=serializers.CharField(max_length=50), + required=False, + allow_null=False, + allow_empty=False, + min_length=1 ) class Meta: model = ProcessDataset - fields = ("dataset", "process", "id", ) - read_only_fields = ("process", "id", ) + fields = ("dataset_id", "dataset", "process_id", "id", "sets", ) + read_only_fields = ("process_id", "id", ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self.context.get("request"): # Do not raise Error in order to create OpenAPI schema return - - request_method = self.context["request"].method # Required for the ProcessACLMixin and readable corpora self._user = self.context["request"].user - if request_method == "DELETE": - # Allow deleting ProcessDatasets even if the user looses access to the corpus - self.fields["dataset"].queryset = Dataset.objects.all() - else: - self.fields["dataset"].queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) + def validate(self, data): + request_method = self.context["request"].method + errors = defaultdict(list) - def validate_process(self, process): + # Validate process + process_qs = ( + Process + .objects + # Avoid a stale read when adding a dataset right after creating a process + .using("default") + # Required for ACL checks + .select_related("corpus", "revision__repo") + # Required to check for a process that has already started + .annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef("pk")))) + ) + if not self.instance: + try: + process = process_qs.get(pk=data["process_id"]) + except Process.DoesNotExist: + raise ValidationError({"process": [f'Invalid pk "{str(data["process_id"])}" - object does not exist.']}) + else: + process = self.instance.process access = self.process_access_level(process) if not access or not (access >= Role.Admin.value): raise PermissionDenied(detail="You do not have admin access to this process.") - return process - def validate(self, data): - process, dataset = data["process"], data["dataset"] - errors = defaultdict(list) + # Validate dataset + if not self.instance: + if request_method == "DELETE": + # Allow deleting ProcessDatasets even if the user looses access to the corpus + dataset_qs = Dataset.objects.all() + else: + dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) + try: + dataset = dataset_qs.select_related("creator").get(pk=data["dataset_id"]) + except Dataset.DoesNotExist: + raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']}) + else: + dataset = self.instance.dataset + data["dataset"] = dataset if process.mode != ProcessMode.Dataset: errors["process"].append('Datasets can only be added to or removed from processes of mode "dataset".') @@ -71,6 +105,19 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): if self.context["request"].method == "POST" and process.datasets.filter(id=dataset.id).exists(): errors["dataset"].append("This dataset is already selected in this process.") + # Validate sets + sets = data.get("sets") + if not sets or len(sets) == 0: + if not self.instance: + data["sets"] = dataset.sets + else: + errors["sets"].append("This field cannot be empty.") + else: + if any(s not in dataset.sets for s in sets): + errors["sets"].append("The specified sets must all exist in the specified dataset.") + if len(set(sets)) != len(sets): + errors["sets"].append("Sets must be unique.") + if errors: raise ValidationError(errors) diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py index 79b0e98ba4..008c963eaa 100644 --- a/arkindex/process/tests/test_create_process.py +++ b/arkindex/process/tests/test_create_process.py @@ -11,6 +11,7 @@ from arkindex.ponos.models import Farm, State from arkindex.process.models import ( ActivityState, Process, + ProcessDataset, ProcessMode, Repository, WorkerActivity, @@ -826,7 +827,8 @@ class TestCreateProcess(FixtureAPITestCase): def test_dataset_gpu_required(self): self.client.force_login(self.user) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - process.datasets.add(self.corpus.datasets.first()) + dataset = self.corpus.datasets.first() + ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) process.versions.set([self.version_2, self.version_3]) with self.assertNumQueries(15): @@ -854,9 +856,9 @@ class TestCreateProcess(FixtureAPITestCase): self.client.force_login(self.user) self.worker_1.archived = datetime.now(timezone.utc) self.worker_1.save() - process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - process.datasets.add(self.corpus.datasets.first()) + dataset = self.corpus.datasets.first() + ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) process.versions.add(self.version_1) with self.assertNumQueries(15): diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py index 142708b126..16005a9513 100644 --- a/arkindex/process/tests/test_process_datasets.py +++ b/arkindex/process/tests/test_process_datasets.py @@ -1,10 +1,12 @@ import uuid +from datetime import datetime, timezone from unittest.mock import patch from django.urls import reverse from rest_framework import status from arkindex.documents.models import Corpus +from arkindex.ponos.models import Farm from arkindex.process.models import Process, ProcessDataset, ProcessMode, Repository from arkindex.project.tests import FixtureAPITestCase from arkindex.training.models import Dataset @@ -35,9 +37,11 @@ class TestProcessDatasets(FixtureAPITestCase): cls.dataset_process = Process.objects.create( creator_id=cls.user.id, mode=ProcessMode.Dataset, - corpus_id=cls.private_corpus.id + corpus_id=cls.private_corpus.id, + farm=Farm.objects.get(name="Wheat farm") ) - cls.dataset_process.datasets.set([cls.dataset1, cls.private_dataset]) + cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=cls.dataset1.sets) + cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=cls.private_dataset.sets) # Control process to check that its datasets are not retrieved cls.dataset_process_2 = Process.objects.create( @@ -45,6 +49,7 @@ class TestProcessDatasets(FixtureAPITestCase): mode=ProcessMode.Dataset, corpus_id=cls.corpus.id ) + ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=cls.dataset2.sets) cls.dataset_process_2.datasets.set([cls.dataset2]) # For repository process @@ -80,30 +85,38 @@ class TestProcessDatasets(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.json()["results"], [ { - "id": str(self.private_dataset.id), - "name": "Dead sea scrolls", - "description": "Human instrumentality manual", - "creator": "Test user", - "sets": ["training", "test", "validation"], - "set_elements": None, - "corpus_id": str(self.private_corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED + "id": str(self.process_dataset_2.id), + "dataset": { + "id": str(self.private_dataset.id), + "name": "Dead sea scrolls", + "description": "Human instrumentality manual", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.private_corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["training", "test", "validation"] }, { - "id": str(self.dataset1.id), - "name": "First Dataset", - "description": "dataset number one", - "creator": "Test user", - "sets": ["training", "test", "validation"], - "set_elements": None, - "corpus_id": str(self.corpus.id), - "state": "open", - "task_id": None, - "created": FAKE_CREATED, - "updated": FAKE_CREATED + "id": str(self.process_dataset_1.id), + "dataset": { + "id": str(self.dataset1.id), + "name": "First Dataset", + "description": "dataset number one", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["training", "test", "validation"] } ]) @@ -113,6 +126,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(0): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -122,6 +136,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(2): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -133,9 +148,10 @@ class TestProcessDatasets(FixtureAPITestCase): if level: self.private_corpus.memberships.create(user=self.test_user, level=level.value) self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(5): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) @@ -151,6 +167,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(8): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -159,9 +176,10 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_process_mode_local(self): self.client.force_login(self.user) local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) - with self.assertNumQueries(6): + with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) @@ -169,9 +187,10 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_process_mode_repository(self): self.client.force_login(self.user) process = Process.objects.create(creator=self.user, mode=ProcessMode.Repository, revision=self.rev) - with self.assertNumQueries(10): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": process.id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) @@ -179,9 +198,10 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_wrong_process_uuid(self): self.client.force_login(self.test_user) wrong_id = uuid.uuid4() - with self.assertNumQueries(6): + with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) @@ -192,6 +212,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(7): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": wrong_id}), + data={"sets": ["test"]} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {"dataset": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) @@ -203,6 +224,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(7): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}), + data={"sets": new_dataset.sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {"dataset": [f'Invalid pk "{str(new_dataset.id)}" - object does not exist.']}) @@ -214,6 +236,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(8): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": self.dataset1.sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -226,11 +249,49 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(8): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": self.dataset2.sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]}) + def test_create_default_sets(self): + self.client.force_login(self.test_user) + self.assertEqual(ProcessDataset.objects.count(), 3) + self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) + with self.assertNumQueries(9): + response = self.client.post( + reverse( + "api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id} + ), + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(ProcessDataset.objects.count(), 4) + self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) + self.assertQuerysetEqual(self.dataset_process.datasets.order_by("name"), [ + self.private_dataset, + self.dataset1, + self.dataset2 + ]) + created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id) + self.assertDictEqual(response.json(), { + "id": str(created.id), + "dataset": { + "id": str(self.dataset2.id), + "name": "Second Dataset", + "description": "dataset number two", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["training", "test", "validation"] + }) + def test_create(self): self.client.force_login(self.test_user) self.assertEqual(ProcessDataset.objects.count(), 3) @@ -238,6 +299,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(9): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": ["validation", "test"]} ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(ProcessDataset.objects.count(), 4) @@ -247,6 +309,345 @@ class TestProcessDatasets(FixtureAPITestCase): self.dataset1, self.dataset2 ]) + created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id) + self.assertDictEqual(response.json(), { + "id": str(created.id), + "dataset": { + "id": str(self.dataset2.id), + "name": "Second Dataset", + "description": "dataset number two", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["validation", "test"] + }) + + def test_create_wrong_sets(self): + self.client.force_login(self.test_user) + self.assertEqual(ProcessDataset.objects.count(), 3) + self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) + with self.assertNumQueries(8): + response = self.client.post( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": ["Unit-01"]} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) + self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]}) + + # Update process dataset + + def test_update_requires_login(self): + with self.assertNumQueries(0): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_update_requires_verified(self): + unverified_user = User.objects.create(email="email@mail.com") + self.client.force_login(unverified_user) + with self.assertNumQueries(2): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_update_access_level(self): + cases = [None, Role.Guest, Role.Contributor] + for level in cases: + with self.subTest(level=level): + self.private_corpus.memberships.filter(user=self.test_user).delete() + if level: + self.private_corpus.memberships.create(user=self.test_user, level=level.value) + self.client.force_login(self.test_user) + with self.assertNumQueries(5): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) + + def test_update_process_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_update_dataset_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_update_m2m_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_update(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(7): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "id": str(self.process_dataset_1.id), + "dataset": { + "id": str(self.dataset1.id), + "name": "First Dataset", + "description": "dataset number one", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["test"] + }) + + def test_update_wrong_sets(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["Unit-01", "Unit-02"]} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]}) + + def test_update_unique_sets(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test", "test"]} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"sets": ["Sets must be unique."]}) + + def test_update_started_process(self): + self.dataset_process.tasks.create( + run=0, + depth=0, + slug="task", + expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), + ) + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]}) + + def test_update_only_sets(self): + """ + Non "sets" fields in the update request are ignored + """ + self.client.force_login(self.test_user) + with self.assertNumQueries(7): + response = self.client.put( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "id": str(self.process_dataset_1.id), + "dataset": { + "id": str(self.dataset1.id), + "name": "First Dataset", + "description": "dataset number one", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["test"] + }) + + # Partial update process dataset + + def test_partial_update_requires_login(self): + with self.assertNumQueries(0): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_partial_update_requires_verified(self): + unverified_user = User.objects.create(email="email@mail.com") + self.client.force_login(unverified_user) + with self.assertNumQueries(2): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_partial_update_access_level(self): + cases = [None, Role.Guest, Role.Contributor] + for level in cases: + with self.subTest(level=level): + self.private_corpus.memberships.filter(user=self.test_user).delete() + if level: + self.private_corpus.memberships.create(user=self.test_user, level=level.value) + self.client.force_login(self.test_user) + with self.assertNumQueries(5): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) + + def test_partial_update_process_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_partial_update_dataset_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_partial_update_m2m_does_not_exist(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(3): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_partial_update(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(7): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "id": str(self.process_dataset_1.id), + "dataset": { + "id": str(self.dataset1.id), + "name": "First Dataset", + "description": "dataset number one", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["test"] + }) + + def test_partial_update_wrong_sets(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["Unit-01", "Unit-02"]} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"sets": ["The specified sets must all exist in the specified dataset."]}) + + def test_partial_update_unique_sets(self): + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test", "test"]} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"sets": ["Sets must be unique."]}) + + def test_partial_update_started_process(self): + self.dataset_process.tasks.create( + run=0, + depth=0, + slug="task", + expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), + ) + self.client.force_login(self.test_user) + with self.assertNumQueries(6): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"process": ["Datasets cannot be updated on processes that have already started."]}) + + def test_partial_update_only_sets(self): + """ + Non "sets" fields in the partial update request are ignored + """ + self.client.force_login(self.test_user) + with self.assertNumQueries(7): + response = self.client.patch( + reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), + data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + "id": str(self.process_dataset_1.id), + "dataset": { + "id": str(self.dataset1.id), + "name": "First Dataset", + "description": "dataset number one", + "creator": "Test user", + "sets": ["training", "test", "validation"], + "set_elements": None, + "corpus_id": str(self.corpus.id), + "state": "open", + "task_id": None, + "created": FAKE_CREATED, + "updated": FAKE_CREATED + }, + "sets": ["test"] + }) # Destroy process dataset @@ -260,7 +661,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_process_does_not_exist(self): self.client.force_login(self.test_user) wrong_id = uuid.uuid4() - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.private_dataset.id}) ) @@ -289,7 +690,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_process_access_level(self): self.private_corpus.memberships.filter(user=self.test_user).delete() self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(5): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.private_dataset.id}) ) @@ -299,7 +700,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_no_dataset_access_requirement(self): new_corpus = Corpus.objects.create(name="NERV") new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) - self.dataset_process.datasets.add(new_dataset) + ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=new_dataset.sets) self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists()) self.client.force_login(self.test_user) with self.assertNumQueries(9): @@ -328,7 +729,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_process_mode_local(self): self.client.force_login(self.user) local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}), ) @@ -338,7 +739,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_process_mode_repository(self): self.client.force_login(self.user) process = Process.objects.create(creator=self.user, mode=ProcessMode.Repository, revision=self.rev) - with self.assertNumQueries(9): + with self.assertNumQueries(6): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": process.id, "dataset": self.dataset2.id}), ) diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index 3ec9cc6491..d8bdb5cef8 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -15,6 +15,7 @@ from arkindex.process.models import ( ActivityState, DataFile, Process, + ProcessDataset, ProcessMode, Repository, WorkerActivity, @@ -2436,7 +2437,7 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset_requires_dataset_in_same_corpus(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - process2.datasets.set([self.private_dataset]) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2453,7 +2454,8 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset_unsupported_parameters(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - process2.datasets.set([self.dataset1, self.private_dataset]) + ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.dataset2.sets) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.client.force_login(self.user) @@ -2477,7 +2479,8 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - process2.datasets.set([self.dataset1, self.private_dataset]) + ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2669,7 +2672,8 @@ class TestProcesses(FixtureAPITestCase): It should be possible to pass chunks when starting a dataset process """ process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - process.datasets.set([self.dataset1, self.dataset2]) + ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=self.dataset1.sets) + ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=self.dataset2.sets) # Add a worker run to this process run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index b92c44d1c5..4cff4c8f0f 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -13,7 +13,7 @@ from rest_framework.validators import UniqueTogetherValidator from arkindex.documents.models import Element from arkindex.documents.serializers.elements import ElementListSerializer from arkindex.ponos.models import Task -from arkindex.process.models import Worker +from arkindex.process.models import ProcessDataset, Worker from arkindex.project.serializer_fields import ArchivedField, DatasetSetsCountField, EnumField from arkindex.training.models import ( Dataset, @@ -560,8 +560,11 @@ class DatasetSerializer(serializers.ModelSerializer): raise ValidationError("Set names must be unique.") removed, added = self.sets_diff(sets) + if removed and ProcessDataset.objects.filter(sets__overlap=removed, dataset_id=self.instance.id).exists(): + # Sets that are used in a ProcessDataset cannot be renamed or deleted + raise ValidationError("These sets cannot be updated because one or more are selected in a dataset process.") if not removed or not added: - # Some sets have been added or removed, do nothing + # Some sets have either been added or removed, but not both; do nothing return sets elif len(removed) == 1 and len(added) == 1: # A single set has been renamed. Move its elements later, while performing the update diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 653ffa152b..f6a935b162 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -6,7 +6,7 @@ from django.utils import timezone as DjangoTimeZone from rest_framework import status from arkindex.documents.models import Corpus -from arkindex.process.models import Process, ProcessMode +from arkindex.process.models import Process, ProcessDataset, ProcessMode from arkindex.project.tests import FixtureAPITestCase from arkindex.project.tools import fake_now from arkindex.training.models import Dataset, DatasetState @@ -30,7 +30,8 @@ class TestDatasetsAPI(FixtureAPITestCase): cls.write_user = User.objects.get(email="user2@user.fr") cls.dataset = Dataset.objects.get(name="First Dataset") cls.dataset2 = Dataset.objects.get(name="Second Dataset") - cls.process.datasets.set((cls.dataset, cls.dataset2)) + ProcessDataset.objects.create(process=cls.process, dataset=cls.dataset, sets=["training", "test", "validation"]) + ProcessDataset.objects.create(process=cls.process, dataset=cls.dataset2, sets=["test"]) cls.private_dataset = Dataset.objects.create(name="Private Dataset", description="Dead Sea Scrolls", corpus=cls.private_corpus, creator=cls.dataset_creator) cls.vol = cls.corpus.elements.get(name="Volume 1") cls.page1 = cls.corpus.elements.get(name="Volume 1, page 1r") @@ -504,9 +505,11 @@ class TestDatasetsAPI(FixtureAPITestCase): """ It is possible to remove many sets, no elements are moved """ + # Remove ProcessDataset relation + ProcessDataset.objects.get(process=self.process, dataset=self.dataset).delete() self.client.force_login(self.user) dataset_elt = self.dataset.dataset_elements.create(element=self.page1, set="training") - with self.assertNumQueries(11): + with self.assertNumQueries(12): response = self.client.put( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), data={ @@ -532,13 +535,14 @@ class TestDatasetsAPI(FixtureAPITestCase): def test_update_sets_update_single_set(self): """ - It is possible to rename a single set + It is possible to rename a single set, if it is not referenced by a ProcessDataset """ + ProcessDataset.objects.get(process=self.process, dataset=self.dataset, sets=["training", "test", "validation"]).delete() self.client.force_login(self.user) self.dataset.dataset_elements.create(element_id=self.page1.id, set="training") self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation") self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation") - with self.assertNumQueries(12): + with self.assertNumQueries(13): response = self.client.put( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), data={ @@ -565,12 +569,50 @@ class TestDatasetsAPI(FixtureAPITestCase): ] ) + def test_update_sets_processdataset_reference(self): + """ + If a dataset's sets are referenced by a ProcessDataset, they cannot be updated + """ + self.client.force_login(self.user) + self.dataset.dataset_elements.create(element_id=self.page1.id, set="training") + self.dataset.dataset_elements.create(element_id=self.page2.id, set="validation") + self.dataset.dataset_elements.create(element_id=self.page3.id, set="validation") + with self.assertNumQueries(7): + response = self.client.put( + reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), + data={ + "name": "Shin Seiki Evangelion", + "description": "Omedeto!", + # validation set is renamed to AAAAAAA + "sets": ["test", "training", "AAAAAAA"], + }, + format="json" + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"sets": ["These sets cannot be updated because one or more are selected in a dataset process."]}) + self.dataset.refresh_from_db() + self.assertEqual(self.dataset.state, DatasetState.Open) + self.assertEqual(self.dataset.name, "First Dataset") + self.assertEqual(self.dataset.description, "dataset number one") + self.assertListEqual(self.dataset.sets, ["training", "test", "validation"]) + self.assertIsNone(self.dataset.task_id) + self.assertQuerysetEqual( + self.dataset.dataset_elements.values_list("set", "element__name").order_by("element__name"), + [ + ("training", "Volume 1, page 1r"), + ("validation", "Volume 1, page 1v"), + ("validation", "Volume 1, page 2r"), + ] + ) + def test_update_sets_ambiguous(self): """ No more than one set can be updated """ + # Remove ProcessDataset relation + ProcessDataset.objects.get(process=self.process, dataset=self.dataset).delete() self.client.force_login(self.user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.put( reverse("api:dataset-update", kwargs={"pk": self.dataset.pk}), data={ -- GitLab