From 751d18066e5197c3797778f81db1400fb01417c6 Mon Sep 17 00:00:00 2001
From: Valentin Rigal <rigal@teklia.com>
Date: Fri, 22 Mar 2024 11:27:01 +0100
Subject: [PATCH] Add tests

---
 arkindex/ponos/tests/test_api.py | 120 ++++++++++++++++++++++++++++++-
 1 file changed, 118 insertions(+), 2 deletions(-)

diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py
index ea84101f74..f4ddb01008 100644
--- a/arkindex/ponos/tests/test_api.py
+++ b/arkindex/ponos/tests/test_api.py
@@ -1,3 +1,4 @@
+import uuid
 from io import BytesIO
 from unittest import expectedFailure
 from unittest.mock import call, patch, seal
@@ -8,7 +9,7 @@ from django.urls import reverse
 from rest_framework import status
 
 from arkindex.documents.models import Corpus
-from arkindex.ponos.models import FINAL_STATES, State
+from arkindex.ponos.models import FINAL_STATES, State, Task
 from arkindex.process.models import Process, ProcessMode, Revision, WorkerVersion
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.users.models import Right, Role, User
@@ -25,7 +26,7 @@ class TestAPI(FixtureAPITestCase):
         cls.rev = Revision.objects.first()
         cls.process = Process.objects.get(mode=ProcessMode.Workers)
         cls.process.run()
-        cls.task1, cls.task2, cls.task3 = cls.process.tasks.all()
+        cls.task1, cls.task2, cls.task3 = cls.process.tasks.all().order_by("depth")
 
         # Brand new user and corpus with no preexisting rights
         new_user = User.objects.create(email="another@user.com")
@@ -554,3 +555,118 @@ class TestAPI(FixtureAPITestCase):
                 resp = self.client.get(reverse("api:task-details", args=[self.task1.id]))
                 self.assertEqual(resp.status_code, status.HTTP_200_OK)
                 self.assertEqual(resp.json()["logs"], "")
+
+    def test_restart_task_requires_login(self):
+        with self.assertNumQueries(0):
+            response = self.client.post(
+                reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_restart_task_requires_verified(self):
+        self.user.verified_email = False
+        self.user.save()
+        self.client.force_login(self.user)
+        with self.assertNumQueries(2):
+            response = self.client.post(
+                reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
+            )
+            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_restart_task_not_found(self):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(6):
+            response = self.client.post(
+                reverse("api:task-restart", kwargs={"pk": str(uuid.uuid4())})
+            )
+            self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+    @expectedFailure
+    def test_restart_task_forbidden(self):
+        """An admin access to the process is required"""
+        self.client.force_login(self.user)
+        self.corpus.memberships.filter(user=self.user).update(level=Role.Guest.value)
+        self.process.save()
+        with self.assertNumQueries(10):
+            response = self.client.post(
+                reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+            self.assertListEqual(
+                response.json(),
+                ["You do not have an admin access to the process of this task."],
+            )
+
+    def test_restart_task_non_final_state(self):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(8):
+            response = self.client.post(
+                reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+            self.assertListEqual(
+                response.json(),
+                ["Task's state must be in a final state to be restarted."],
+            )
+
+    @patch("arkindex.project.aws.s3")
+    def test_restart_task(self, s3_mock):
+        """
+        From:
+        task1 → task2_restart42 → task3
+                            ↘      ↗
+                             task 4
+
+        To:
+        task1  → task2_restart42
+             ↘
+               task2_restart43 → task3
+                           ↘      ↗
+                            task 4
+        """
+        s3_mock.Object.return_value.bucket_name = "ponos"
+        s3_mock.Object.return_value.key = "somelog"
+        s3_mock.Object.return_value.get.return_value = {
+            "Body": BytesIO(b"Task has been restarted")
+        }
+        s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
+
+        task4 = self.process.tasks.create(run=self.task1.run, depth=1)
+        task4.parents.add(self.task2)
+        task4.children.add(self.task3)
+        self.task1.state = State.Completed.value
+        self.task1.save()
+        self.task2.state = State.Error.value
+        self.task2.slug = "task2_restart42"
+        self.task2.save()
+
+        self.client.force_login(self.user)
+        with self.assertNumQueries(12):
+            response = self.client.post(
+                reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
+            )
+            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        self.assertEqual(self.process.tasks.count(), 5)
+        restarted_task = self.process.tasks.latest("created")
+        self.assertDictEqual(
+            response.json(),
+            {
+                "id": str(restarted_task.id),
+                "depth": 1,
+                "agent": None,
+                "extra_files": {},
+                "full_log": "http://somewhere",
+                "gpu": None,
+                "logs": "Task has been restarted",
+                "parents": [str(self.task1.id)],
+                "run": 0,
+                "shm_size": None,
+                "slug": "task2_restart43",
+                "state": "pending",
+            },
+        )
+        self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
+        self.assertQuerysetEqual(
+            restarted_task.children.all(),
+            Task.objects.filter(id__in=[self.task3.id, task4.id]),
+        )
-- 
GitLab