Skip to content
Snippets Groups Projects
Commit 831e42a4 authored by Valentin Rigal's avatar Valentin Rigal Committed by Erwan Rouchet
Browse files

RestartTask endpoint

parent 4d9d1624
No related branches found
No related tags found
1 merge request!2266RestartTask endpoint
import uuid
from textwrap import dedent
from django.db import transaction
from django.shortcuts import get_object_or_404, redirect
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import serializers, status
from rest_framework.authentication import SessionAuthentication, TokenAuthentication
from rest_framework.generics import ListCreateAPIView, RetrieveUpdateAPIView, UpdateAPIView
from rest_framework.exceptions import NotFound, ValidationError
from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveUpdateAPIView, UpdateAPIView
from rest_framework.response import Response
from rest_framework.views import APIView
from arkindex.ponos.models import Artifact, Task
from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task, task_token_default
from arkindex.ponos.permissions import (
IsAgentOrArtifactGuest,
IsAgentOrTaskGuest,
......@@ -15,6 +20,9 @@ from arkindex.ponos.permissions import (
IsTaskAdmin,
)
from arkindex.ponos.serializers import ArtifactSerializer, TaskSerializer, TaskTinySerializer
from arkindex.project.mixins import ProcessACLMixin
from arkindex.project.permissions import IsVerified
from arkindex.users.models import Role
@extend_schema(tags=["ponos"])
......@@ -168,3 +176,81 @@ class TaskUpdate(UpdateAPIView):
permission_classes = (IsTaskAdmin, )
queryset = Task.objects.select_related("process__corpus")
serializer_class = TaskTinySerializer
@extend_schema_view(
post=extend_schema(
operation_id="RestartTask",
tags=["ponos"],
description=dedent(
"""
Restart a task by creating a fresh copy and moving dependent tasks to the new one.
Scenario restarting `my_worker` task:
```
init_elements → my_worker → other worker
```
```
init_elements → my_worker
my_worker_2 → other worker
```
Requires an **admin** access to the task's process.
The task must be in a final state to be restarted.
"""
),
responses={201: TaskSerializer},
),
)
class TaskRestart(ProcessACLMixin, CreateAPIView):
permission_classes = (IsVerified,)
serializer_class = serializers.Serializer
def get_task(self):
task = get_object_or_404(
Task.objects.prefetch_related("parents").select_related("process__corpus"),
pk=self.kwargs["pk"],
)
access_level = self.process_access_level(task.process)
if access_level is None:
raise NotFound
if access_level < Role.Admin.value:
raise ValidationError(
detail="You do not have an admin access to the process of this task."
)
if task.state not in FINAL_STATES:
raise ValidationError(
detail="Task's state must be in a final state to be restarted."
)
# TODO Check the original_task_id field directly once it is implemented
# https://gitlab.teklia.com/arkindex/frontend/-/issues/1383
if task.process.tasks.filter(run=task.run, slug=self.increment(task.slug)).exists():
raise ValidationError(
detail="This task has already been restarted"
)
return task
def increment(self, name):
basename, *suffix = name.rsplit("_restart", 1)
suffix = int(suffix[0]) + 1 if suffix and suffix[0].isdigit() else 1
return f"{basename}_restart{suffix}"
@transaction.atomic
def create(self, request, pk=None, **kwargs):
copy = self.get_task()
parents = list(copy.parents.all())
copy.id = uuid.uuid4()
copy.state = State.Pending
copy.token = task_token_default()
copy.slug = self.increment(copy.slug)
copy.save()
# Create links to retried task parents
copy.parents.add(*parents)
# Move all tasks depending on the retried task to the copy
Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id)
return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED)
import uuid
from io import BytesIO
from unittest import expectedFailure
from unittest.mock import call, patch, seal
......@@ -8,7 +9,7 @@ from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.ponos.models import FINAL_STATES, State
from arkindex.ponos.models import FINAL_STATES, State, Task
from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Right, Role, User
......@@ -25,7 +26,7 @@ class TestAPI(FixtureAPITestCase):
cls.rev = Revision.objects.first()
cls.process = Process.objects.get(mode=ProcessMode.Workers)
cls.process.run()
cls.task1, cls.task2, cls.task3 = cls.process.tasks.all()
cls.task1, cls.task2, cls.task3 = cls.process.tasks.all().order_by("depth")
# Brand new user and corpus with no preexisting rights
new_user = User.objects.create(email="another@user.com")
......@@ -554,3 +555,133 @@ class TestAPI(FixtureAPITestCase):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.json()["logs"], "")
def test_restart_task_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_restart_task_requires_verified(self):
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_restart_task_not_found(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@patch("arkindex.project.mixins.get_max_level")
def test_restart_task_forbidden(self, get_max_level_mock):
"""An admin access to the process is required"""
get_max_level_mock.return_value = Role.Guest.value
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(
response.json(),
["You do not have an admin access to the process of this task."],
)
def test_restart_task_non_final_state(self):
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(
response.json(),
["Task's state must be in a final state to be restarted."],
)
def test_restart_task_already_restarted(self):
self.client.force_login(self.user)
self.task2.slug = self.task1.slug + "_restart1"
self.task2.save()
self.task1.state = State.Completed.value
self.task1.save()
with self.assertNumQueries(8):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(
response.json(),
["This task has already been restarted"],
)
@patch("arkindex.project.aws.s3")
def test_restart_task(self, s3_mock):
"""
From:
task1 → task2_restart42 → task3
↘ ↗
task 4
To:
task1 → task2_restart42
task2_restart43 → task3
↘ ↗
task 4
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {
"Body": BytesIO(b"Task has been restarted")
}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
task4 = self.process.tasks.create(run=self.task1.run, depth=1)
task4.parents.add(self.task2)
task4.children.add(self.task3)
self.task1.state = State.Completed.value
self.task1.save()
self.task2.state = State.Error.value
self.task2.slug = "task2_restart42"
self.task2.save()
self.client.force_login(self.user)
with self.assertNumQueries(12):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.process.tasks.count(), 5)
restarted_task = self.process.tasks.latest("created")
self.assertDictEqual(
response.json(),
{
"id": str(restarted_task.id),
"depth": 1,
"agent": None,
"extra_files": {},
"full_log": "http://somewhere",
"gpu": None,
"logs": "Task has been restarted",
"parents": [str(self.task1.id)],
"run": 0,
"shm_size": None,
"slug": "task2_restart43",
"state": "pending",
},
)
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
self.assertQuerysetEqual(
restarted_task.children.all(),
Task.objects.filter(id__in=[self.task3.id, task4.id]),
)
......@@ -59,7 +59,7 @@ from arkindex.documents.api.ml import (
)
from arkindex.documents.api.search import CorpusSearch, SearchIndexBuild
from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
from arkindex.ponos.api import TaskArtifactDownload, TaskArtifacts, TaskDetailsFromAgent, TaskUpdate
from arkindex.ponos.api import TaskArtifactDownload, TaskArtifacts, TaskDetailsFromAgent, TaskRestart, TaskUpdate
from arkindex.process.api import (
ApplyProcessTemplate,
BucketList,
......@@ -327,4 +327,5 @@ api = [
TaskArtifactDownload.as_view(),
name="task-artifact-download",
),
path("task/<uuid:pk>/restart/", TaskRestart.as_view(), name="task-restart"),
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment