Skip to content
Snippets Groups Projects

RestartTask endpoint

Merged Valentin Rigal requested to merge restart-task into master
3 files
+ 223
5
Compare changes
  • Side-by-side
  • Inline
Files
3
 
import uuid
from io import BytesIO
from io import BytesIO
from unittest import expectedFailure
from unittest import expectedFailure
from unittest.mock import call, patch, seal
from unittest.mock import call, patch, seal
@@ -8,7 +9,7 @@ from django.urls import reverse
@@ -8,7 +9,7 @@ from django.urls import reverse
from rest_framework import status
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.documents.models import Corpus
from arkindex.ponos.models import FINAL_STATES, State
from arkindex.ponos.models import FINAL_STATES, State, Task
from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion
from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Right, Role, User
from arkindex.users.models import Right, Role, User
@@ -25,7 +26,7 @@ class TestAPI(FixtureAPITestCase):
@@ -25,7 +26,7 @@ class TestAPI(FixtureAPITestCase):
cls.rev = Revision.objects.first()
cls.rev = Revision.objects.first()
cls.process = Process.objects.get(mode=ProcessMode.Workers)
cls.process = Process.objects.get(mode=ProcessMode.Workers)
cls.process.run()
cls.process.run()
cls.task1, cls.task2, cls.task3 = cls.process.tasks.all()
cls.task1, cls.task2, cls.task3 = cls.process.tasks.all().order_by("depth")
# Brand new user and corpus with no preexisting rights
# Brand new user and corpus with no preexisting rights
new_user = User.objects.create(email="another@user.com")
new_user = User.objects.create(email="another@user.com")
@@ -554,3 +555,133 @@ class TestAPI(FixtureAPITestCase):
@@ -554,3 +555,133 @@ class TestAPI(FixtureAPITestCase):
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.json()["logs"], "")
self.assertEqual(resp.json()["logs"], "")
 
 
def test_restart_task_requires_login(self):
 
with self.assertNumQueries(0):
 
response = self.client.post(
 
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
 
)
 
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
 
def test_restart_task_requires_verified(self):
 
self.user.verified_email = False
 
self.user.save()
 
self.client.force_login(self.user)
 
with self.assertNumQueries(2):
 
response = self.client.post(
 
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
 
)
 
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
 
 
def test_restart_task_not_found(self):
 
self.client.force_login(self.user)
 
with self.assertNumQueries(6):
 
response = self.client.post(
 
reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
 
)
 
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
 
 
@patch("arkindex.project.mixins.get_max_level")
 
def test_restart_task_forbidden(self, get_max_level_mock):
 
"""An admin access to the process is required"""
 
get_max_level_mock.return_value = Role.Guest.value
 
self.client.force_login(self.user)
 
with self.assertNumQueries(7):
 
response = self.client.post(
 
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
 
)
 
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
self.assertListEqual(
 
response.json(),
 
["You do not have an admin access to the process of this task."],
 
)
 
 
def test_restart_task_non_final_state(self):
 
self.client.force_login(self.user)
 
with self.assertNumQueries(7):
 
response = self.client.post(
 
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
 
)
 
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
self.assertListEqual(
 
response.json(),
 
["Task's state must be in a final state to be restarted."],
 
)
 
 
def test_restart_task_already_restarted(self):
 
self.client.force_login(self.user)
 
self.task2.slug = self.task1.slug + "_restart1"
 
self.task2.save()
 
self.task1.state = State.Completed.value
 
self.task1.save()
 
with self.assertNumQueries(8):
 
response = self.client.post(
 
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
 
)
 
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
self.assertListEqual(
 
response.json(),
 
["This task has already been restarted"],
 
)
 
 
@patch("arkindex.project.aws.s3")
 
def test_restart_task(self, s3_mock):
 
"""
 
From:
 
task1 → task2_restart42 → task3
 
↘ ↗
 
task 4
 
 
To:
 
task1 → task2_restart42
 
 
task2_restart43 → task3
 
↘ ↗
 
task 4
 
"""
 
s3_mock.Object.return_value.bucket_name = "ponos"
 
s3_mock.Object.return_value.key = "somelog"
 
s3_mock.Object.return_value.get.return_value = {
 
"Body": BytesIO(b"Task has been restarted")
 
}
 
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
 
 
task4 = self.process.tasks.create(run=self.task1.run, depth=1)
 
task4.parents.add(self.task2)
 
task4.children.add(self.task3)
 
self.task1.state = State.Completed.value
 
self.task1.save()
 
self.task2.state = State.Error.value
 
self.task2.slug = "task2_restart42"
 
self.task2.save()
 
 
self.client.force_login(self.user)
 
with self.assertNumQueries(12):
 
response = self.client.post(
 
reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
 
)
 
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
 
self.assertEqual(self.process.tasks.count(), 5)
 
restarted_task = self.process.tasks.latest("created")
 
self.assertDictEqual(
 
response.json(),
 
{
 
"id": str(restarted_task.id),
 
"depth": 1,
 
"agent": None,
 
"extra_files": {},
 
"full_log": "http://somewhere",
 
"gpu": None,
 
"logs": "Task has been restarted",
 
"parents": [str(self.task1.id)],
 
"run": 0,
 
"shm_size": None,
 
"slug": "task2_restart43",
 
"state": "pending",
 
},
 
)
 
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
 
self.assertQuerysetEqual(
 
restarted_task.children.all(),
 
Task.objects.filter(id__in=[self.task3.id, task4.id]),
 
)
Loading