From 8dbb003304aec242a2256538852ac8fae47c39d8 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Tue, 23 Jul 2024 18:27:24 +0200
Subject: [PATCH] Support restarting tasks in RQ

---
 arkindex/ponos/api.py            | 10 ++++++
 arkindex/ponos/tests/test_api.py | 53 ++++++++++++++++++++++++++++++++
 2 files changed, 63 insertions(+)

diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py
index 2ba88a887e..9705f24c74 100644
--- a/arkindex/ponos/api.py
+++ b/arkindex/ponos/api.py
@@ -1,6 +1,8 @@
 import uuid
+from functools import partial
 from textwrap import dedent
 
+from django.conf import settings
 from django.db import transaction
 from django.shortcuts import get_object_or_404, redirect
 from drf_spectacular.utils import extend_schema, extend_schema_view
@@ -18,6 +20,7 @@ from arkindex.ponos.permissions import (
     IsAssignedAgentOrTaskOrReadOnly,
 )
 from arkindex.ponos.serializers import ArtifactSerializer, TaskSerializer
+from arkindex.ponos.tasks import run_task_rq
 from arkindex.project.mixins import ProcessACLMixin
 from arkindex.project.permissions import IsVerified
 from arkindex.users.models import Role
@@ -237,4 +240,11 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
         # Move all tasks depending on the restarted task to the copy
         Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id)
 
+        # Trigger a new RQ task execution when in RQ mode
+        # This does not handle dependencies, as all the parents had to be in a final state for this task to have a final state,
+        # and run_task_rq handles tasks with parents in error states.
+        if settings.PONOS_RQ_EXECUTION:
+            # Use on_commit so that the task is triggered only after the task has been created, outside of a transaction
+            transaction.on_commit(partial(run_task_rq.delay, copy))
+
         return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED)
diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py
index a041d7a244..5bbcdf05bf 100644
--- a/arkindex/ponos/tests/test_api.py
+++ b/arkindex/ponos/tests/test_api.py
@@ -713,6 +713,7 @@ class TestAPI(FixtureAPITestCase):
                 ["This task has already been restarted."],
             )
 
+    @override_settings(PONOS_RQ_EXECUTION=False)
     @patch("arkindex.project.aws.s3")
     def test_restart_task(self, s3_mock):
         """
@@ -807,6 +808,7 @@ class TestAPI(FixtureAPITestCase):
             Task.objects.filter(id__in=[self.task3.id, task4.id]),
         )
 
+    @override_settings(PONOS_RQ_EXECUTION=False)
     @patch("arkindex.project.aws.s3")
     def test_restart_task_no_previous_restart(self, s3_mock):
         """
@@ -872,3 +874,54 @@ class TestAPI(FixtureAPITestCase):
             restarted_task.children.all(),
             Task.objects.filter(id__in=[self.task3.id, task4.id]),
         )
+
+    @override_settings(PONOS_RQ_EXECUTION=True)
+    @patch("arkindex.ponos.tasks.run_task_rq.delay")
+    @patch("arkindex.project.aws.s3")
+    def test_restart_task_rq(self, s3_mock, delay_mock):
+        s3_mock.Object.return_value.bucket_name = "ponos"
+        s3_mock.Object.return_value.key = "somelog"
+        s3_mock.Object.return_value.get.return_value = {
+            "Body": BytesIO(b"Task has been restarted")
+        }
+        s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
+        seal(s3_mock)
+
+        delay_mock.return_value = None
+        seal(delay_mock)
+
+        self.task1.state = State.Failed.value
+        self.task1.save()
+
+        self.client.force_login(self.user)
+
+        with self.assertNumQueries(13), self.captureOnCommitCallbacks(execute=True):
+            response = self.client.post(
+                reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
+            )
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+        self.assertEqual(self.process.tasks.count(), 4)
+        restarted_task = self.process.tasks.latest("created")
+        self.assertDictEqual(
+            response.json(),
+            {
+                "id": str(restarted_task.id),
+                "depth": 0,
+                "agent": None,
+                "extra_files": {},
+                "full_log": "http://somewhere",
+                "gpu": None,
+                "logs": "Task has been restarted",
+                "original_task_id": str(self.task1.id),
+                "parents": [],
+                "run": 0,
+                "shm_size": None,
+                "slug": self.task1.slug,
+                "state": "pending",
+                "requires_gpu": False,
+            },
+        )
+
+        self.assertEqual(delay_mock.call_count, 1)
+        self.assertEqual(delay_mock.call_args, call(restarted_task))
-- 
GitLab