From 0b9d79bf0f759f84e35a5c0490f12ace7abbd162 Mon Sep 17 00:00:00 2001 From: mlbonhomme <bonhomme@teklia.com> Date: Tue, 12 Mar 2024 13:19:19 +0100 Subject: [PATCH] temp dataset process fix --- arkindex/process/api.py | 2 + arkindex/process/serializers/training.py | 6 +- arkindex/process/tests/test_create_process.py | 6 +- .../process/tests/test_process_datasets.py | 164 ++++++++++++------ arkindex/process/tests/test_processes.py | 14 +- arkindex/training/api.py | 4 +- arkindex/training/tests/test_datasets_api.py | 28 +-- 7 files changed, 149 insertions(+), 75 deletions(-) diff --git a/arkindex/process/api.py b/arkindex/process/api.py index e6f76cf645..7d1fd99a82 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -706,6 +706,7 @@ class ProcessDatasets(ProcessACLMixin, ListAPIView): return ( ProcessDataset.objects.filter(process_id=self.process.id) .select_related("process__creator", "dataset__creator") + .prefetch_related("dataset__sets") .order_by("dataset__name") ) @@ -751,6 +752,7 @@ class ProcessDatasetManage(CreateAPIView, UpdateAPIView, DestroyAPIView): process_dataset = get_object_or_404( ProcessDataset.objects .select_related("dataset__creator", "process__corpus") + .prefetch_related("dataset__sets") # Required to check for a process that have already started .annotate(process_has_tasks=Exists(Task.objects.filter(process_id=self.kwargs["process"]))), dataset_id=self.kwargs["dataset"], process_id=self.kwargs["process"] diff --git a/arkindex/process/serializers/training.py b/arkindex/process/serializers/training.py index e7c25f1ec7..f280a16a65 100644 --- a/arkindex/process/serializers/training.py +++ b/arkindex/process/serializers/training.py @@ -89,7 +89,7 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): else: dataset_qs = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user)) try: - dataset = dataset_qs.select_related("creator").get(pk=data["dataset_id"]) + dataset = dataset_qs.select_related("creator").prefetch_related("sets").get(pk=data["dataset_id"]) except Dataset.DoesNotExist: raise ValidationError({"dataset": [f'Invalid pk "{str(data["dataset_id"])}" - object does not exist.']}) else: @@ -109,11 +109,11 @@ class ProcessDatasetSerializer(ProcessACLMixin, serializers.ModelSerializer): sets = data.get("sets") if not sets or len(sets) == 0: if not self.instance: - data["sets"] = dataset.sets + data["sets"] = [item.name for item in list(dataset.sets.all())] else: errors["sets"].append("This field cannot be empty.") else: - if any(s not in dataset.sets for s in sets): + if any(s not in [item.name for item in list(dataset.sets.all())] for s in sets): errors["sets"].append("The specified sets must all exist in the specified dataset.") if len(set(sets)) != len(sets): errors["sets"].append("Sets must be unique.") diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py index 34264e1a04..59e454700c 100644 --- a/arkindex/process/tests/test_create_process.py +++ b/arkindex/process/tests/test_create_process.py @@ -899,7 +899,8 @@ class TestCreateProcess(FixtureAPITestCase): self.client.force_login(self.user) process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) dataset = self.corpus.datasets.first() - ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) + test_sets = list(dataset.sets.values_list("name", flat=True)) + ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) process.versions.set([self.version_2, self.version_3]) with self.assertNumQueries(9): @@ -929,7 +930,8 @@ class TestCreateProcess(FixtureAPITestCase): self.worker_1.save() process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) dataset = self.corpus.datasets.first() - ProcessDataset.objects.create(process=process, dataset=dataset, sets=dataset.sets) + test_sets = list(dataset.sets.values_list("name", flat=True)) + ProcessDataset.objects.create(process=process, dataset=dataset, sets=test_sets) process.versions.add(self.version_1) with self.assertNumQueries(9): diff --git a/arkindex/process/tests/test_process_datasets.py b/arkindex/process/tests/test_process_datasets.py index 5f35cff0b9..1785250e9e 100644 --- a/arkindex/process/tests/test_process_datasets.py +++ b/arkindex/process/tests/test_process_datasets.py @@ -9,7 +9,7 @@ from arkindex.documents.models import Corpus from arkindex.ponos.models import Farm from arkindex.process.models import Process, ProcessDataset, ProcessMode from arkindex.project.tests import FixtureAPITestCase -from arkindex.training.models import Dataset +from arkindex.training.models import Dataset, DatasetSet from arkindex.users.models import Role, User # Using the fake DB fixtures creation date when needed @@ -28,6 +28,10 @@ class TestProcessDatasets(FixtureAPITestCase): description="Human instrumentality manual", creator=cls.user ) + DatasetSet.objects.bulk_create([ + DatasetSet(dataset_id=cls.private_dataset.id, name=set_name) + for set_name in ["validation", "training", "test"] + ]) cls.test_user = User.objects.create(email="katsuragi@nerv.co.jp", verified_email=True) cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value) @@ -40,8 +44,8 @@ class TestProcessDatasets(FixtureAPITestCase): corpus_id=cls.private_corpus.id, farm=Farm.objects.get(name="Wheat farm") ) - cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=cls.dataset1.sets) - cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=cls.private_dataset.sets) + cls.process_dataset_1 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.dataset1, sets=list(cls.dataset1.sets.values_list("name", flat=True))) + cls.process_dataset_2 = ProcessDataset.objects.create(process=cls.dataset_process, dataset=cls.private_dataset, sets=list(cls.private_dataset.sets.values_list("name", flat=True))) # Control process to check that its datasets are not retrieved cls.dataset_process_2 = Process.objects.create( @@ -49,7 +53,7 @@ class TestProcessDatasets(FixtureAPITestCase): mode=ProcessMode.Dataset, corpus_id=cls.corpus.id ) - ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=cls.dataset2.sets) + ProcessDataset.objects.create(process=cls.dataset_process_2, dataset=cls.dataset2, sets=list(cls.dataset2.sets.values_list("name", flat=True))) # List process datasets @@ -78,9 +82,10 @@ class TestProcessDatasets(FixtureAPITestCase): def test_list(self): self.client.force_login(self.test_user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.get(reverse("api:process-datasets", kwargs={"pk": self.dataset_process.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) + self.maxDiff = None self.assertEqual(response.json()["results"], [ { "id": str(self.process_dataset_2.id), @@ -89,7 +94,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "Dead sea scrolls", "description": "Human instrumentality manual", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.private_dataset.sets.all() + ], "set_elements": None, "corpus_id": str(self.private_corpus.id), "state": "open", @@ -97,7 +108,7 @@ class TestProcessDatasets(FixtureAPITestCase): "created": FAKE_CREATED, "updated": FAKE_CREATED }, - "sets": ["training", "test", "validation"] + "sets": ["validation", "training", "test"] }, { "id": str(self.process_dataset_1.id), @@ -106,7 +117,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -114,33 +131,36 @@ class TestProcessDatasets(FixtureAPITestCase): "created": FAKE_CREATED, "updated": FAKE_CREATED }, - "sets": ["training", "test", "validation"] + "sets": ["validation", "training", "test"] } ]) # Create process dataset def test_create_requires_login(self): + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) with self.assertNumQueries(0): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_create_requires_verified(self): unverified_user = User.objects.create(email="email@mail.com") + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) self.client.force_login(unverified_user) with self.assertNumQueries(2): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @patch("arkindex.project.mixins.get_max_level") def test_create_access_level(self, get_max_level_mock): cases = [None, Role.Guest.value, Role.Contributor.value] + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) for level in cases: with self.subTest(level=level): get_max_level_mock.reset_mock() @@ -150,7 +170,7 @@ class TestProcessDatasets(FixtureAPITestCase): with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -161,16 +181,17 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_process_mode(self): cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local} + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) for mode in cases: with self.subTest(mode=mode): self.dataset_process.mode = mode self.dataset_process.save() self.client.force_login(self.test_user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -178,22 +199,24 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_process_mode_local(self): self.client.force_login(self.user) + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local) with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": local_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.json(), {"detail": "You do not have admin access to this process."}) def test_create_wrong_process_uuid(self): self.client.force_login(self.test_user) + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) wrong_id = uuid.uuid4() with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": wrong_id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {"process": [f'Invalid pk "{str(wrong_id)}" - object does not exist.']}) @@ -213,12 +236,13 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_dataset_access(self, filter_rights_mock): new_corpus = Corpus.objects.create(name="NERV") new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) + test_sets = list(new_dataset.sets.values_list("name", flat=True)) self.client.force_login(self.test_user) with self.assertNumQueries(3): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}), - data={"sets": new_dataset.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -229,12 +253,13 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_unique(self): self.client.force_login(self.test_user) + test_sets = list(self.dataset1.sets.values_list("name", flat=True)) self.assertTrue(self.dataset_process.datasets.filter(id=self.dataset1.id).exists()) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), - data={"sets": self.dataset1.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -243,11 +268,12 @@ class TestProcessDatasets(FixtureAPITestCase): def test_create_started(self): self.client.force_login(self.test_user) self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") + test_sets = list(self.dataset2.sets.values_list("name", flat=True)) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), - data={"sets": self.dataset2.sets} + data={"sets": test_sets} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -257,7 +283,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.client.force_login(self.test_user) self.assertEqual(ProcessDataset.objects.count(), 3) self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.post( reverse( "api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id} @@ -272,6 +298,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.dataset2 ]) created = ProcessDataset.objects.get(process=self.dataset_process.id, dataset=self.dataset2.id) + self.maxDiff = None self.assertDictEqual(response.json(), { "id": str(created.id), "dataset": { @@ -279,7 +306,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "Second Dataset", "description": "dataset number two", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset2.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -287,14 +320,14 @@ class TestProcessDatasets(FixtureAPITestCase): "created": FAKE_CREATED, "updated": FAKE_CREATED }, - "sets": ["training", "test", "validation"] + "sets": ["test", "training", "validation"] }) def test_create(self): self.client.force_login(self.test_user) self.assertEqual(ProcessDataset.objects.count(), 3) self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), data={"sets": ["validation", "test"]} @@ -315,7 +348,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "Second Dataset", "description": "dataset number two", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset2.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -330,7 +369,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.client.force_login(self.test_user) self.assertEqual(ProcessDataset.objects.count(), 3) self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists()) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.post( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), data={"sets": ["Unit-01"]} @@ -367,7 +406,7 @@ class TestProcessDatasets(FixtureAPITestCase): if level: self.private_corpus.memberships.create(user=self.test_user, level=level.value) self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -403,7 +442,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_update(self): self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -416,7 +455,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -429,7 +474,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_update_wrong_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["Unit-01", "Unit-02"]} @@ -439,7 +484,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_update_unique_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test", "test"]} @@ -455,7 +500,7 @@ class TestProcessDatasets(FixtureAPITestCase): expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), ) self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -468,7 +513,7 @@ class TestProcessDatasets(FixtureAPITestCase): Non "sets" fields in the update request are ignored """ self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.put( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} @@ -481,7 +526,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -520,7 +571,7 @@ class TestProcessDatasets(FixtureAPITestCase): if level: self.private_corpus.memberships.create(user=self.test_user, level=level.value) self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -556,7 +607,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_partial_update(self): self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -569,7 +620,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -582,7 +639,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_partial_update_wrong_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["Unit-01", "Unit-02"]} @@ -592,7 +649,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_partial_update_unique_sets(self): self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test", "test"]} @@ -608,7 +665,7 @@ class TestProcessDatasets(FixtureAPITestCase): expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), ) self.client.force_login(self.test_user) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"sets": ["test"]} @@ -621,7 +678,7 @@ class TestProcessDatasets(FixtureAPITestCase): Non "sets" fields in the partial update request are ignored """ self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.patch( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), data={"process": str(self.dataset_process_2.id), "dataset": str(self.dataset2.id), "sets": ["test"]} @@ -634,7 +691,13 @@ class TestProcessDatasets(FixtureAPITestCase): "name": "First Dataset", "description": "dataset number one", "creator": "Test user", - "sets": ["training", "test", "validation"], + "sets": [ + { + "id": str(ds.id), + "name": ds.name + } + for ds in self.dataset1.sets.all() + ], "set_elements": None, "corpus_id": str(self.corpus.id), "state": "open", @@ -677,7 +740,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_not_found(self): self.assertFalse(self.dataset_process.datasets.filter(id=self.dataset2.id).exists()) self.client.force_login(self.test_user) - with self.assertNumQueries(5): + with self.assertNumQueries(6): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), ) @@ -700,10 +763,11 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy_no_dataset_access_requirement(self): new_corpus = Corpus.objects.create(name="NERV") new_dataset = new_corpus.datasets.create(name="Eva series", description="We created the Evas from Adam", creator=self.user) - ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=new_dataset.sets) + test_sets = list(new_dataset.sets.values_list("name", flat=True)) + ProcessDataset.objects.create(process=self.dataset_process, dataset=new_dataset, sets=test_sets) self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists()) self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": new_dataset.id}), ) @@ -718,7 +782,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.dataset_process.save() self.client.force_login(self.test_user) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset2.id}), ) @@ -740,7 +804,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.client.force_login(self.test_user) self.dataset_process.tasks.create(run=0, depth=0, slug="makrout") - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), ) @@ -750,7 +814,7 @@ class TestProcessDatasets(FixtureAPITestCase): def test_destroy(self): self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), ) @@ -765,7 +829,7 @@ class TestProcessDatasets(FixtureAPITestCase): self.process_dataset_1.sets = ["test"] self.process_dataset_1.save() self.client.force_login(self.test_user) - with self.assertNumQueries(6): + with self.assertNumQueries(7): response = self.client.delete( reverse("api:process-dataset", kwargs={"process": self.dataset_process.id, "dataset": self.dataset1.id}), ) diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index 8a4d715cd2..b897d3250d 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -2324,7 +2324,7 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset_requires_dataset_in_same_corpus(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True))) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2341,8 +2341,8 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset_unsupported_parameters(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.dataset2.sets) + ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.dataset2.sets.values_list("name", flat=True))) process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.client.force_login(self.user) @@ -2366,8 +2366,8 @@ class TestProcesses(FixtureAPITestCase): def test_start_process_dataset(self): process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=self.dataset1.sets) - ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=self.private_dataset.sets) + ProcessDataset.objects.create(process=process2, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) + ProcessDataset.objects.create(process=process2, dataset=self.private_dataset, sets=list(self.private_dataset.sets.values_list("name", flat=True))) run = process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None) self.assertFalse(process2.tasks.exists()) @@ -2562,8 +2562,8 @@ class TestProcesses(FixtureAPITestCase): It should be possible to pass chunks when starting a dataset process """ process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset) - ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=self.dataset1.sets) - ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=self.dataset2.sets) + ProcessDataset.objects.create(process=process, dataset=self.dataset1, sets=list(self.dataset1.sets.values_list("name", flat=True))) + ProcessDataset.objects.create(process=process, dataset=self.dataset2, sets=list(self.dataset2.sets.values_list("name", flat=True))) # Add a worker run to this process run = process.worker_runs.create(version=self.recognizer, parents=[], configuration=None) diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 530b759921..f0b71f16bc 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -954,9 +954,7 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): serializer_class = DatasetSerializer def get_queryset(self): - return ( - Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) - ) + return Dataset.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)) def check_object_permissions(self, request, dataset): if not self.has_write_access(dataset.corpus): diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index 20b854a193..72149bfc3c 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -2118,28 +2118,36 @@ class TestDatasetsAPI(FixtureAPITestCase): data = response.json() data.pop("created") data.pop("updated") - cloned_dataset = Dataset.objects.get(id=data["id"]) - self.maxDiff = None + clone = Dataset.objects.get(id=data["id"]) + test_clone, train_clone, val_clone = clone.sets.all().order_by("name") + cloned_sets = data.pop("sets") self.assertDictEqual( response.json(), { - "id": str(cloned_dataset.id), + "id": str(clone.id), "name": "Clone of First Dataset 1", "description": self.dataset.description, "creator": self.user.display_name, "corpus_id": str(self.corpus.id), - "sets": [ - { - "id": str(ds.id), - "name": ds.name - } - for ds in cloned_dataset.sets.all() - ], "set_elements": {str(k.name): 0 for k in self.dataset.sets.all()}, "state": DatasetState.Open.value, "task_id": None, }, ) + self.assertCountEqual(cloned_sets, [ + { + "name": "training", + "id": str(train_clone.id) + }, + { + "name": "test", + "id": str(test_clone.id) + }, + { + "name": "validation", + "id": str(val_clone.id) + } + ]) def test_clone_name_too_long(self): dataset = self.corpus.datasets.create(name="A" * 99, creator=self.user) -- GitLab