diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 42cefcecb51537547d5dc4513b98885a86283f7b..259053083f27ad8ba795569215d7ac8c01e9b0f5 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -712,6 +712,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): return ( ProcessDataset.objects.filter(process_id=self.process.id) .select_related("process__creator", "dataset__creator") + .prefetch_related("dataset__sets") .order_by("dataset__name") ) @@ -757,6 +758,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): process_dataset = get_object_or_404( ProcessDataset.objects .select_related("dataset__creator", "process__corpus") + .prefetch_related("dataset__sets") # Required to check for a process that have already started .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))), dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"] diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py index f56fc916a99a1821d24d64cc354fd40b2ac412c9..478c7a9f61e4cd89a66294f41de189a3c97e2864 100644 --- a/arkindex/process/serializers/training.py +++ b/arkindex/process/serializers/training.py @@ -89,7 +89,7 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): else: dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) try: - dataset = dataset_qs.select_related("creator").get(pk=data["dataset_id"]) + dataset = dataset_qs.select_related("creator").prefetch_related("sets").get(pk=data["dataset_id"]) except Dataset.DoesNotExist: raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']}) else: @@ -109,11 +109,11 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): sets = data.get("sets") if not sets or len(sets) == 0: if not self.instance: - data["sets"] = dataset.sets + data["sets"] = [item.name for item in list(dataset.sets.all())] else: errors["sets"].append("This field cannot be empty.") else: - if any(s not in dataset.sets for s in sets): + if any(s not in [item.name for item in list(dataset.sets.all())] for s in sets): errors["sets"].append("The specified sets must all exist in the specified dataset.") if len(set(sets)) != len(sets): errors["sets"].append("Sets must be unique.") diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py index cfb1f59958dd66a00fe54fd99f65bce7fdf9e388..c2aff4188959ba167e010b3154fa2d88875d9b92 100644 --- a/arkindex/process/tests/test_create_process.py +++ b/arkindex/process/tests/test_create_process.py @@ -907,7 +907,8 @@ class TestCreateProcess(FixtureAPITestCase): self.client.force_login(self.user) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) dataset = self.corpus.datasets.first() - ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) + test_sets = list(dataset.sets.values_list("name", flat=True)) + ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) process.versions.set([self.version_2, self.version_3]) with self.assertNumQueries(9): @@ -937,7 +938,8 @@ class TestCreateProcess(FixtureAPITestCase): self.worker_1.save() process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) dataset = self.corpus.datasets.first() - ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) + test_sets = list(dataset.sets.values_list("name", flat=True)) + ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) process.versions.add(self.version_1) with self.assertNumQueries(9): diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py index af98d7b25395164688b614c6dcc26b31818fcaca..158c819ec34ca2f3ca410d7fafecad85b5a62420 100644 --- a/arkindex/process/tests/test_process_datasets.py +++ b/arkindex/process/tests/test_process_datasets.py @@ -9,7 +9,7 @@ from arkindex.documents.models import Corpus from arkindex.ponos.models import Farm from arkindex.process.models import Process, ProcessDataset, ProcessMode, Repository from arkindex.project.tests import FixtureAPITestCase -from arkindex.training.models import Dataset +from arkindex.training.models import Dataset, DatasetSet from arkindex.users.models import Role, User # Using the fake DB fixtures creation date when needed @@ -28,6 +28,10 @@ class TestProcessDatasets(FixtureAPITestCase): description="Human instrumentality manual", creator=cls.user ) + DatasetSet.objects.bulk_create([ + DatasetSet(dataset_id=cls.private_dataset.id, name=set_name) + for set_name in ["validation", "training", "test"] + ]) cls.test_user = User.objects.create(email="katsuragi@nerv.co.jp", verified_email=True) cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value) @@ -40,8 +44,8 @@ class TestProcessDatasets(FixtureAPITestCase): corpus_id=cls.private_corpus.id, farm=Farm.objects.get(name="Wheat farm") ) - cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=cls.dataset1.sets) - cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=cls.private_dataset.sets) + cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=list(cls.dataset1.sets.values_list("name", flat=True))) + cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=list(cls.private_dataset.sets.values_list("name", flat=True))) # Control process to check that its datasets are not retrieved cls.dataset_process_2 = Process.objects.create( @@ -49,7 +53,7 @@ class TestProcessDatasets(FixtureAPITestCase): mode=ProcessMode.Dataset, corpus_id=cls.corpus.id ) - ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=cls.dataset2.sets) + ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True))) # For repository process cls.repo = Repository.objects.get(url="http://my_repo.fake/workers/worker") @@ -83,9 +87,10 @@ class TestProcessDatasets(FixtureAPITestCase): def test_list(self): self.client.force_login(self.test_user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.get(reverse("api:process-datasets", kwargs={"pk": self.dataset_process.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) + self.maxDiff = None self.assertEqual(response.json()["results"], [ { "id": str(self.process_dataset_2.id), @@ -94,7 +99,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "Dead sea scrolls", "description": "Human instrumentality manual", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.private_dataset.sets.all() + ], "set_elements": None, "corpus_id": str(self.private_corpus.id), "state": "open", @@ -102,7 +113,7 @@ class TestProcessDatasets(FixtureAPITestCase): "created": FAKE_CREATED, "updated": FAKE_CREATED }, - "sets": ["training", "test", "validation"] + "sets": ["validation", "training", "test"] }, { "id": str(self.process_dataset_1.id), @@ -111,7 +122,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -119,33 +136,36 @@ class TestProcessDatasets(FixtureAPITestCase): "created": FAKE_CREATED, "updated": FAKE_CREATED }, - "sets": ["training", "test", "validation"] + "sets": ["validation", "training", "test"] } ]) # Create process dataset def test_create_requires_login(self): + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) with self.assertNumQueries(0): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_create_requires_verified(self): unverified_user = User.objects.create(email="email@mail.com") + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) self.client.force_login(unverified_user) with self.assertNumQueries(2): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @patch("arkindex.project.mixins.get_max_level") def test_create_access_level(self, get_max_level_mock): cases = [None, Role.Guest.value, Role.Contributor.value] + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) for level in cases: with self.subTest(level=level): get_max_level_mock.reset_mock() @@ -155,7 +175,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -166,16 +186,17 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_process_mode(self): cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local, ProcessMode.Repository} + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) for mode in cases: with self.subTest(mode=mode): self.dataset_process.mode = mode self.dataset_process.save() self.client.force_login(self.test_user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -183,22 +204,24 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_process_mode_local(self): self.client.force_login(self.user) + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) def test_create_wrong_process_uuid(self): self.client.force_login(self.test_user) + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) wrong_id = uuid.uuid4() with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) @@ -218,12 +241,13 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_dataset_access(self, filter_rights_mock): new_corpus = Corpus.objects.create(name="NERV") new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) + test_sets = list(new_dataset.sets.values_list("name", flat=True)) self.client.force_login(self.test_user) with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}), - data={"sets": new_dataset.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -234,12 +258,13 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_unique(self): self.client.force_login(self.test_user) + test_sets = list(self.dataset1.sets.values_list("name", flat=True)) self.assertTrue(self.dataset_process.datasets.filter(id=self.dataset1.id).exists()) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": self.dataset1.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -248,11 +273,12 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_started(self): self.client.force_login(self.test_user) self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -262,7 +288,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.client.force_login(self.test_user) self.assertEqual(ProcessDataset.objects.count(), 3) self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.post( reverse( "api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id} @@ -277,6 +303,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.dataset2 ]) created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id) + self.maxDiff = None self.assertDictEqual(response.json(), { "id": str(created.id), "dataset": { @@ -284,7 +311,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "Second Dataset", "description": "dataset number two", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset2.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -292,14 +325,14 @@ class TestProcessDatasets(FixtureAPITestCase): "created": FAKE_CREATED, "updated": FAKE_CREATED }, - "sets": ["training", "test", "validation"] + "sets": ["test", "training", "validation"] }) def test_create(self): self.client.force_login(self.test_user) self.assertEqual(ProcessDataset.objects.count(), 3) self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), data={"sets": ["validation", "test"]} @@ -320,7 +353,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "Second Dataset", "description": "dataset number two", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset2.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -335,7 +374,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.client.force_login(self.test_user) self.assertEqual(ProcessDataset.objects.count(), 3) self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), data={"sets": ["Unit-01"]} @@ -372,7 +411,7 @@ class TestProcessDatasets(FixtureAPITestCase): if level: self.private_corpus.memberships.create(user=self.test_user, level=level.value) self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -408,7 +447,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_update(self): self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -421,7 +460,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -434,7 +479,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_update_wrong_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["Unit-01", "Unit-02"]} @@ -444,7 +489,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_update_unique_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test", "test"]} @@ -460,7 +505,7 @@ class TestProcessDatasets(FixtureAPITestCase): expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), ) self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -473,7 +518,7 @@ class TestProcessDatasets(FixtureAPITestCase): Non "sets" fields in the update request are ignored """ self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} @@ -486,7 +531,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -525,7 +576,7 @@ class TestProcessDatasets(FixtureAPITestCase): if level: self.private_corpus.memberships.create(user=self.test_user, level=level.value) self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -561,7 +612,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_partial_update(self): self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -574,7 +625,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -587,7 +644,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_partial_update_wrong_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["Unit-01", "Unit-02"]} @@ -597,7 +654,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_partial_update_unique_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test", "test"]} @@ -613,7 +670,7 @@ class TestProcessDatasets(FixtureAPITestCase): expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), ) self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -626,7 +683,7 @@ class TestProcessDatasets(FixtureAPITestCase): Non "sets" fields in the partial update request are ignored """ self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} @@ -639,7 +696,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -682,7 +745,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_not_found(self): self.assertFalse(self.dataset_process.datasets.filter(id=self.dataset2.id).exists()) self.client.force_login(self.test_user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), ) @@ -705,10 +768,11 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_no_dataset_access_requirement(self): new_corpus = Corpus.objects.create(name="NERV") new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) - ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=new_dataset.sets) + test_sets = list(new_dataset.sets.values_list("name", flat=True)) + ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=test_sets) self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists()) self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}), ) @@ -723,7 +787,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.dataset_process.save() self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), ) @@ -745,7 +809,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.client.force_login(self.test_user) self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), ) @@ -755,7 +819,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy(self): self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), ) @@ -770,7 +834,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.process_dataset_1.sets = ["test"] self.process_dataset_1.save() self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), ) diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index 1af5d378662ac38aa8ecf1b4cc4a65b6d2a206a2..8683abd689f7a64048ebd136466978be8e21b3c4 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -2396,7 +2396,7 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset_requires_dataset_in_same_corpus(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True))) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2413,8 +2413,8 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset_unsupported_parameters(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.dataset2.sets) + ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.dataset2.sets.values_list("name", flat=True))) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.client.force_login(self.user) @@ -2438,8 +2438,8 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) + ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True))) run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2634,8 +2634,8 @@ class TestProcesses(FixtureAPITestCase): It should be possible to pass chunks when starting a dataset process """ process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=self.dataset1.sets) - ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=self.dataset2.sets) + ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) + ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=list(self.dataset2.sets.values_list("name", flat=True))) # Add a worker run to this process run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None) diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 530b7599210b9e9dcd26e0af5949b006b6689027..f0b71f16bcd8564738f18a24be8b37903dd8e35f 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -954,9 +954,7 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): serializer_class = DatasetSerializer def get_queryset(self): - return ( - Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) - ) + return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) def check_object_permissions(self, request, dataset): if not self.has_write_access(dataset.corpus): diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 6db551191b984e564c6eaa382ae73a2a3e95b8f5..58730286b8590dd94e1f207184a9421441c28b41 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -2118,28 +2118,36 @@ class TestDatasetsAPI(FixtureAPITestCase): data = response.json() data.pop("created") data.pop("updated") - cloned_dataset = Dataset.objects.get(id=data["id"]) - self.maxDiff = None + clone = Dataset.objects.get(id=data["id"]) + test_clone, train_clone, val_clone = clone.sets.all().order_by("name") + cloned_sets = data.pop("sets") self.assertDictEqual( response.json(), { - "id": str(cloned_dataset.id), + "id": str(clone.id), "name": "Clone of First Dataset 1", "description": self.dataset.description, "creator": self.user.display_name, "corpus_id": str(self.corpus.id), - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in cloned_dataset.sets.all() - ], "set_elements": {str(k.name): 0 for k in self.dataset.sets.all()}, "state": DatasetState.Open.value, "task_id": None, }, ) + self.assertCountEqual(cloned_sets, [ + { + "name": "training", + "id": str(train_clone.id) + }, + { + "name": "test", + "id": str(test_clone.id) + }, + { + "name": "validation", + "id": str(val_clone.id) + } + ]) def test_clone_name_too_long(self): dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user)