diff --git a/arkindex/process/models.py b/arkindex/process/models.py index bd963b51ac17a520f0d0631c0cfde391643654b2..d9932105751a09889d8101c41d1d60731cb594e6 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -314,24 +314,28 @@ class Process(IndexableModel): # Keep the relation between template_worker_runs and new_process_worker_runs ids new_runs = {} - # First we copy each worker runs + # First we copy each worker runs, excluding elements initialisation worker runs for run in self.worker_runs.all(): - # Create a new WorkerRun with same version, configuration and parents. - new_run = WorkerRun( - process=new_process, - version_id=run.version_id, - model_version_id=run.model_version_id, - parents=run.parents, - configuration_id=run.configuration_id, - summary=run.summary - ) - # Save the correspondence between this process' worker_run and the new one - new_runs[run.id] = new_run + if run.version_id != WorkerVersion.objects.init_elements_version.id: + # Create a new WorkerRun with same version, configuration and parents. + new_run = WorkerRun( + process=new_process, + version_id=run.version_id, + model_version_id=run.model_version_id, + parents=run.parents, + configuration_id=run.configuration_id, + summary=run.summary + ) + # Save the correspondence between this process' worker_run and the new one + new_runs[run.id] = new_run # Remap parent ids correctly for run in new_runs.values(): # If there are parents, we need to use the correct worker_run id, the newly created one - run.parents = [new_runs[parent_id].id for parent_id in run.parents] + run.parents = [ + new_runs[parent_id].id for parent_id in run.parents + if parent_id in new_runs + ] WorkerRun.objects.bulk_create(new_runs.values()) def list_elements(self): diff --git a/arkindex/process/tests/test_templates.py b/arkindex/process/tests/test_templates.py index 8aeb10ea91a482419b09f728267c081371da931a..287af1b50ecedb592e7dc01bfa3ef1cfb0165291 100644 --- a/arkindex/process/tests/test_templates.py +++ b/arkindex/process/tests/test_templates.py @@ -6,7 +6,14 @@ from rest_framework import status from rest_framework.reverse import reverse from arkindex.documents.models import Corpus -from arkindex.process.models import ProcessMode, WorkerConfiguration, WorkerRun, WorkerVersion, WorkerVersionState +from arkindex.process.models import ( + Process, + ProcessMode, + WorkerConfiguration, + WorkerRun, + WorkerVersion, + WorkerVersionState, +) from arkindex.project.tests import FixtureAPITestCase from arkindex.training.models import Model, ModelVersionState from arkindex.users.models import Role, User @@ -54,23 +61,23 @@ class TestTemplates(FixtureAPITestCase): name="A config", configuration={"param1": "value1"}, ) - run_1 = cls.process_template.worker_runs.create( + cls.run_1 = cls.process_template.worker_runs.create( version=cls.version_1, parents=[], configuration=cls.worker_configuration ) cls.process_template.worker_runs.create( version=cls.version_2, - parents=[run_1.id], + parents=[cls.run_1.id], ) cls.model = Model.objects.create(name="moo") cls.model_version = cls.model.versions.create(state=ModelVersionState.Available) - run_1 = cls.template.worker_runs.create( + cls.template_run_1 = cls.template.worker_runs.create( version=cls.version_1, parents=[], configuration=cls.worker_configuration ) cls.template.worker_runs.create( version=cls.version_2, - parents=[run_1.id], + parents=[cls.template_run_1.id], model_version=cls.model_version, ) @@ -83,7 +90,7 @@ class TestTemplates(FixtureAPITestCase): def test_create(self): self.client.force_login(self.user) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.post( reverse( "api:create-process-template", kwargs={"pk": str(self.process_template.id)} @@ -102,6 +109,41 @@ class TestTemplates(FixtureAPITestCase): self.assertTrue(self.process_template.worker_runs.filter(version=parent_run.version).exists()) self.assertTrue(self.process_template.worker_runs.filter(version=child_run.version).exists()) + def test_create_excludes_init_elements(self): + init_run = self.process_template.worker_runs.create(version=WorkerVersion.objects.init_elements_version) + self.run_1.parents = [init_run.id] + self.run_1.save() + + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.post( + reverse( + "api:create-process-template", kwargs={"pk": str(self.process_template.id)} + ), + {"name": "test_template"}, + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json()["mode"], "template") + + new_process_id = response.json()["id"] + + # Only two worker runs for the created template + self.assertEqual(self.process_template.worker_runs.count(), 3) + self.assertEqual(Process.objects.get(id=new_process_id).worker_runs.count(), 2) + + # No elements initialisation run in the created template + self.assertFalse(WorkerRun.objects.filter(process_id=new_process_id, version_id=WorkerVersion.objects.init_elements_version).exists()) + + child_run, parent_run = WorkerRun.objects.select_related("version__worker").filter(process__id=new_process_id).order_by("version__worker__slug").all() + # Check dependencies + self.assertListEqual(parent_run.parents, []) + self.assertListEqual(child_run.parents, [parent_run.id]) + + # Check that every new worker_run is the same as one of the template's + self.assertTrue(self.process_template.worker_runs.filter(version=parent_run.version).exists()) + self.assertTrue(self.process_template.worker_runs.filter(version=child_run.version).exists()) + def test_create_requires_authentication(self): response = self.client.post( reverse("api:create-process-template", kwargs={"pk": str(self.process_template.id)}) @@ -310,7 +352,7 @@ class TestTemplates(FixtureAPITestCase): def test_apply(self): self.assertIsNotNone(self.version_2.docker_image_iid) self.client.force_login(self.user) - with self.assertNumQueries(9): + with self.assertNumQueries(10): response = self.client.post( reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), data=json.dumps({"process_id": str(self.process.id)}), @@ -333,6 +375,44 @@ class TestTemplates(FixtureAPITestCase): self.assertIsNone(child_run.configuration_id) self.assertListEqual(child_run.parents, [parent_run.id]) + def test_apply_excludes_init_elements(self): + init_run = self.template.worker_runs.create(version=WorkerVersion.objects.init_elements_version) + self.template_run_1.parents = [init_run.id] + self.template_run_1.save() + + self.client.force_login(self.user) + with self.assertNumQueries(9): + response = self.client.post( + reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), + data=json.dumps({"process_id": str(self.process.id)}), + content_type="application/json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["template_id"], str(self.template.id)) + + created_process_id = response.json()["id"] + + # Only two workers runs in the created process + self.assertEqual(self.template.worker_runs.count(), 3) + self.assertEqual(Process.objects.get(id=created_process_id).worker_runs.count(), 2) + + child_run, parent_run = WorkerRun.objects.select_related("version__worker").filter(process__id=created_process_id).order_by("version__worker__slug").all() + + self.assertEqual(parent_run.process_id, self.process.id) + self.assertEqual(parent_run.version_id, self.version_1.id) + self.assertIsNone(parent_run.model_version_id) + self.assertEqual(parent_run.configuration_id, self.worker_configuration.id) + self.assertListEqual(parent_run.parents, []) + + self.assertEqual(child_run.process_id, self.process.id) + self.assertEqual(child_run.version_id, self.version_2.id) + self.assertEqual(child_run.model_version_id, self.model_version.id) + self.assertIsNone(child_run.configuration_id) + self.assertListEqual(child_run.parents, [parent_run.id]) + + # No elements initialisation run in the created process + self.assertFalse(WorkerRun.objects.filter(process_id=created_process_id, version_id=WorkerVersion.objects.init_elements_version)) + def test_apply_delete_previous_worker_runs(self): self.client.force_login(self.user) # Create a process with one worker run already @@ -344,7 +424,7 @@ class TestTemplates(FixtureAPITestCase): parents=[], ) # Apply a template that has two other worker runs - with self.assertNumQueries(11): + with self.assertNumQueries(12): response = self.client.post( reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}), data=json.dumps({"process_id": str(process.id)}),