From 751d18066e5197c3797778f81db1400fb01417c6 Mon Sep 17 00:00:00 2001 From: Valentin Rigal <rigal@teklia.com> Date: Fri, 22 Mar 2024 11:27:01 +0100 Subject: [PATCH] Add tests --- arkindex/ponos/tests/test_api.py | 120 ++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index ea84101f74..f4ddb01008 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -1,3 +1,4 @@ +import uuid from io import BytesIO from unittest import expectedFailure from unittest.mock import call, patch, seal @@ -8,7 +9,7 @@ from django.urls import reverse from rest_framework import status from arkindex.documents.models import Corpus -from arkindex.ponos.models import FINAL_STATES, State +from arkindex.ponos.models import FINAL_STATES, State, Task from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion from arkindex.project.tests import FixtureAPITestCase from arkindex.users.models import Right, Role, User @@ -25,7 +26,7 @@ class TestAPI(FixtureAPITestCase): cls.rev = Revision.objects.first() cls.process = Process.objects.get(mode=ProcessMode.Workers) cls.process.run() - cls.task1, cls.task2, cls.task3 = cls.process.tasks.all() + cls.task1, cls.task2, cls.task3 = cls.process.tasks.all().order_by("depth") # Brand new user and corpus with no preexisting rights new_user = User.objects.create(email="another@user.com") @@ -554,3 +555,118 @@ class TestAPI(FixtureAPITestCase): resp = self.client.get(reverse("api:task-details", args=[self.task1.id])) self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.json()["logs"], "") + + def test_restart_task_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_restart_task_requires_verified(self): + self.user.verified_email = False + self.user.save() + self.client.force_login(self.user) + with self.assertNumQueries(2): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())}) + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_restart_task_not_found(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())}) + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + @expectedFailure + def test_restart_task_forbidden(self): + """An admin access to the process is required""" + self.client.force_login(self.user) + self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value) + self.process.save() + with self.assertNumQueries(10): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertListEqual( + response.json(), + ["You do not have an admin access to the process of this task."], + ) + + def test_restart_task_non_final_state(self): + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertListEqual( + response.json(), + ["Task's state must be in a final state to be restarted."], + ) + + @patch("arkindex.project.aws.s3") + def test_restart_task(self, s3_mock): + """ + From: + task1 → task2_restart42 → task3 + ↘ ↗ + task 4 + + To: + task1 → task2_restart42 + ↘ + task2_restart43 → task3 + ↘ ↗ + task 4 + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Task has been restarted") + } + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + task4 = self.process.tasks.create(run=self.task1.run, depth=1) + task4.parents.add(self.task2) + task4.children.add(self.task3) + self.task1.state = State.Completed.value + self.task1.save() + self.task2.state = State.Error.value + self.task2.slug = "task2_restart42" + self.task2.save() + + self.client.force_login(self.user) + with self.assertNumQueries(12): + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task2.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(self.process.tasks.count(), 5) + restarted_task = self.process.tasks.latest("created") + self.assertDictEqual( + response.json(), + { + "id": str(restarted_task.id), + "depth": 1, + "agent": None, + "extra_files": {}, + "full_log": "http://somewhere", + "gpu": None, + "logs": "Task has been restarted", + "parents": [str(self.task1.id)], + "run": 0, + "shm_size": None, + "slug": "task2_restart43", + "state": "pending", + }, + ) + self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none()) + self.assertQuerysetEqual( + restarted_task.children.all(), + Task.objects.filter(id__in=[self.task3.id, task4.id]), + ) -- GitLab