diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index d73d3893a4ef5902b54da9e7b7f4596b4eee0e70..3a683208e802dcc78d2b870fa1f4696d721beed8 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -86,16 +86,16 @@ class Corpus(IndexableModel): return level is not None and level >= Role.Contributor.value def is_processable(self, user) -> bool: - """ - Whether a user can create and execute processes on this corpus - """ - if user.is_anonymous or getattr(user, "is_agent", False): - return False - if user.is_admin: - return True - from arkindex.users.utils import get_max_level - level = get_max_level(user, self) - return level is not None and level >= Role.Admin.value + """ + Whether a user can create and execute processes on this corpus + """ + if user.is_anonymous or getattr(user, "is_agent", False): + return False + if user.is_admin: + return True + from arkindex.users.utils import get_max_level + level = get_max_level(user, self) + return level is not None and level >= Role.Admin.value class ElementType(models.Model): diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py index a24da2fe6791d8c95a3824c038b18196865d6619..2ba88a887ee3b3fdfd6f2286073960b84b23586b 100644 --- a/arkindex/ponos/api.py +++ b/arkindex/ponos/api.py @@ -5,9 +5,8 @@ from django.db import transaction from django.shortcuts import get_object_or_404, redirect from drf_spectacular.utils import extend_schema, extend_schema_view from rest_framework import status -from rest_framework.authentication import SessionAuthentication, TokenAuthentication from rest_framework.exceptions import NotFound, ValidationError -from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveUpdateAPIView, UpdateAPIView +from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveUpdateAPIView from rest_framework.response import Response from rest_framework.views import APIView @@ -15,11 +14,10 @@ from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task, task_toke from arkindex.ponos.permissions import ( IsAgentOrArtifactGuest, IsAgentOrTaskGuest, - IsAssignedAgentOrReadOnly, + IsAssignedAgentOrTaskAdminOrReadOnly, IsAssignedAgentOrTaskOrReadOnly, - IsTaskAdmin, ) -from arkindex.ponos.serializers import ArtifactSerializer, TaskSerializer, TaskTinySerializer +from arkindex.ponos.serializers import ArtifactSerializer, TaskSerializer from arkindex.project.mixins import ProcessACLMixin from arkindex.project.permissions import IsVerified from arkindex.users.models import Role @@ -28,7 +26,7 @@ from arkindex.users.models import Role @extend_schema(tags=["ponos"]) @extend_schema_view( get=extend_schema( - operation_id="RetrieveTaskFromAgent", + operation_id="RetrieveTask", description=dedent(""" Retrieve a Ponos task. @@ -36,23 +34,23 @@ from arkindex.users.models import Role """), ), put=extend_schema( - operation_id="UpdateTaskFromAgent", + operation_id="UpdateTask", description=dedent(""" Update a task. - Requires authentication as the Ponos agent assigned to the task. + Requires authentication as the Ponos agent assigned to the task, **admin** access on the task's process, or to be the creator of the process. """), ), patch=extend_schema( - operation_id="PartialUpdateTaskFromAgent", + operation_id="PartialUpdateTask", description=dedent(""" Partially update a task. - Requires authentication as the Ponos agent assigned to the task. + Requires authentication as the Ponos agent assigned to the task, **admin** access on the task's process, or to be the creator of the process. """), ), ) -class TaskDetailsFromAgent(RetrieveUpdateAPIView): +class TaskDetails(RetrieveUpdateAPIView): # Avoid stale read when a recently assigned agent wants to update # the state of one of its tasks @@ -65,8 +63,9 @@ class TaskDetailsFromAgent(RetrieveUpdateAPIView): permission_classes = ( # On all HTTP methods, require either any Ponos agent, an instance admin, the task itself, or guest access to the process' task IsAgentOrTaskGuest, - # On unsafe HTTP methods, require a Ponos agent assigned to the task. Both permission classes are combined. - IsAssignedAgentOrReadOnly, + # On unsafe HTTP methods, require a Ponos agent assigned to the task or admin access to the process. + # Both permission classes are combined. + IsAssignedAgentOrTaskAdminOrReadOnly, ) serializer_class = TaskSerializer @@ -150,34 +149,6 @@ class TaskArtifactDownload(APIView): return redirect(artifact.s3_url) -@extend_schema(tags=["ponos"]) -@extend_schema_view( - put=extend_schema( - description=dedent(""" - Update a task. - - Requires **admin** access on the task's process, or to be the creator of the process. - Cannot be used with Ponos agent or task authentication. - """), - ), - patch=extend_schema( - description=dedent(""" - Partially update a task. - - Requires **admin** access on the task's process, or to be the creator of the process. - Cannot be used with Ponos agent or task authentication. - """), - ), -) -class TaskUpdate(UpdateAPIView): - # Only allow regular users, not Ponos agents or tasks - authentication_classes = (TokenAuthentication, SessionAuthentication) - # Only allow regular users that have admin access to the task's process - permission_classes = (IsTaskAdmin, ) - queryset = Task.objects.select_related("process__corpus") - serializer_class = TaskTinySerializer - - @extend_schema_view( post=extend_schema( operation_id="RestartTask", diff --git a/arkindex/ponos/permissions.py b/arkindex/ponos/permissions.py index 8afbd90d72e85d10dd3e4c4ad9e3456d9ea2d989..7f401f45f5ff189e2cd199615e8357fbc54db5b0 100644 --- a/arkindex/ponos/permissions.py +++ b/arkindex/ponos/permissions.py @@ -2,7 +2,7 @@ from rest_framework.permissions import SAFE_METHODS from arkindex.ponos.models import Task from arkindex.project.mixins import ProcessACLMixin -from arkindex.project.permissions import IsAuthenticated, IsVerified +from arkindex.project.permissions import IsAuthenticated, IsVerified, require_verified_email from arkindex.users.models import Role @@ -14,17 +14,22 @@ def require_task(request, view): return getattr(request.user, "is_admin", False) or isinstance(request.auth, Task) +def require_verified_or_agent(request, view): + return require_agent_or_admin(request, view) or require_verified_email(request, view) + + def require_agent_or_task(request, view): + return require_agent_or_admin(request, view) or require_task(request, view) + + +def require_agent_or_task_or_verified(request, view): return ( - getattr(request.user, "is_agent", False) + require_agent_or_admin(request, view) or require_task(request, view) + or require_verified_email(request, view) ) -def require_verified_or_agent(request, view): - return getattr(request.user, "verified_email", False) or require_agent_or_admin(request, view) - - class IsAgent(IsAuthenticated): """ Only allow Ponos agents and admins. @@ -32,13 +37,6 @@ class IsAgent(IsAuthenticated): checks = IsAuthenticated.checks + (require_agent_or_admin, ) -class IsAgentOrTask(IsAuthenticated): - """ - Only allow Ponos agents, tasks, and admins. - """ - checks = IsAuthenticated.checks + (require_agent_or_task, ) - - class IsAgentOrReadOnly(IsAgent): """ Restricts write access to Ponos agents and admins, @@ -64,24 +62,63 @@ class IsAssignedAgentOrReadOnly(IsAgentOrReadOnly): return super().has_object_permission(request, view, obj) -class IsAssignedAgentOrTaskOrReadOnly(ProcessACLMixin, IsAgentOrTask): +class IsAssignedAgentOrTaskOrReadOnly(IsAuthenticated): """ - Restricts write access to Ponos agents, Ponos tasks, and admins, and allows read access to anyone. - When checking object write permissions for a Ponos task, requires either a Ponos agent assigned to the task, - or authentication as the task itself. + Restricts write access to the endpoint to Ponos agents, Ponos tasks, or instance admins, + and allows read access to anyone. + + Checking object permissions only works with a Ponos task instance. It allows read access to anyone, + and restricts writes to the task itself, or the Ponos agent assigned to the obj. """ allow_safe_methods = True + checks = IsAuthenticated.checks + (require_agent_or_task, ) def has_object_permission(self, request, view, obj) -> bool: assert isinstance(obj, Task) + # Allow all reads if request.method in SAFE_METHODS: return True + # Allow the task to update itself if isinstance(request.auth, Task): return obj == request.auth - return (super().has_object_permission(request, view, obj) and obj.agent_id == request.user.id) + # Allow the task's agent to update the task + if getattr(request.user, "is_agent", False): + return obj.agent_id == request.user.id + + return False + + +class IsAssignedAgentOrTaskAdminOrReadOnly(IsAuthenticated): + """ + Restricts write access to the endpoint to Ponos agents, Ponos tasks, or instance admins, + and allows read access to anyone. + + Checking object permissions only works with a Ponos task instance. It allows read access to anyone, + and restricts writes to the task itself, the Ponos agent assigned to the task, instance admins and process admins. + """ + allow_safe_methods = True + checks = IsAuthenticated.checks + (require_agent_or_task_or_verified, ) + + def has_object_permission(self, request, view, obj) -> bool: + assert isinstance(obj, Task) + + # Allow all reads + if request.method in SAFE_METHODS: + return True + + # Allow the task to update itself + if isinstance(request.auth, Task): + return obj == request.auth + + # Allow the task's agent to update the task + if getattr(request.user, "is_agent", False): + return obj.agent_id == request.user.id + + # Allow admins on the process' corpus, which also allows instance admins + return obj.process.corpus.is_processable(request.user) class IsTaskAdmin(ProcessACLMixin, IsVerified): diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py index fd7a4883cda092b5fb0d5494321f3ded5e656406..ae35b00896d6a71aa0eea4122880766f517f1225 100644 --- a/arkindex/ponos/serializers.py +++ b/arkindex/ponos/serializers.py @@ -24,19 +24,23 @@ TaskGPUField = import_string(getattr(settings, "TASK_GPU_FIELD", None) or "arkin class TaskLightSerializer(serializers.ModelSerializer): """ Serializes a :class:`~arkindex.ponos.models.Task` instance without logs or agent information. - Used to list tasks inside a process. """ state = EnumField( State, help_text=dedent(""" - Allowed transitions for the state of a task by an agent are defined below: + Current state of the task. + + May be updated by regular users only from `running` to `stopping`, to request that a task be stopped. + + The allowed state transitions for Ponos agents are defined below: Pending ⟶ Running ⟶ Completed └⟶ Error ├⟶ Failed └⟶ Error Stopping ⟶ Stopped └⟶ Error - Slurm agents are also allowed to update state from Pending to Completed or Failed. + + Slurm agents are also allowed to update the state from Pending to Completed or Failed. """).strip(), ) @@ -66,15 +70,19 @@ class TaskLightSerializer(serializers.ModelSerializer): def validate_state(self, state): # Updates from a state to the same state is blocked to avoid side effects on finished tasks - allowed_transitions = { - State.Unscheduled: [State.Pending], - State.Pending: [State.Running, State.Error], - State.Running: [State.Completed, State.Failed, State.Error], - State.Stopping: [State.Stopped, State.Error], - } user = self.context["request"].user - if isinstance(user, Agent) and user.mode == AgentMode.Slurm: - allowed_transitions[State.Pending].extend([State.Completed, State.Failed]) + if isinstance(user, Agent): + allowed_transitions = { + State.Unscheduled: [State.Pending], + State.Pending: [State.Running, State.Error], + State.Running: [State.Completed, State.Failed, State.Error], + State.Stopping: [State.Stopped, State.Error], + } + if user.mode == AgentMode.Slurm: + allowed_transitions[State.Pending].extend([State.Completed, State.Failed]) + else: + allowed_transitions = {State.Running: [State.Stopping]} + if self.instance and state not in allowed_transitions.get(self.instance.state, []): raise ValidationError(f"Transition from state {self.instance.state} to state {state} is forbidden.") return state @@ -166,43 +174,6 @@ class TaskSerializer(TaskLightSerializer): return instance -class TaskTinySerializer(TaskSerializer): - """ - Serializes a :class:`~arkindex.ponos.models.Task` instance with only its state. - Used by humans to update a task. - """ - - state = EnumField( - State, - help_text="The state can only be updated from `running` to `stopping`, to manually stop a task.", - ) - - class Meta: - model = Task - fields = ("id", "state") - read_only_fields = ("id",) - - def validate_state(self, state): - """ - Only allow a user to manually stop a task - """ - if ( - self.instance - and self.instance.state != state - and (self.instance.state != State.Running or state != State.Stopping) - ): - raise ValidationError("State can only be updated from running to stopping") - return state - - def update(self, instance: Task, validated_data) -> Task: - new_state = validated_data.get("state") - - if new_state in FINAL_STATES and new_state != State.Completed: - task_failure.send_robust(self.__class__, task=instance) - - return super().update(instance, validated_data) - - class ArtifactSerializer(serializers.ModelSerializer): """ Serializes a :class:`~arkindex.ponos.models.Artifact` instance to allow diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index f38ad3a34d9f3ca1e64caf007b9176f3365165d7..a041d7a24433ba84df450f5491680816b7450643 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -3,7 +3,6 @@ import uuid from datetime import datetime, timedelta, timezone from io import BytesIO from itertools import combinations -from unittest import expectedFailure from unittest.mock import call, patch, seal from botocore.exceptions import ClientError @@ -262,10 +261,10 @@ class TestAPI(FixtureAPITestCase): def test_update_task_requires_login(self): with self.assertNumQueries(0): resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) - self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) def test_update_task_requires_verified(self): self.user.verified_email = False @@ -274,57 +273,66 @@ class TestAPI(FixtureAPITestCase): with self.assertNumQueries(2): resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) - def test_update_task_forbids_task(self): - with self.assertNumQueries(0): - resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), - data={"state": State.Stopping.value}, - HTTP_AUTHORIZATION=f"Ponos {self.task1.token}", - ) - self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) - - @expectedFailure - def test_update_task_requires_process_admin_corpus(self): - self.process.creator = self.superuser - self.process.save() + @patch("arkindex.users.utils.get_max_level") + def test_update_task_requires_process_admin_corpus(self, get_max_level_mock): self.corpus.public = False self.corpus.save() self.client.force_login(self.user) for role in [None, Role.Guest, Role.Contributor]: with self.subTest(role=role): - self.corpus.memberships.filter(user=self.user).delete() - if role: - self.corpus.memberships.create(user=self.user, level=role.value) + get_max_level_mock.return_value = role.value if role else None - with self.assertNumQueries(5): + with self.assertNumQueries(3): resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) - def test_update_running_task_state_stopping(self): + @patch("arkindex.project.aws.s3") + def test_update_running_task_state_stopping(self, s3_mock): + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")} + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + seal(s3_mock) + self.task1.state = State.Running self.task1.save() self.client.force_login(self.superuser) - with self.assertNumQueries(4): + with self.assertNumQueries(5): resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertDictEqual(resp.json(), { - "id": str(self.task1.id), - "state": State.Stopping.value, - }) + self.assertDictEqual( + resp.json(), + { + "id": str(self.task1.id), + "run": 0, + "depth": 0, + "slug": "initialisation", + "state": "stopping", + "parents": [], + "original_task_id": None, + "logs": "Failed successfully", + "full_log": "http://somewhere", + "extra_files": {}, + "agent": None, + "gpu": None, + "shm_size": None, + "requires_gpu": False, + }, + ) self.task1.refresh_from_db() self.assertEqual(self.task1.state, State.Stopping) @@ -341,13 +349,13 @@ class TestAPI(FixtureAPITestCase): with self.assertNumQueries(3): resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual( resp.json(), - {"state": ["State can only be updated from running to stopping"]}, + {"state": [f"Transition from state {state} to state Stopping is forbidden."]}, ) self.task1.refresh_from_db() @@ -355,7 +363,6 @@ class TestAPI(FixtureAPITestCase): def test_update_running_task_non_pending(self): states = list(State) - states.remove(State.Running) states.remove(State.Stopping) self.task1.state = State.Running self.task1.save() @@ -365,59 +372,19 @@ class TestAPI(FixtureAPITestCase): with self.subTest(state=state): with self.assertNumQueries(3): resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": state.value}, ) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual( resp.json(), - {"state": ["State can only be updated from running to stopping"]}, - ) - - def test_update_same_state(self): - states = list(State) - - self.client.force_login(self.superuser) - for state in states: - with self.subTest(state=state): - self.task1.state = state - self.task1.save() - resp = self.client.put( - reverse("api:task-update", args=[self.task1.id]), - data={"state": state.value}, + {"state": [f"Transition from state Running to state {state} is forbidden."]}, ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - - def test_partial_update_task_from_agent_requires_login(self): - with self.assertNumQueries(0): - resp = self.client.patch( - reverse("api:task-details", args=[self.task1.id]), - data={"state": State.Completed.value}, - ) - self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) - - def test_partial_update_task_from_agent_forbids_users(self): - self.client.force_login(self.user) - with self.assertNumQueries(2): - resp = self.client.patch( - reverse("api:task-details", args=[self.task1.id]), - data={"state": State.Completed.value}, - ) - self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) - - def test_partial_update_task_from_agent_forbids_task(self): - with self.assertNumQueries(1): - resp = self.client.patch( - reverse("api:task-details", args=[self.task1.id]), - data={"state": State.Completed.value}, - HTTP_AUTHORIZATION=f"Ponos {self.task1.token}", - ) - self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) @patch("arkindex.users.models.User.objects.get") @patch("arkindex.project.aws.s3") @patch("arkindex.ponos.serializers.notify_process_completion") - @patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple()) + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) def test_partial_update_task_from_docker_agent_allowed_transitions(self, notify_mock, s3_mock, get_user_mock): s3_mock.Object.return_value.bucket_name = "ponos" s3_mock.Object.return_value.key = "somelog" @@ -448,7 +415,7 @@ class TestAPI(FixtureAPITestCase): @patch("arkindex.users.models.User.objects.get") @patch("arkindex.project.aws.s3") @patch("arkindex.ponos.serializers.notify_process_completion") - @patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple()) + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) def test_partial_update_task_from_slurm_agent_allowed_transitions(self, notify_mock, s3_mock, get_user_mock): s3_mock.Object.return_value.bucket_name = "ponos" s3_mock.Object.return_value.key = "somelog" @@ -478,7 +445,7 @@ class TestAPI(FixtureAPITestCase): @patch("arkindex.users.models.User.objects.get") @patch("arkindex.project.aws.s3") - @patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple()) + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) def test_partial_update_task_from_docker_agent_forbidden_transitions(self, s3_mock, get_user_mock): s3_mock.Object.return_value.bucket_name = "ponos" s3_mock.Object.return_value.key = "somelog" @@ -513,7 +480,7 @@ class TestAPI(FixtureAPITestCase): @patch("arkindex.users.models.User.objects.get") @patch("arkindex.project.aws.s3") - @patch("arkindex.ponos.api.TaskDetailsFromAgent.permission_classes", tuple()) + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) def test_partial_update_task_from_slurm_agent_forbidden_transitions(self, s3_mock, get_user_mock): s3_mock.Object.return_value.bucket_name = "ponos" s3_mock.Object.return_value.key = "somelog" @@ -550,10 +517,10 @@ class TestAPI(FixtureAPITestCase): def test_partial_update_task_requires_login(self): with self.assertNumQueries(0): resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) - self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) def test_partial_update_task_requires_verified(self): self.user.verified_email = False @@ -562,40 +529,27 @@ class TestAPI(FixtureAPITestCase): with self.assertNumQueries(2): resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) - def test_partial_update_task_forbids_task(self): - with self.assertNumQueries(0): - resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), - data={"state": State.Stopping.value}, - HTTP_AUTHORIZATION=f"Ponos {self.task1.token}", - ) - self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) - - @expectedFailure - def test_partial_update_task_requires_process_admin_corpus(self): - self.process.creator = self.superuser - self.process.save() + @patch("arkindex.users.utils.get_max_level") + def test_partial_update_task_requires_process_admin_corpus(self, get_max_level_mock): self.corpus.public = False self.corpus.save() self.client.force_login(self.user) for role in [None, Role.Guest, Role.Contributor]: with self.subTest(role=role): - self.corpus.memberships.filter(user=self.user).delete() - if role: - self.corpus.memberships.create(user=self.user, level=role.value) + get_max_level_mock.return_value = role.value if role else None - with self.assertNumQueries(5): + with self.assertNumQueries(3): resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) - self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN, resp.json()) def test_partial_update_non_running_task_state_stopping(self): states = list(State) @@ -610,13 +564,13 @@ class TestAPI(FixtureAPITestCase): with self.assertNumQueries(3): resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": State.Stopping.value}, ) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual( resp.json(), - {"state": ["State can only be updated from running to stopping"]}, + {"state": [f"Transition from state {state} to state Stopping is forbidden."]}, ) self.task1.refresh_from_db() @@ -624,7 +578,6 @@ class TestAPI(FixtureAPITestCase): def test_partial_update_running_task_non_pending(self): states = list(State) - states.remove(State.Running) states.remove(State.Stopping) self.task1.state = State.Running self.task1.save() @@ -634,29 +587,15 @@ class TestAPI(FixtureAPITestCase): with self.subTest(state=state): with self.assertNumQueries(3): resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), + reverse("api:task-details", args=[self.task1.id]), data={"state": state.value}, ) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual( resp.json(), - {"state": ["State can only be updated from running to stopping"]}, + {"state": [f"Transition from state Running to state {state} is forbidden."]}, ) - def test_partial_update_same_state(self): - states = list(State) - - self.client.force_login(self.superuser) - for state in states: - with self.subTest(state=state): - self.task1.state = state - self.task1.save() - resp = self.client.patch( - reverse("api:task-update", args=[self.task1.id]), - data={"state": state.value}, - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - @patch("arkindex.project.aws.s3") def test_task_logs_unicode_error(self, s3_mock): """ diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index aab4bb37f5d0ab1382b966088e819e4f16b7c57b..938af5179658db05049d0d79fb11e85298471961 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -56,7 +56,7 @@ from arkindex.documents.api.ml import ( ) from arkindex.documents.api.search import CorpusSearch, SearchIndexBuild from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve -from arkindex.ponos.api import TaskArtifactDownload, TaskArtifacts, TaskDetailsFromAgent, TaskRestart, TaskUpdate +from arkindex.ponos.api import TaskArtifactDownload, TaskArtifacts, TaskDetails, TaskRestart from arkindex.process.api import ( ApplyProcessTemplate, BucketList, @@ -307,12 +307,7 @@ api = [ path("openapi/", OpenApiSchemaView.as_view(), name="openapi-schema"), # Ponos - path("task/<uuid:pk>/", TaskUpdate.as_view(), name="task-update"), - path( - "task/<uuid:pk>/from-agent/", - TaskDetailsFromAgent.as_view(), - name="task-details", - ), + path("task/<uuid:pk>/", TaskDetails.as_view(), name="task-details"), path( "task/<uuid:pk>/artifacts/", TaskArtifacts.as_view(), name="task-artifacts" ),