From fafde6d57a0791c2ca594e6a890dbbc25f8fb344 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Tue, 11 Jul 2023 17:12:08 +0200
Subject: [PATCH] Restrict RetrieveTaskDefinition to Ponos agents

---
 arkindex/ponos/api.py            | 10 ++++++++--
 arkindex/ponos/tests/test_api.py | 16 ++++++++++++----
 2 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py
index a2d97cfc9c..505eba9568 100644
--- a/arkindex/ponos/api.py
+++ b/arkindex/ponos/api.py
@@ -230,16 +230,22 @@ class AgentActions(RetrieveAPIView):
 )
 class TaskDefinition(RetrieveAPIView):
     """
-    Obtain a task's definition as an agent or admin.
+    Obtain a task's definition.
     This holds all the required data to start a task, except for the artifacts.
+
+    Requires authentication as a Ponos agent.
     """
 
     # We need to specify the default database to avoid stale reads
     # when a task is updated by an agent, then the agent immediately fetches its definition
     queryset = Task.objects.using("default").select_related("process")
-    permission_classes = (IsAgent,)
     serializer_class = TaskDefinitionSerializer
 
+    # This cannot be restricted to only the agent assigned to the task because agents need to access
+    # the parent tasks of an assigned task to download their artifacts.
+    authentication_classes = (AgentAuthentication, )
+    permission_classes = (IsAgent, )
+
 
 @extend_schema(tags=["ponos"])
 @extend_schema_view(
diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py
index 009de523c8..8316dc0889 100644
--- a/arkindex/ponos/tests/test_api.py
+++ b/arkindex/ponos/tests/test_api.py
@@ -152,13 +152,21 @@ class TestAPI(FixtureAPITestCase):
     def test_task_definition_requires_login(self):
         with self.assertNumQueries(0):
             response = self.client.get(reverse("api:task-definition", args=[self.task1.id]))
-            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+            self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
 
-    def test_task_definition_requires_agent_or_admin(self):
+    def test_task_definition_requires_agent(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(2):
+        with self.assertNumQueries(0):
             response = self.client.get(reverse("api:task-definition", args=[self.task1.id]))
-            self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+            self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+    def test_task_definition_task_forbidden(self):
+        with self.assertNumQueries(0):
+            response = self.client.get(
+                reverse("api:task-definition", args=[self.task1.id]),
+                HTTP_AUTHORIZATION=f'Ponos {self.task1.token}',
+            )
+            self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
 
     def test_task_definition_image_artifact(self):
         self.task1.image_artifact = self.task1.artifacts.create(
-- 
GitLab