From 8dbb003304aec242a2256538852ac8fae47c39d8 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 23 Jul 2024 18:27:24 +0200 Subject: [PATCH] Support restarting tasks in RQ --- arkindex/ponos/api.py | 10 ++++++ arkindex/ponos/tests/test_api.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py index 2ba88a887e..9705f24c74 100644 --- a/arkindex/ponos/api.py +++ b/arkindex/ponos/api.py @@ -1,6 +1,8 @@ import uuid +from functools import partial from textwrap import dedent +from django.conf import settings from django.db import transaction from django.shortcuts import get_object_or_404, redirect from drf_spectacular.utils import extend_schema, extend_schema_view @@ -18,6 +20,7 @@ from arkindex.ponos.permissions import ( IsAssignedAgentOrTaskOrReadOnly, ) from arkindex.ponos.serializers import ArtifactSerializer, TaskSerializer +from arkindex.ponos.tasks import run_task_rq from arkindex.project.mixins import ProcessACLMixin from arkindex.project.permissions import IsVerified from arkindex.users.models import Role @@ -237,4 +240,11 @@ class TaskRestart(ProcessACLMixin, CreateAPIView): # Move all tasks depending on the restarted task to the copy Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id) + # Trigger a new RQ task execution when in RQ mode + # This does not handle dependencies, as all the parents had to be in a final state for this task to have a final state, + # and run_task_rq handles tasks with parents in error states. + if settings.PONOS_RQ_EXECUTION: + # Use on_commit so that the task is triggered only after the task has been created, outside of a transaction + transaction.on_commit(partial(run_task_rq.delay, copy)) + return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED) diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index a041d7a244..5bbcdf05bf 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -713,6 +713,7 @@ class TestAPI(FixtureAPITestCase): ["This task has already been restarted."], ) + @override_settings(PONOS_RQ_EXECUTION=False) @patch("arkindex.project.aws.s3") def test_restart_task(self, s3_mock): """ @@ -807,6 +808,7 @@ class TestAPI(FixtureAPITestCase): Task.objects.filter(id__in=[self.task3.id, task4.id]), ) + @override_settings(PONOS_RQ_EXECUTION=False) @patch("arkindex.project.aws.s3") def test_restart_task_no_previous_restart(self, s3_mock): """ @@ -872,3 +874,54 @@ class TestAPI(FixtureAPITestCase): restarted_task.children.all(), Task.objects.filter(id__in=[self.task3.id, task4.id]), ) + + @override_settings(PONOS_RQ_EXECUTION=True) + @patch("arkindex.ponos.tasks.run_task_rq.delay") + @patch("arkindex.project.aws.s3") + def test_restart_task_rq(self, s3_mock, delay_mock): + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Task has been restarted") + } + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + seal(s3_mock) + + delay_mock.return_value = None + seal(delay_mock) + + self.task1.state = State.Failed.value + self.task1.save() + + self.client.force_login(self.user) + + with self.assertNumQueries(13), self.captureOnCommitCallbacks(execute=True): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + self.assertEqual(self.process.tasks.count(), 4) + restarted_task = self.process.tasks.latest("created") + self.assertDictEqual( + response.json(), + { + "id": str(restarted_task.id), + "depth": 0, + "agent": None, + "extra_files": {}, + "full_log": "http://somewhere", + "gpu": None, + "logs": "Task has been restarted", + "original_task_id": str(self.task1.id), + "parents": [], + "run": 0, + "shm_size": None, + "slug": self.task1.slug, + "state": "pending", + "requires_gpu": False, + }, + ) + + self.assertEqual(delay_mock.call_count, 1) + self.assertEqual(delay_mock.call_args, call(restarted_task)) -- GitLab