From 68ab41b01492fd500dfc124d19be4c503ee75800 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Tue, 20 Feb 2024 10:52:20 +0100
Subject: [PATCH] Allow agents to update tasks from Stopping to Error

---
 arkindex/ponos/serializers.py    |  5 +-
 arkindex/ponos/tests/test_api.py | 96 +++++++++++++++++++++++++++-----
 2 files changed, 85 insertions(+), 16 deletions(-)

diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py
index 31c19601f2..a040972365 100644
--- a/arkindex/ponos/serializers.py
+++ b/arkindex/ponos/serializers.py
@@ -106,6 +106,7 @@ class TaskLightSerializer(serializers.ModelSerializer):
                   └⟶ Error   ├⟶ Failed
                              └⟶ Error
                 Stopping ⟶ Stopped
+                  └⟶ Error
         """).strip(),
     )
 
@@ -133,8 +134,8 @@ class TaskLightSerializer(serializers.ModelSerializer):
         allowed_transitions = {
             State.Unscheduled: [State.Pending],
             State.Pending: [State.Running, State.Error],
-            State.Running: [State.Completed, State.Failed, State.Stopping, State.Error],
-            State.Stopping: [State.Stopped],
+            State.Running: [State.Completed, State.Failed, State.Error],
+            State.Stopping: [State.Stopped, State.Error],
         }
         if self.instance and state not in allowed_transitions.get(self.instance.state, []):
             raise ValidationError(f"Transition from state {self.instance.state} to state {state} is forbidden.")
diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py
index a416be308f..7456f22a26 100644
--- a/arkindex/ponos/tests/test_api.py
+++ b/arkindex/ponos/tests/test_api.py
@@ -1275,25 +1275,93 @@ class TestAPI(FixtureAPITestCase):
                 self.assertEqual(self.task1.state, state)
 
     @patch("arkindex.ponos.models.TaskLogs.latest", new_callable=PropertyMock)
-    def test_partial_update_task_from_agent(self, short_logs_mock):
+    @patch("arkindex.ponos.tasks.notify_process_completion.delay")
+    def test_partial_update_task_from_agent_allowed_states(self, notify_mock, short_logs_mock):
         short_logs_mock.return_value = ""
-        self.task1.state = State.Pending
         self.task1.agent = self.agent
         self.task1.save()
 
-        with self.assertNumQueries(5):
-            resp = self.client.patch(
-                reverse("api:task-details", args=[self.task1.id]),
-                data={"state": State.Running.value},
-                HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}",
-            )
-            self.assertEqual(resp.status_code, status.HTTP_200_OK)
+        cases = [
+            (State.Unscheduled, State.Pending, 5),
+            (State.Pending, State.Running, 5),
+            (State.Pending, State.Error, 11),
+            (State.Running, State.Completed, 9),
+            (State.Running, State.Failed, 9),
+            (State.Running, State.Error, 9),
+            (State.Stopping, State.Stopped, 9),
+            (State.Stopping, State.Error, 9),
+        ]
+
+        for from_state, to_state, query_count in cases:
+            with self.subTest(from_state=from_state, to_state=to_state):
+                self.task1.state = from_state
+                self.task1.save()
 
-        data = resp.json()
-        self.assertEqual(data["id"], str(self.task1.id))
-        self.assertEqual(data["state"], State.Running.value)
-        self.task1.refresh_from_db()
-        self.assertEqual(self.task1.state, State.Running)
+                with self.assertNumQueries(query_count):
+                    resp = self.client.patch(
+                        reverse("api:task-details", args=[self.task1.id]),
+                        data={"state": to_state.value},
+                        HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}",
+                    )
+                    self.assertEqual(resp.status_code, status.HTTP_200_OK)
+
+                data = resp.json()
+                self.assertEqual(data["id"], str(self.task1.id))
+                self.assertEqual(data["state"], to_state.value)
+                self.task1.refresh_from_db()
+                self.assertEqual(self.task1.state, to_state)
+
+    def test_partial_update_task_from_agent_forbidden_states(self):
+        self.task1.agent = self.agent
+        self.task1.save()
+
+        cases = [
+            (State.Unscheduled, State.Running),
+            (State.Unscheduled, State.Completed),
+            (State.Unscheduled, State.Failed),
+            (State.Unscheduled, State.Error),
+            (State.Unscheduled, State.Stopping),
+            (State.Unscheduled, State.Stopped),
+            (State.Pending, State.Unscheduled),
+            (State.Pending, State.Completed),
+            (State.Pending, State.Failed),
+            (State.Pending, State.Stopping),
+            (State.Pending, State.Stopped),
+            (State.Running, State.Unscheduled),
+            (State.Running, State.Pending),
+            (State.Running, State.Stopping),
+            (State.Running, State.Stopped),
+            (State.Stopping, State.Unscheduled),
+            (State.Stopping, State.Pending),
+            (State.Stopping, State.Running),
+            (State.Stopping, State.Completed),
+            (State.Stopping, State.Failed),
+            # Cannot go from one state to the same state
+            *((state, state) for state in State),
+            # Cannot go from a final state to anywhere
+            *((final_state, state) for final_state in FINAL_STATES for state in State),
+        ]
+
+        for from_state, to_state in cases:
+            with self.subTest(from_state=from_state, to_state=to_state):
+                self.task1.state = from_state
+                self.task1.save()
+
+                with self.assertNumQueries(2):
+                    resp = self.client.put(
+                        reverse("api:task-details", args=[self.task1.id]),
+                        data={"state": to_state.value},
+                        HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}",
+                    )
+                    self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
+
+                self.assertDictEqual(
+                    resp.json(),
+                    {"state": [f"Transition from state {from_state} to state {to_state} is forbidden."]},
+                )
+
+                self.task1.refresh_from_db()
+                self.assertEqual(self.task1.state, from_state)
 
     def test_partial_update_task_from_agent_requires_login(self):
         with self.assertNumQueries(0):
-- 
GitLab