diff --git a/arkindex/process/serializers/imports.py b/arkindex/process/serializers/imports.py
index a8c5e04c5d39a116c68486a35da4a1bae4818c47..2b3dd034d01ec6b8f26e9d929bae4ef89b4828df 100644
--- a/arkindex/process/serializers/imports.py
+++ b/arkindex/process/serializers/imports.py
@@ -507,6 +507,7 @@ class ApplyProcessTemplateSerializer(ProcessACLMixin, serializers.Serializer):
         # Apply the template by copying all the worker runs on to the new process
         template_process.copy_runs(target_process)
         target_process.template_id = template_process.id
+        target_process.save(update_fields=["template_id"])
         return target_process
 
     def validate_process_id(self, process):
diff --git a/arkindex/process/tests/test_templates.py b/arkindex/process/tests/test_templates.py
index b46a25f507b88d5a042fa5b32f6f0cd319b3ca10..0857e6d34a44fd0610ea737378ba87385dd78e85 100644
--- a/arkindex/process/tests/test_templates.py
+++ b/arkindex/process/tests/test_templates.py
@@ -397,7 +397,7 @@ class TestTemplates(FixtureAPITestCase):
     def test_apply(self):
         self.assertIsNotNone(self.version_2.docker_image_iid)
         self.client.force_login(self.user)
-        with self.assertNumQueries(10):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
                 data={"process_id": str(self.process.id)},
@@ -406,7 +406,10 @@ class TestTemplates(FixtureAPITestCase):
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertEqual(response.json()["template_id"], str(self.template.id))
 
-        child_run, parent_run = WorkerRun.objects.select_related("version__worker").filter(process__id=response.json()["id"]).order_by("version__worker__slug").all()
+        self.process.refresh_from_db()
+        self.assertEqual(self.process.template, self.template)
+
+        child_run, parent_run = self.process.worker_runs.select_related("version__worker").order_by("version__worker__slug")
 
         self.assertEqual(parent_run.process_id, self.process.id)
         self.assertEqual(parent_run.version_id, self.version_1.id)
@@ -426,22 +429,26 @@ class TestTemplates(FixtureAPITestCase):
         self.template_run_1.save()
 
         self.client.force_login(self.user)
-        with self.assertNumQueries(9):
+        with self.assertNumQueries(10):
             response = self.client.post(
                 reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
                 data={"process_id": str(self.process.id)},
                 format="json",
             )
             self.assertEqual(response.status_code, status.HTTP_200_OK)
-        self.assertEqual(response.json()["template_id"], str(self.template.id))
 
-        created_process_id = response.json()["id"]
+        data = response.json()
+        self.assertEqual(data["id"], str(self.process.id))
+        self.assertEqual(data["template_id"], str(self.template.id))
+
+        self.process.refresh_from_db()
+        self.assertEqual(self.process.template, self.template)
 
-        # Only two workers runs in the created process
+        # Only two workers runs in the process
         self.assertEqual(self.template.worker_runs.count(), 3)
-        self.assertEqual(Process.objects.get(id=created_process_id).worker_runs.count(), 2)
+        self.assertEqual(self.process.worker_runs.count(), 2)
 
-        child_run, parent_run = WorkerRun.objects.select_related("version__worker").filter(process__id=created_process_id).order_by("version__worker__slug").all()
+        child_run, parent_run = self.process.worker_runs.select_related("version__worker").order_by("version__worker__slug")
 
         self.assertEqual(parent_run.process_id, self.process.id)
         self.assertEqual(parent_run.version_id, self.version_1.id)
@@ -456,7 +463,7 @@ class TestTemplates(FixtureAPITestCase):
         self.assertListEqual(child_run.parents, [parent_run.id])
 
         # No elements initialisation run in the created process
-        self.assertFalse(WorkerRun.objects.filter(process_id=created_process_id, version_id=WorkerVersion.objects.init_elements_version))
+        self.assertFalse(self.process.worker_runs.filter(version_id=WorkerVersion.objects.init_elements_version).exists())
 
     def test_apply_delete_previous_worker_runs(self):
         self.client.force_login(self.user)
@@ -469,7 +476,7 @@ class TestTemplates(FixtureAPITestCase):
             parents=[],
         )
         # Apply a template that has two other worker runs
-        with self.assertNumQueries(12):
+        with self.assertNumQueries(13):
             response = self.client.post(
                 reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
                 data={"process_id": str(process.id)},
@@ -478,8 +485,11 @@ class TestTemplates(FixtureAPITestCase):
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertEqual(response.json()["template_id"], str(self.template.id))
 
+        process.refresh_from_db()
+        self.assertEqual(process.template, self.template)
+
         # Assert that the previous worker runs was deleted and the template was correctly applied
-        child_run, parent_run = WorkerRun.objects.select_related("version__worker").filter(process__id=response.json()["id"]).order_by("version__worker__slug").all()
+        child_run, parent_run = process.worker_runs.select_related("version__worker").order_by("version__worker__slug")
 
         self.assertEqual(parent_run.process_id, process.id)
         self.assertEqual(parent_run.version_id, self.version_1.id)
@@ -518,7 +528,7 @@ class TestTemplates(FixtureAPITestCase):
             (FeatureUsage.Supported, True),
         ])
 
-        with self.assertNumQueries(10):
+        with self.assertNumQueries(11):
             response = self.client.post(
                 reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
                 {"process_id": str(self.process.id)},