diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py index 31c19601f2ebc9b40773293304866ca2c4e6b0c4..a0409723655aade93a86f4fa6e3c3d547afafe6b 100644 --- a/arkindex/ponos/serializers.py +++ b/arkindex/ponos/serializers.py @@ -106,6 +106,7 @@ class TaskLightSerializer(serializers.ModelSerializer): └⟶ Error ├⟶ Failed └⟶ Error Stopping ⟶ Stopped + └⟶ Error """).strip(), ) @@ -133,8 +134,8 @@ class TaskLightSerializer(serializers.ModelSerializer): allowed_transitions = { State.Unscheduled: [State.Pending], State.Pending: [State.Running, State.Error], - State.Running: [State.Completed, State.Failed, State.Stopping, State.Error], - State.Stopping: [State.Stopped], + State.Running: [State.Completed, State.Failed, State.Error], + State.Stopping: [State.Stopped, State.Error], } if self.instance and state not in allowed_transitions.get(self.instance.state, []): raise ValidationError(f"Transition from state {self.instance.state} to state {state} is forbidden.") diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index a416be308fb2226956345c1949b09d7e852096fd..7456f22a268188beb0cf40df7f82cb31b3fc1fe0 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -1275,25 +1275,93 @@ class TestAPI(FixtureAPITestCase): self.assertEqual(self.task1.state, state) @patch("arkindex.ponos.models.TaskLogs.latest", new_callable=PropertyMock) - def test_partial_update_task_from_agent(self, short_logs_mock): + @patch("arkindex.ponos.tasks.notify_process_completion.delay") + def test_partial_update_task_from_agent_allowed_states(self, notify_mock, short_logs_mock): short_logs_mock.return_value = "" - self.task1.state = State.Pending self.task1.agent = self.agent self.task1.save() - with self.assertNumQueries(5): - resp = self.client.patch( - reverse("api:task-details", args=[self.task1.id]), - data={"state": State.Running.value}, - HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}", - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) + cases = [ + (State.Unscheduled, State.Pending, 5), + (State.Pending, State.Running, 5), + (State.Pending, State.Error, 11), + (State.Running, State.Completed, 9), + (State.Running, State.Failed, 9), + (State.Running, State.Error, 9), + (State.Stopping, State.Stopped, 9), + (State.Stopping, State.Error, 9), + ] + + for from_state, to_state, query_count in cases: + with self.subTest(from_state=from_state, to_state=to_state): + self.task1.state = from_state + self.task1.save() - data = resp.json() - self.assertEqual(data["id"], str(self.task1.id)) - self.assertEqual(data["state"], State.Running.value) - self.task1.refresh_from_db() - self.assertEqual(self.task1.state, State.Running) + with self.assertNumQueries(query_count): + resp = self.client.patch( + reverse("api:task-details", args=[self.task1.id]), + data={"state": to_state.value}, + HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + data = resp.json() + self.assertEqual(data["id"], str(self.task1.id)) + self.assertEqual(data["state"], to_state.value) + self.task1.refresh_from_db() + self.assertEqual(self.task1.state, to_state) + + def test_partial_update_task_from_agent_forbidden_states(self): + self.task1.agent = self.agent + self.task1.save() + + cases = [ + (State.Unscheduled, State.Running), + (State.Unscheduled, State.Completed), + (State.Unscheduled, State.Failed), + (State.Unscheduled, State.Error), + (State.Unscheduled, State.Stopping), + (State.Unscheduled, State.Stopped), + (State.Pending, State.Unscheduled), + (State.Pending, State.Completed), + (State.Pending, State.Failed), + (State.Pending, State.Stopping), + (State.Pending, State.Stopped), + (State.Running, State.Unscheduled), + (State.Running, State.Pending), + (State.Running, State.Stopping), + (State.Running, State.Stopped), + (State.Stopping, State.Unscheduled), + (State.Stopping, State.Pending), + (State.Stopping, State.Running), + (State.Stopping, State.Completed), + (State.Stopping, State.Failed), + # Cannot go from one state to the same state + *((state, state) for state in State), + # Cannot go from a final state to anywhere + *((final_state, state) for final_state in FINAL_STATES for state in State), + ] + + for from_state, to_state in cases: + with self.subTest(from_state=from_state, to_state=to_state): + self.task1.state = from_state + self.task1.save() + + with self.assertNumQueries(2): + resp = self.client.put( + reverse("api:task-details", args=[self.task1.id]), + data={"state": to_state.value}, + HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}", + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual( + resp.json(), + {"state": [f"Transition from state {from_state} to state {to_state} is forbidden."]}, + ) + + self.task1.refresh_from_db() + self.assertEqual(self.task1.state, from_state) def test_partial_update_task_from_agent_requires_login(self): with self.assertNumQueries(0):