diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py index a7f0864565e5ad980716882e12d64213969c84c5..5338c65487ddcf7400adb554c7e18405fd1b0840 100644 --- a/arkindex/ponos/api.py +++ b/arkindex/ponos/api.py @@ -1,12 +1,17 @@ +import uuid from textwrap import dedent +from django.db import transaction from django.shortcuts import get_object_or_404, redirect from drf_spectacular.utils import extend_schema, extend_schema_view +from rest_framework import serializers, status from rest_framework.authentication import SessionAuthentication, TokenAuthentication -from rest_framework.generics import ListCreateAPIView, RetrieveUpdateAPIView, UpdateAPIView +from rest_framework.exceptions import NotFound, ValidationError +from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveUpdateAPIView, UpdateAPIView +from rest_framework.response import Response from rest_framework.views import APIView -from arkindex.ponos.models import Artifact, Task +from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task, task_token_default from arkindex.ponos.permissions import ( IsAgentOrArtifactGuest, IsAgentOrTaskGuest, @@ -15,6 +20,9 @@ from arkindex.ponos.permissions import ( IsTaskAdmin, ) from arkindex.ponos.serializers import ArtifactSerializer, TaskSerializer, TaskTinySerializer +from arkindex.project.mixins import ProcessACLMixin +from arkindex.project.permissions import IsVerified +from arkindex.users.models import Role @extend_schema(tags=["ponos"]) @@ -168,3 +176,81 @@ class TaskUpdate(UpdateAPIView): permission_classes = (IsTaskAdmin, ) queryset = Task.objects.select_related("process__corpus") serializer_class = TaskTinySerializer + + +@extend_schema_view( + post=extend_schema( + operation_id="RestartTask", + tags=["ponos"], + description=dedent( + """ + Restart a task by creating a fresh copy and moving dependent tasks to the new one. + + Scenario restarting `my_worker` task: + ``` + init_elements → my_worker → other worker + ``` + ``` + init_elements → my_worker + ↘ + my_worker_2 → other worker + ``` + + Requires an **admin** access to the task's process. + The task must be in a final state to be restarted. + """ + ), + responses={201: TaskSerializer}, + ), +) +class TaskRestart(ProcessACLMixin, CreateAPIView): + permission_classes = (IsVerified,) + serializer_class = serializers.Serializer + + def get_task(self): + task = get_object_or_404( + Task.objects.prefetch_related("parents").select_related("process__corpus"), + pk=self.kwargs["pk"], + ) + access_level = self.process_access_level(task.process) + if access_level is None: + raise NotFound + if access_level < Role.Admin.value: + raise ValidationError( + detail="You do not have an admin access to the process of this task." + ) + if task.state not in FINAL_STATES: + raise ValidationError( + detail="Task's state must be in a final state to be restarted." + ) + # TODO Check the original_task_id field directly once it is implemented + # https://gitlab.teklia.com/arkindex/frontend/-/issues/1383 + if task.process.tasks.filter(run=task.run, slug=self.increment(task.slug)).exists(): + raise ValidationError( + detail="This task has already been restarted" + ) + return task + + def increment(self, name): + basename, *suffix = name.rsplit("_restart", 1) + suffix = int(suffix[0]) + 1 if suffix and suffix[0].isdigit() else 1 + return f"{basename}_restart{suffix}" + + @transaction.atomic + def create(self, request, pk=None, **kwargs): + copy = self.get_task() + parents = list(copy.parents.all()) + + copy.id = uuid.uuid4() + copy.state = State.Pending + copy.token = task_token_default() + copy.slug = self.increment(copy.slug) + copy.save() + + # Create links to retried task parents + copy.parents.add(*parents) + + # Move all tasks depending on the retried task to the copy + Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id) + + return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED) diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index ea84101f743adbdee613894007f037bf919c8596..5c05f001402e29cb25ac3cecc57a90ba413738c5 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -1,3 +1,4 @@ +import uuid from io import BytesIO from unittest import expectedFailure from unittest.mock import call, patch, seal @@ -8,7 +9,7 @@ from django.urls import reverse from rest_framework import status from arkindex.documents.models import Corpus -from arkindex.ponos.models import FINAL_STATES, State +from arkindex.ponos.models import FINAL_STATES, State, Task from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion from arkindex.project.tests import FixtureAPITestCase from arkindex.users.models import Right, Role, User @@ -25,7 +26,7 @@ class TestAPI(FixtureAPITestCase): cls.rev = Revision.objects.first() cls.process = Process.objects.get(mode=ProcessMode.Workers) cls.process.run() - cls.task1, cls.task2, cls.task3 = cls.process.tasks.all() + cls.task1, cls.task2, cls.task3 = cls.process.tasks.all().order_by("depth") # Brand new user and corpus with no preexisting rights new_user = User.objects.create(email="another@user.com") @@ -554,3 +555,133 @@ class TestAPI(FixtureAPITestCase): resp = self.client.get(reverse("api:task-details", args=[self.task1.id])) self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.json()["logs"], "") + + def test_restart_task_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_restart_task_requires_verified(self): + self.user.verified_email = False + self.user.save() + self.client.force_login(self.user) + with self.assertNumQueries(2): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_restart_task_not_found(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())}) + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + @patch("arkindex.project.mixins.get_max_level") + def test_restart_task_forbidden(self, get_max_level_mock): + """An admin access to the process is required""" + get_max_level_mock.return_value = Role.Guest.value + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertListEqual( + response.json(), + ["You do not have an admin access to the process of this task."], + ) + + def test_restart_task_non_final_state(self): + self.client.force_login(self.user) + with self.assertNumQueries(7): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertListEqual( + response.json(), + ["Task's state must be in a final state to be restarted."], + ) + + def test_restart_task_already_restarted(self): + self.client.force_login(self.user) + self.task2.slug = self.task1.slug + "_restart1" + self.task2.save() + self.task1.state = State.Completed.value + self.task1.save() + with self.assertNumQueries(8): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertListEqual( + response.json(), + ["This task has already been restarted"], + ) + + @patch("arkindex.project.aws.s3") + def test_restart_task(self, s3_mock): + """ + From: + task1 → task2_restart42 → task3 + ↘ ↗ + task 4 + + To: + task1 → task2_restart42 + ↘ + task2_restart43 → task3 + ↘ ↗ + task 4 + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Task has been restarted") + } + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + task4 = self.process.tasks.create(run=self.task1.run, depth=1) + task4.parents.add(self.task2) + task4.children.add(self.task3) + self.task1.state = State.Completed.value + self.task1.save() + self.task2.state = State.Error.value + self.task2.slug = "task2_restart42" + self.task2.save() + + self.client.force_login(self.user) + with self.assertNumQueries(12): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task2.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(self.process.tasks.count(), 5) + restarted_task = self.process.tasks.latest("created") + self.assertDictEqual( + response.json(), + { + "id": str(restarted_task.id), + "depth": 1, + "agent": None, + "extra_files": {}, + "full_log": "http://somewhere", + "gpu": None, + "logs": "Task has been restarted", + "parents": [str(self.task1.id)], + "run": 0, + "shm_size": None, + "slug": "task2_restart43", + "state": "pending", + }, + ) + self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none()) + self.assertQuerysetEqual( + restarted_task.children.all(), + Task.objects.filter(id__in=[self.task3.id, task4.id]), + ) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index aad7ba77c22be98ea347cf4cedeff218c4061f92..87433cb65a2ba231389f89dd705f66d523c8db76 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -59,7 +59,7 @@ from arkindex.documents.api.ml import ( ) from arkindex.documents.api.search import CorpusSearch, SearchIndexBuild from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve -from arkindex.ponos.api import TaskArtifactDownload, TaskArtifacts, TaskDetailsFromAgent, TaskUpdate +from arkindex.ponos.api import TaskArtifactDownload, TaskArtifacts, TaskDetailsFromAgent, TaskRestart, TaskUpdate from arkindex.process.api import ( ApplyProcessTemplate, BucketList, @@ -327,4 +327,5 @@ api = [ TaskArtifactDownload.as_view(), name="task-artifact-download", ), + path("task/<uuid:pk>/restart/", TaskRestart.as_view(), name="task-restart"), ]