diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py index 0b28c7801122872aa022581f49245adebcf7b6c3..932cc7c0c9586702826cabe665afc0adba0c47f9 100644 --- a/arkindex/ponos/serializers.py +++ b/arkindex/ponos/serializers.py @@ -165,15 +165,7 @@ class TaskTinySerializer(TaskSerializer): state = EnumField( State, - help_text=dedent(""" - Allowed transitions for the state of a task by a user are defined below: - - Completed ⟶ Pending - Failed ⟶ Pending - Error ⟶ Pending - Stopped ⟶ Pending - Running ⟶ Stopping - """).strip(), + help_text="The state can only be updated from `running` to `stopping`, to manually stop a task.", ) class Meta: @@ -183,32 +175,19 @@ class TaskTinySerializer(TaskSerializer): def validate_state(self, state): """ - Only allow a user to manually stop or retry a task + Only allow a user to manually stop a task """ - allowed_transitions = { - state: [State.Pending] for state in FINAL_STATES - } - allowed_transitions.update({State.Running: [State.Stopping]}) - if self.instance and state not in allowed_transitions.get(self.instance.state, []): - raise ValidationError(f"Transition from state {self.instance.state} to state {state} is forbidden.") + if ( + self.instance + and self.instance.state != state + and (self.instance.state != State.Running or state != State.Stopping) + ): + raise ValidationError("State can only be updated from running to stopping") return state def update(self, instance: Task, validated_data) -> Task: new_state = validated_data.get("state") - if new_state == State.Pending: - if instance.state == State.Unscheduled: - # Prevent a user from restarting a task that was never assigned to an agent. - raise ValidationError({ - "state": [f"Transition from state {State.Unscheduled} to state {State.Pending} is forbidden."] - }) - # Restart the task - instance.agent = None - instance.gpu = None - # Un-finish the process since a task will run again - instance.process.finished = None - instance.process.save() - if new_state in FINAL_STATES and new_state != State.Completed: task_failure.send_robust(self.__class__, task=instance) diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index 5c05f001402e29cb25ac3cecc57a90ba413738c5..dfd46c1d74984f64404bdb633c1290e49d0ab5f2 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -9,7 +9,7 @@ from django.urls import reverse from rest_framework import status from arkindex.documents.models import Corpus -from arkindex.ponos.models import FINAL_STATES, State, Task +from arkindex.ponos.models import State, Task from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion from arkindex.project.tests import FixtureAPITestCase from arkindex.users.models import Right, Role, User @@ -287,6 +287,7 @@ class TestAPI(FixtureAPITestCase): def test_update_non_running_task_state_stopping(self): states = list(State) states.remove(State.Running) + states.remove(State.Stopping) self.client.force_login(self.superuser) for state in states: @@ -302,51 +303,46 @@ class TestAPI(FixtureAPITestCase): self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual( resp.json(), - {"state": [f"Transition from state {state} to state Stopping is forbidden."]}, + {"state": ["State can only be updated from running to stopping"]}, ) self.task1.refresh_from_db() self.assertEqual(self.task1.state, state) - def test_update_final_task_state_pending(self): - self.task1.state = State.Completed + def test_update_running_task_non_pending(self): + states = list(State) + states.remove(State.Running) + states.remove(State.Stopping) + self.task1.state = State.Running self.task1.save() - self.client.force_login(self.superuser) - - with self.assertNumQueries(5): - resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), - data={"state": State.Pending.value}, - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.task1.refresh_from_db() - self.assertEqual(self.task1.state, State.Pending) - self.assertIsNone(self.task1.agent) - self.assertIsNone(self.task1.gpu) - - def test_update_non_final_task_state_pending(self): - states = set(State) - set(FINAL_STATES) self.client.force_login(self.superuser) - for state in states: with self.subTest(state=state): - self.task1.state = state - self.task1.save() - with self.assertNumQueries(3): resp = self.client.put( reverse("api:task-update", args=[self.task1.id]), - data={"state": State.Pending.value}, + data={"state": state.value}, ) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual( resp.json(), - {"state": [f"Transition from state {state} to state Pending is forbidden."]} + {"state": ["State can only be updated from running to stopping"]}, ) - self.task1.refresh_from_db() - self.assertEqual(self.task1.state, state) + + def test_update_same_state(self): + states = list(State) + + self.client.force_login(self.superuser) + for state in states: + with self.subTest(state=state): + self.task1.state = state + self.task1.save() + resp = self.client.put( + reverse("api:task-update", args=[self.task1.id]), + data={"state": state.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) def test_partial_update_task_from_agent_requires_login(self): with self.assertNumQueries(0): @@ -424,29 +420,10 @@ class TestAPI(FixtureAPITestCase): ) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) - def test_partial_update_running_task_state_stopping(self): - self.task1.state = State.Running - self.task1.save() - self.client.force_login(self.superuser) - - with self.assertNumQueries(4): - resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), - data={"state": State.Stopping.value}, - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - - self.assertDictEqual(resp.json(), { - "id": str(self.task1.id), - "state": State.Stopping.value, - }) - self.task1.refresh_from_db() - self.assertEqual(self.task1.state, State.Stopping) - def test_partial_update_non_running_task_state_stopping(self): states = list(State) states.remove(State.Running) - self.task1.save() + states.remove(State.Stopping) self.client.force_login(self.superuser) for state in states: @@ -462,51 +439,46 @@ class TestAPI(FixtureAPITestCase): self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual( resp.json(), - {"state": [f"Transition from state {state} to state Stopping is forbidden."]} + {"state": ["State can only be updated from running to stopping"]}, ) self.task1.refresh_from_db() self.assertEqual(self.task1.state, state) - def test_partial_update_final_task_state_pending(self): - self.task1.state = State.Completed + def test_partial_update_running_task_non_pending(self): + states = list(State) + states.remove(State.Running) + states.remove(State.Stopping) + self.task1.state = State.Running self.task1.save() - self.client.force_login(self.superuser) - - with self.assertNumQueries(5): - resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), - data={"state": State.Pending.value}, - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - - self.task1.refresh_from_db() - self.assertEqual(self.task1.state, State.Pending) - self.assertIsNone(self.task1.agent) - def test_partial_update_non_final_task_state_pending(self): - states = set(State) - set(FINAL_STATES) - self.task1.save() self.client.force_login(self.superuser) - for state in states: with self.subTest(state=state): - self.task1.state = state - self.task1.save() - with self.assertNumQueries(3): resp = self.client.patch( reverse("api:task-update", args=[self.task1.id]), - data={"state": State.Pending.value}, + data={"state": state.value}, ) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual( resp.json(), - {"state": [f"Transition from state {state} to state Pending is forbidden."]} + {"state": ["State can only be updated from running to stopping"]}, ) - self.task1.refresh_from_db() - self.assertEqual(self.task1.state, state) + + def test_partial_update_same_state(self): + states = list(State) + + self.client.force_login(self.superuser) + for state in states: + with self.subTest(state=state): + self.task1.state = state + self.task1.save() + resp = self.client.patch( + reverse("api:task-update", args=[self.task1.id]), + data={"state": state.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) @patch("arkindex.project.aws.s3") def test_task_logs_unicode_error(self, s3_mock):