Skip to content
Snippets Groups Projects

RestartTask endpoint

Merged Valentin Rigal requested to merge restart-task into master
All threads resolved!
Files
3
import uuid
from io import BytesIO
from unittest import expectedFailure
from unittest.mock import call, patch, seal
@@ -8,7 +9,7 @@ from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.ponos.models import FINAL_STATES, State
from arkindex.ponos.models import FINAL_STATES, State, Task
from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Right, Role, User
@@ -25,7 +26,7 @@ class TestAPI(FixtureAPITestCase):
cls.rev = Revision.objects.first()
cls.process = Process.objects.get(mode=ProcessMode.Workers)
cls.process.run()
cls.task1, cls.task2, cls.task3 = cls.process.tasks.all()
cls.task1, cls.task2, cls.task3 = cls.process.tasks.all().order_by("depth")
# Brand new user and corpus with no preexisting rights
new_user = User.objects.create(email="another@user.com")
@@ -555,3 +556,118 @@ class TestAPI(FixtureAPITestCase):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.json()["logs"], "")
def test_restart_task_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_restart_task_requires_verified(self):
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_restart_task_not_found(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@expectedFailure
def test_restart_task_forbidden(self):
"""An admin access to the process is required"""
self.client.force_login(self.user)
self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value)
self.process.save()
with self.assertNumQueries(10):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(
response.json(),
["You do not have an admin access to the process of this task."],
)
def test_restart_task_non_final_state(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(
response.json(),
["Task's state must be in a final state to be restarted."],
)
@patch("arkindex.project.aws.s3")
def test_restart_task(self, s3_mock):
"""
From:
task1 → task2_restart42 → task3
↘ ↗
task 4
To:
task1 → task2_restart42
task2_restart43 → task3
↘ ↗
task 4
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {
"Body": BytesIO(b"Task has been restarted")
}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
task4 = self.process.tasks.create(run=self.task1.run, depth=1)
task4.parents.add(self.task2)
task4.children.add(self.task3)
self.task1.state = State.Completed.value
self.task1.save()
self.task2.state = State.Error.value
self.task2.slug = "task2_restart42"
self.task2.save()
self.client.force_login(self.user)
with self.assertNumQueries(12):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.process.tasks.count(), 5)
restarted_task = self.process.tasks.latest("created")
self.assertDictEqual(
response.json(),
{
"id": str(restarted_task.id),
"depth": 1,
"agent": None,
"extra_files": {},
"full_log": "http://somewhere",
"gpu": None,
"logs": "Task has been restarted",
"parents": [str(self.task1.id)],
"run": 0,
"shm_size": None,
"slug": "task2_restart43",
"state": "pending",
},
)
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
self.assertQuerysetEqual(
restarted_task.children.all(),
Task.objects.filter(id__in=[self.task3.id, task4.id]),
)
Loading