Skip to content
Snippets Groups Projects
Commit 8dbb0033 authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

Support restarting tasks in RQ

parent fb85ced6
No related branches found
No related tags found
1 merge request!2394Support restarting tasks in RQ
import uuid
from functools import partial
from textwrap import dedent
from django.conf import settings
from django.db import transaction
from django.shortcuts import get_object_or_404, redirect
from drf_spectacular.utils import extend_schema, extend_schema_view
......@@ -18,6 +20,7 @@ from arkindex.ponos.permissions import (
IsAssignedAgentOrTaskOrReadOnly,
)
from arkindex.ponos.serializers import ArtifactSerializer, TaskSerializer
from arkindex.ponos.tasks import run_task_rq
from arkindex.project.mixins import ProcessACLMixin
from arkindex.project.permissions import IsVerified
from arkindex.users.models import Role
......@@ -237,4 +240,11 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
# Move all tasks depending on the restarted task to the copy
Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id)
# Trigger a new RQ task execution when in RQ mode
# This does not handle dependencies, as all the parents had to be in a final state for this task to have a final state,
# and run_task_rq handles tasks with parents in error states.
if settings.PONOS_RQ_EXECUTION:
# Use on_commit so that the task is triggered only after the task has been created, outside of a transaction
transaction.on_commit(partial(run_task_rq.delay, copy))
return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED)
......@@ -713,6 +713,7 @@ class TestAPI(FixtureAPITestCase):
["This task has already been restarted."],
)
@override_settings(PONOS_RQ_EXECUTION=False)
@patch("arkindex.project.aws.s3")
def test_restart_task(self, s3_mock):
"""
......@@ -807,6 +808,7 @@ class TestAPI(FixtureAPITestCase):
Task.objects.filter(id__in=[self.task3.id, task4.id]),
)
@override_settings(PONOS_RQ_EXECUTION=False)
@patch("arkindex.project.aws.s3")
def test_restart_task_no_previous_restart(self, s3_mock):
"""
......@@ -872,3 +874,54 @@ class TestAPI(FixtureAPITestCase):
restarted_task.children.all(),
Task.objects.filter(id__in=[self.task3.id, task4.id]),
)
@override_settings(PONOS_RQ_EXECUTION=True)
@patch("arkindex.ponos.tasks.run_task_rq.delay")
@patch("arkindex.project.aws.s3")
def test_restart_task_rq(self, s3_mock, delay_mock):
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {
"Body": BytesIO(b"Task has been restarted")
}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
delay_mock.return_value = None
seal(delay_mock)
self.task1.state = State.Failed.value
self.task1.save()
self.client.force_login(self.user)
with self.assertNumQueries(13), self.captureOnCommitCallbacks(execute=True):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.process.tasks.count(), 4)
restarted_task = self.process.tasks.latest("created")
self.assertDictEqual(
response.json(),
{
"id": str(restarted_task.id),
"depth": 0,
"agent": None,
"extra_files": {},
"full_log": "http://somewhere",
"gpu": None,
"logs": "Task has been restarted",
"original_task_id": str(self.task1.id),
"parents": [],
"run": 0,
"shm_size": None,
"slug": self.task1.slug,
"state": "pending",
"requires_gpu": False,
},
)
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(restarted_task))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment