diff --git a/.gitignore b/.gitignore index 394e382c96578fc8e53820a373af9a91479ec0c4..7cde2291adad29a847d6444d485d04730962cc73 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,6 @@ workers local_settings.py .coverage htmlcov -ponos *.key arkindex/config.yml test-report.xml diff --git a/arkindex/documents/fixtures/data.json b/arkindex/documents/fixtures/data.json index afec6b200f5288e91efbbfa7d288506ab37e844b..1f63c8fe43c51bc2a7afda9ffc41e89f744dd328 100644 --- a/arkindex/documents/fixtures/data.json +++ b/arkindex/documents/fixtures/data.json @@ -3777,6 +3777,7 @@ "updated": "2020-02-02T01:23:45.678Z", "expiry": "2050-03-03T01:23:45.678Z", "extra_files": "{}", + "token": "MjViMzE5NDQtNzc2YS00YThjLWE1YWUtY2RhYTY2ZmE0OTIzCg==", "parents": [] } }, diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py index 2bbe8df38c473d6dd21c2b8aa9555e197e40ce49..79593e4ac550658834509bcd703b0bbf99c0db93 100644 --- a/arkindex/ponos/api.py +++ b/arkindex/ponos/api.py @@ -26,7 +26,7 @@ from arkindex.ponos.models import Agent, Artifact, Farm, Secret, State, Task, Wo from arkindex.ponos.permissions import ( IsAgent, IsAgentOrArtifactAdmin, - IsAgentOrInternal, + IsAgentOrTask, IsAgentOrTaskAdmin, IsAssignedAgentOrReadOnly, ) @@ -411,10 +411,12 @@ class TaskUpdate(UpdateAPIView): ) class SecretDetails(RetrieveAPIView): """ - Retrieve a Ponos secret content as cleartext + Retrieve a Ponos secret content as cleartext. + + Requires authentication as an internal user, a Ponos agent or a Ponos task. """ - permission_classes = (IsAgentOrInternal, ) + permission_classes = (IsAgentOrTask, ) serializer_class = ClearTextSecretSerializer def get_object(self): diff --git a/arkindex/ponos/authentication.py b/arkindex/ponos/authentication.py index 8eaa6bf008d9abd3500dbf5bdf3b538fb62e9bfc..58455c0b212f281121aac0d7b4bd1bad0bc338cf 100644 --- a/arkindex/ponos/authentication.py +++ b/arkindex/ponos/authentication.py @@ -1,7 +1,10 @@ +from django.core.exceptions import ObjectDoesNotExist +from drf_spectacular.authentication import TokenScheme from drf_spectacular.contrib.rest_framework_simplejwt import SimpleJWTScheme +from rest_framework.authentication import TokenAuthentication from rest_framework.exceptions import AuthenticationFailed -from arkindex.ponos.models import Agent +from arkindex.ponos.models import Agent, Task from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.exceptions import InvalidToken from rest_framework_simplejwt.settings import api_settings @@ -68,3 +71,40 @@ class AgentAuthenticationExtension(SimpleJWTScheme): target_class = "arkindex.ponos.authentication.AgentAuthentication" name = "agentAuth" + + +class TaskAuthentication(TokenAuthentication): + keyword = 'Ponos' + model = Task + + def authenticate_credentials(self, key): + try: + task = Task.objects.select_related('workflow__process__creator').get(token=key) + except Task.DoesNotExist: + # Same error message as the standard TokenAuthentication + raise AuthenticationFailed('Invalid token.') + + # There is no Workflow.process_id, since the FK is on Process.workflow_id, + # and accessing Workflow.process when there is no process causes an exception + # instead of returning None. + try: + process = task.workflow.process + except ObjectDoesNotExist: + raise AuthenticationFailed('Task has no process.') + + if not process.creator_id or not process.creator.is_active: + # Same error message as the standard TokenAuthentication + raise AuthenticationFailed('User inactive or deleted.') + + # Must return a 2-tuple that will be set as (self.request.user, self.request.auth) + return (process.creator, task) + + +class TaskAuthenticationExtension(TokenScheme): + target_class = "arkindex.ponos.authentication.TaskAuthentication" + name = "taskAuth" + # The TokenScheme has a priority of -1 and matches both TokenAuthentication and its subclasses; + # we set the priority to a higher number to make this extension match first, and disable + # subclass matching so that this only applies to the TaskAuthentication. + priority = 0 + match_subclasses = False diff --git a/arkindex/ponos/migrations/0037_task_token.py b/arkindex/ponos/migrations/0037_task_token.py new file mode 100644 index 0000000000000000000000000000000000000000..bb11372fc4392f47d87393344bb5c11ec4611a23 --- /dev/null +++ b/arkindex/ponos/migrations/0037_task_token.py @@ -0,0 +1,38 @@ +# Generated by Django 4.1.7 on 2023-03-07 15:19 + +from django.db import migrations, models + +from arkindex.ponos.models import task_token_default + + +def add_task_tokens(apps, schema_editor): + Task = apps.get_model('ponos', 'Task') + to_update = [] + for task in Task.objects.filter(token=None).only('id').iterator(): + task.token = task_token_default() + to_update.append(task) + Task.objects.bulk_update(to_update, ['token'], batch_size=100) + + +class Migration(migrations.Migration): + + dependencies = [ + ('ponos', '0036_hstore_task_env_and_extra_files'), + ] + + operations = [ + migrations.AddField( + model_name='task', + name='token', + field=models.CharField( + max_length=52, + # Make the field temporarily nullable and not unique, so that we can + # fill the tokens on existing tasks before adding the constraints. + null=True, + ), + ), + migrations.RunPython( + add_task_tokens, + reverse_code=migrations.RunPython.noop, + ), + ] diff --git a/arkindex/ponos/migrations/0038_task_token_unique.py b/arkindex/ponos/migrations/0038_task_token_unique.py new file mode 100644 index 0000000000000000000000000000000000000000..5de95c85ee7afba67531028d0fd1fb9eaac1c586 --- /dev/null +++ b/arkindex/ponos/migrations/0038_task_token_unique.py @@ -0,0 +1,24 @@ +# Generated by Django 4.1.7 on 2023-03-07 15:25 + +from django.db import migrations, models + +from arkindex.ponos.models import task_token_default + + +class Migration(migrations.Migration): + + dependencies = [ + ('ponos', '0037_task_token'), + ] + + operations = [ + migrations.AlterField( + model_name='task', + name='token', + field=models.CharField( + default=task_token_default, + max_length=52, + unique=True, + ), + ), + ] diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py index f6a9cd2d6259580f10d1603fff642b593695d352..2893f67cd526ace99f0ed1807ba1d6983788218a 100644 --- a/arkindex/ponos/models.py +++ b/arkindex/ponos/models.py @@ -1,3 +1,4 @@ +import base64 import logging import os.path import random @@ -415,8 +416,15 @@ class Workflow(models.Model): ) # Create tasks without any parent - tasks = { - slug: self.tasks.create( + tasks = {} + for slug, recipe in self.recipes.items(): + # Add the task token to the environment now, as higher-level code cannot add a token + # when building workflow recipes since the Task instances do not exist. + env = recipe.environment.copy() + token = task_token_default() + env['ARKINDEX_TASK_TOKEN'] = token + + tasks[slug] = self.tasks.create( run=run, slug=slug, tags=recipe.tags, @@ -424,16 +432,15 @@ class Workflow(models.Model): image=recipe.image, command=recipe.command, shm_size=recipe.shm_size, - env=recipe.environment, + env=env, has_docker_socket=recipe.has_docker_socket, image_artifact=Artifact.objects.get(id=recipe.artifact) if recipe.artifact else None, requires_gpu=recipe.requires_gpu, extra_files=recipe.extra_files if recipe.extra_files else {}, + token=token, ) - for slug, recipe in self.recipes.items() - } # Apply parents for slug, recipe in self.recipes.items(): @@ -605,6 +612,15 @@ def expiry_default(): return timezone.now() + timedelta(days=30) +def task_token_default(): + """ + Default value for Task.token. + + :rtype: str + """ + return base64.encodebytes(uuid.uuid4().bytes + uuid.uuid4().bytes).strip().decode('utf-8') + + class Task(models.Model): """ A task created from a workflow's recipe. @@ -670,6 +686,13 @@ class Task(models.Model): # Remote files required to start the container extra_files = HStoreField(default=dict) + token = models.CharField( + default=task_token_default, + # The token generation always returns 52 characters + max_length=52, + unique=True, + ) + objects = TaskManager() class Meta: diff --git a/arkindex/ponos/permissions.py b/arkindex/ponos/permissions.py index 58d8de04e741535d94c92d528e3e0dc060a111fc..84ca2a9009ac22a08ccdff9be0a3d1b0b5e0931a 100644 --- a/arkindex/ponos/permissions.py +++ b/arkindex/ponos/permissions.py @@ -5,19 +5,28 @@ from arkindex.project.mixins import CorpusACLMixin from arkindex.project.permissions import IsAuthenticated, require_internal -def require_agent(request, view): +def require_agent_or_admin(request, view): return getattr(request.user, 'is_admin', False) or getattr(request.user, 'is_agent', False) -def require_agent_or_internal(request, view): - return require_internal(request, view) or getattr(request.user, 'is_agent', False) +def require_task(request, view): + # For backwards compatibility, internal users are considered to be authenticated as a Ponos task. + # TODO: Remove the internal check once APIs should be restricted to the new authentication + return isinstance(request.auth, Task) or require_internal(request, view) + + +def require_agent_or_task(request, view): + return ( + getattr(request.user, 'is_agent', False) + or require_task(request, view) + ) class IsAgent(IsAuthenticated): """ Only allow Ponos agents and admins. """ - checks = IsAuthenticated.checks + (require_agent, ) + checks = IsAuthenticated.checks + (require_agent_or_admin, ) class IsAgentOrReadOnly(IsAgent): @@ -59,7 +68,7 @@ class IsAgentOrTaskAdmin(CorpusACLMixin, IsAuthenticated): self.request = request return ( - require_agent(request, view) + require_agent_or_admin(request, view) or require_internal(request, view) or ( task.workflow.process is not None @@ -79,8 +88,12 @@ class IsAgentOrArtifactAdmin(IsAgentOrTaskAdmin): return super().has_object_permission(request, view, artifact.task) -class IsAgentOrInternal(IsAuthenticated): +class IsTask(IsAuthenticated): + checks = (require_task, ) + + +class IsAgentOrTask(IsAuthenticated): """ - Allow access to agents or internal users, and not admins. + Allow access to Ponos agents or tasks. """ - checks = (require_agent_or_internal, ) + checks = (require_agent_or_task, ) diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py index 1f8fa91b0e5c44101ce965a48f6820646a7ef8e6..655bd6daf148da9cfb3715069301e626e8eb8de1 100644 --- a/arkindex/ponos/serializers.py +++ b/arkindex/ponos/serializers.py @@ -22,6 +22,7 @@ from arkindex.ponos.models import ( State, Task, Workflow, + task_token_default, ) from arkindex.ponos.serializer_fields import Base64Field, PublicKeyField from arkindex.ponos.signals import task_failure @@ -614,6 +615,11 @@ class NewTaskSerializer(serializers.ModelSerializer): data["depth"] = max(parent.depth for parent in parents) + 1 + # Set the task token manually so that we can immediately copy it to the environment variables, + # just like what is done in Workflow.build_tasks() + data["token"] = task_token_default() + data["env"]['ARKINDEX_TASK_TOKEN'] = data['token'] + return super().validate(data) diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index 629455cc5e18b16b66f466f67c918a9568f3d815..50b9f1573a9adfabafaff3361553f2bab78abbde 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -18,9 +18,11 @@ from django.utils import timezone from rest_framework import status from rest_framework.test import APITestCase +from arkindex.documents.models import Corpus from arkindex.ponos.api import timezone as api_tz from arkindex.ponos.authentication import AgentUser from arkindex.ponos.models import FINAL_STATES, GPU, Agent, Farm, Secret, State, Task, Workflow, encrypt +from arkindex.process.models import Process, ProcessMode from arkindex.project.tools import build_public_key from arkindex.users.models import User from rest_framework_simplejwt.tokens import AccessToken, RefreshToken @@ -52,7 +54,6 @@ def str_date(d): class TestAPI(APITestCase): @classmethod def setUpTestData(cls): - super().setUpTestData() cls.farm = Farm.objects.create(name="Wheat farm") pubkey = build_public_key() @@ -359,7 +360,11 @@ class TestAPI(APITestCase): 'image': 'hello-world', 'command': None, 'shm_size': None, - 'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'}, + 'env': { + 'test': 'test_workflow', + 'top_env_variable': 'workflow_variable', + 'ARKINDEX_TASK_TOKEN': self.task1.token, + }, 'has_docker_socket': False, 'image_artifact_url': None, 'agent_id': None, @@ -402,7 +407,11 @@ class TestAPI(APITestCase): 'image': 'hello-world', 'command': None, 'shm_size': None, - 'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'}, + 'env': { + 'test': 'test_workflow', + 'top_env_variable': 'workflow_variable', + 'ARKINDEX_TASK_TOKEN': self.task1.token, + }, 'has_docker_socket': False, 'image_artifact_url': 'http://testserver' + reverse( 'api:task-artifact-download', @@ -436,7 +445,11 @@ class TestAPI(APITestCase): 'image': 'hello-world', 'command': None, 'shm_size': '128g', - 'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'}, + 'env': { + 'test': 'test_workflow', + 'top_env_variable': 'workflow_variable', + 'ARKINDEX_TASK_TOKEN': self.task1.token, + }, 'has_docker_socket': False, 'image_artifact_url': None, 'agent_id': None, @@ -1629,24 +1642,27 @@ class TestAPI(APITestCase): image="registry.gitlab.com/test", ) - response = self.client.post( - reverse("api:task-create"), - data={ - "workflow_id": str(self.workflow.id), - "slug": "test_task", - "image": "registry.gitlab.com/test", - "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], - "command": "echo Test", - "env": {"test": "test", "test2": "test2"}, - }, - format="json", - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - data = response.json() - del data["id"] + with self.assertNumQueries(12): + response = self.client.post( + reverse("api:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "image": "registry.gitlab.com/test", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + "command": "echo Test", + "env": {"test": "test", "test2": "test2"}, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + new_task = self.workflow.tasks.get(slug='test_task') + self.assertDictEqual( - data, + response.json(), { + "id": str(new_task.id), "workflow_id": str(self.workflow.id), "slug": "test_task", "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], @@ -1656,6 +1672,7 @@ class TestAPI(APITestCase): "test": "test", "top_env_variable": "workflow_variable", "test2": "test2", + "ARKINDEX_TASK_TOKEN": new_task.token, }, "run": 0, "depth": 4, @@ -1671,25 +1688,28 @@ class TestAPI(APITestCase): image="registry.gitlab.com/test", ) - response = self.client.post( - reverse("api:task-create"), - data={ - "workflow_id": str(self.workflow.id), - "slug": "test_task", - "image": "registry.gitlab.com/test", - "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], - "command": "echo Test", - "env": {"test": "test", "test2": "test2"}, - "has_docker_socket": True, - }, - format="json", - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - data = response.json() - del data["id"] + with self.assertNumQueries(12): + response = self.client.post( + reverse("api:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "image": "registry.gitlab.com/test", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + "command": "echo Test", + "env": {"test": "test", "test2": "test2"}, + "has_docker_socket": True, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + new_task = self.workflow.tasks.get(slug='test_task') + self.assertDictEqual( - data, + response.json(), { + "id": str(new_task.id), "workflow_id": str(self.workflow.id), "slug": "test_task", "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], @@ -1699,6 +1719,7 @@ class TestAPI(APITestCase): "test": "test", "top_env_variable": "workflow_variable", "test2": "test2", + "ARKINDEX_TASK_TOKEN": new_task.token, }, "run": 0, "depth": 4, @@ -1708,7 +1729,7 @@ class TestAPI(APITestCase): def test_retrieve_secret_requires_auth(self): """ - Only agents may access a secret details + Only agents or tasks may access a secret details """ response = self.client.get( reverse("api:secret-details", kwargs={"name": "abc"}) @@ -1743,11 +1764,104 @@ class TestAPI(APITestCase): content=encrypt(b"1337" * 4, "1337$"), ) self.assertEqual(secret.content, b"\xc1\x81\xc0\xceo") - response = self.client.get( - reverse("api:secret-details", kwargs={"name": account_name}), - HTTP_AUTHORIZATION=f'Bearer {self.agent.token.access_token}', + + with self.assertNumQueries(2): + response = self.client.get( + reverse("api:secret-details", kwargs={"name": account_name}), + HTTP_AUTHORIZATION=f'Bearer {self.agent.token.access_token}', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertDictEqual( + response.json(), + { + "id": str(secret.id), + "name": account_name, + "content": "1337$", + }, + ) + + @override_settings(PONOS_PRIVATE_KEY=PONOS_PRIVATE_KEY) + def test_task_retrieve_secret_requires_process(self): + account_name = "bank_account/0001/private" + secret = Secret.objects.create( + name=account_name, + nonce=b"1337" * 4, + content=encrypt(b"1337" * 4, "1337$"), + ) + self.assertEqual(secret.content, b"\xc1\x81\xc0\xceo") + + with self.assertNumQueries(1): + response = self.client.get( + reverse("api:secret-details", kwargs={"name": account_name}), + HTTP_AUTHORIZATION=f'Ponos {self.task1.token}', + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + self.assertDictEqual( + response.json(), + {'detail': 'Task has no process.'}, + ) + + @override_settings(PONOS_PRIVATE_KEY=PONOS_PRIVATE_KEY) + def test_task_retrieve_secret_requires_active_user(self): + account_name = "bank_account/0001/private" + secret = Secret.objects.create( + name=account_name, + nonce=b"1337" * 4, + content=encrypt(b"1337" * 4, "1337$"), + ) + self.assertEqual(secret.content, b"\xc1\x81\xc0\xceo") + + Process.objects.create( + creator=self.user, + mode=ProcessMode.Repository, + workflow=self.workflow, + ) + self.user.is_active = False + self.user.save() + + with self.assertNumQueries(1): + response = self.client.get( + reverse("api:secret-details", kwargs={"name": account_name}), + HTTP_AUTHORIZATION=f'Ponos {self.task1.token}', + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + self.assertDictEqual( + response.json(), + {'detail': 'User inactive or deleted.'}, + ) + + @override_settings(PONOS_PRIVATE_KEY=PONOS_PRIVATE_KEY) + def test_task_retrieve_secret(self): + account_name = "bank_account/0001/private" + secret = Secret.objects.create( + name=account_name, + nonce=b"1337" * 4, + content=encrypt(b"1337" * 4, "1337$"), + ) + self.assertEqual(secret.content, b"\xc1\x81\xc0\xceo") + + internal_user = User.objects.create_user( + 'internal@internal.fr', + 'Pa$$w0rd', + internal=True, ) - self.assertEqual(response.status_code, status.HTTP_200_OK) + corpus = Corpus.objects.create(name='Test corpus') + corpus.processes.create( + mode=ProcessMode.Workers, + creator=internal_user, + workflow=self.workflow, + ) + + with self.assertNumQueries(2): + response = self.client.get( + reverse("api:secret-details", kwargs={"name": account_name}), + HTTP_AUTHORIZATION=f'Ponos {self.task1.token}', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual( response.json(), { diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 643b309c838b36d20f1be300b18bb6c2f25ccad8..b05345440844f08717fd3ca6f2ffbf149da78c83 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -39,6 +39,7 @@ from rest_framework.views import APIView from arkindex.documents.models import Corpus, Element from arkindex.ponos.models import STATES_ORDERING, State +from arkindex.ponos.permissions import IsTask from arkindex.process.models import ( ActivityState, DataFile, @@ -100,7 +101,7 @@ from arkindex.project.mixins import ( WorkerACLMixin, ) from arkindex.project.pagination import CustomCursorPagination -from arkindex.project.permissions import IsInternal, IsVerified, IsVerifiedOrReadOnly +from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.project.tools import PercentileCont, RTrimChr from arkindex.project.triggers import process_delete from arkindex.training.models import ModelVersionState @@ -1475,10 +1476,10 @@ class ListProcessElements(CorpusACLMixin, ListAPIView): class UpdateWorkerActivity(GenericAPIView): """ - Makes a worker (internal user) able to update its activity on an element + Allow a Ponos task or an internal user to update an element's state Only allow defined evolutions of the element's state """ - permission_classes = (IsInternal, ) + permission_classes = (IsTask, ) serializer_class = WorkerActivitySerializer queryset = WorkerActivity.objects.none() @@ -1511,7 +1512,7 @@ class UpdateWorkerActivity(GenericAPIView): operation_id='UpdateWorkerActivity', description=( 'Updates the activity of a worker version on an element.\n\n' - 'The user must be **internal** to perform this request.\n\n' + 'The user must be **internal** or a Ponos task to perform this request.\n\n' 'A **HTTP_409_CONFLICT** is returned in case the body is valid but the update failed.' ), ) diff --git a/arkindex/process/serializers/workers.py b/arkindex/process/serializers/workers.py index 7b3a21c8df3a4f940c6bd4c67de64da78451daad..1af738412458b525b8edd2a1188538b569a2c0a7 100644 --- a/arkindex/process/serializers/workers.py +++ b/arkindex/process/serializers/workers.py @@ -8,6 +8,7 @@ from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from rest_framework.exceptions import ValidationError +from arkindex.ponos.models import Task from arkindex.process.models import ( Process, Repository, @@ -302,6 +303,12 @@ class WorkerActivitySerializer(serializers.ModelSerializer): 'worker_version_id', ) + def validate_process_id(self, process): + request = self.context.get('request') + if request and isinstance(request.auth, Task) and process.workflow_id != request.auth.workflow_id: + raise serializers.ValidationError('Only WorkerActivities for the process of the currently authenticated task can be updated.') + return process + class WorkerConfigurationListSerializer(serializers.ModelSerializer): configuration = serializers.DictField(allow_empty=False) diff --git a/arkindex/process/tests/test_create_s3_import.py b/arkindex/process/tests/test_create_s3_import.py index 604014e34cbde7ff8347b57bc733d54b68cdda34..2a7ac6e8741c97b540f8f6038ece21748e965360 100644 --- a/arkindex/process/tests/test_create_s3_import.py +++ b/arkindex/process/tests/test_create_s3_import.py @@ -167,6 +167,7 @@ class TestCreateS3Import(FixtureTestCase): self.assertDictEqual(task.env, { 'ARKINDEX_CORPUS_ID': str(self.corpus.id), 'ARKINDEX_PROCESS_ID': str(process.id), + 'ARKINDEX_TASK_TOKEN': task.token, 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id), 'INGEST_S3_ENDPOINT': 'http://s3.null.teklia.com', 'INGEST_S3_ACCESS_KEY': '🔑', @@ -224,6 +225,7 @@ class TestCreateS3Import(FixtureTestCase): self.assertDictEqual(task.env, { 'ARKINDEX_CORPUS_ID': str(self.corpus.id), 'ARKINDEX_PROCESS_ID': str(process.id), + 'ARKINDEX_TASK_TOKEN': task.token, 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id), 'INGEST_S3_ENDPOINT': 'http://s3.null.teklia.com', 'INGEST_S3_ACCESS_KEY': '🔑', diff --git a/arkindex/process/tests/test_create_training_process.py b/arkindex/process/tests/test_create_training_process.py index e334f082138d717e480fb6e530e8be84d207b477..1e0e42b8aa2ff86304ee77cb82b6432b91cacbf4 100644 --- a/arkindex/process/tests/test_create_training_process.py +++ b/arkindex/process/tests/test_create_training_process.py @@ -329,9 +329,11 @@ class TestCreateTrainingProcess(FixtureTestCase): self.assertEqual(sorted(task.env.keys()), [ 'ARKINDEX_CORPUS_ID', 'ARKINDEX_PROCESS_ID', + 'ARKINDEX_TASK_TOKEN', 'ARKINDEX_WORKER_RUN_ID', 'WORKER_VERSION_ID', ]) + self.assertEqual(task.env['ARKINDEX_TASK_TOKEN'], task.token) self.assertEqual(task.requires_gpu, False) # Check worker run properties @@ -410,10 +412,12 @@ class TestCreateTrainingProcess(FixtureTestCase): self.assertEqual(sorted(task.env.keys()), [ 'ARKINDEX_CORPUS_ID', 'ARKINDEX_PROCESS_ID', + 'ARKINDEX_TASK_TOKEN', 'ARKINDEX_WORKER_RUN_ID', 'WORKER_VERSION_ID', ]) self.assertEqual(task.requires_gpu, True) + self.assertEqual(task.env['ARKINDEX_TASK_TOKEN'], task.token) # Check worker run properties self.assertEqual(str(training_process.worker_runs.get().id), task.env['ARKINDEX_WORKER_RUN_ID']) worker_run = WorkerRun.objects.get(id=task.env['ARKINDEX_WORKER_RUN_ID']) diff --git a/arkindex/process/tests/test_workeractivity.py b/arkindex/process/tests/test_workeractivity.py index 0912ef0eefa42333f280b04273df62a5920fb30e..e68787395b1ae9556ada812a7554f729454b3d3b 100644 --- a/arkindex/process/tests/test_workeractivity.py +++ b/arkindex/process/tests/test_workeractivity.py @@ -7,6 +7,7 @@ from django.urls import reverse from rest_framework import status from arkindex.documents.models import Corpus, Element +from arkindex.ponos.models import Farm, Workflow from arkindex.process.models import ( ActivityState, Process, @@ -201,6 +202,65 @@ class TestWorkerActivity(FixtureTestCase): ) self.assertEqual(response.status_code, status_code) + def test_put_activity_task(self): + """ + Ponos task authentication can also be used to update a WorkerActivity. + Tasks can update any WorkerActivity in their process, without any filter on the WorkerVersion, + since there is no link between a WorkerVersion and a Task yet. + """ + self.process.start() + task = self.process.workflow.tasks.get(slug=self.worker_version.slug) + + with self.assertNumQueries(4): + response = self.client.put( + reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}), + { + 'element_id': str(self.element.id), + 'process_id': str(self.process.id), + 'state': WorkerActivityState.Started.value, + }, + content_type='application/json', + HTTP_AUTHORIZATION=f'Ponos {task.token}', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.activity.refresh_from_db() + self.assertEqual(self.activity.state, WorkerActivityState.Started) + + def test_put_activity_task_other_process(self): + """ + When using Ponos task authentication, updating a WorkerActivity of another process should be forbidden. + Unlike the usual HTTP 409, we can raise a HTTP 400 here as the check requires no extra query and can be done + in the serializer before the UPDATE query that normally updates any WorkerActivity if it exists. + """ + process2 = Process.objects.create(mode=ProcessMode.Repository, creator=self.user) + process2.workflow = Workflow.objects.create( + farm=Farm.objects.first(), + recipe="{tasks: {test: {image: hello-world}}}", + ) + process2.save() + task = process2.workflow.build_tasks()['test'] + + with self.assertNumQueries(2): + response = self.client.put( + reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}), + { + 'element_id': str(self.element.id), + 'process_id': str(self.process.id), + 'state': WorkerActivityState.Started.value, + }, + content_type='application/json', + HTTP_AUTHORIZATION=f'Ponos {task.token}', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), { + 'process_id': ['Only WorkerActivities for the process of the currently authenticated task can be updated.'], + }) + + self.activity.refresh_from_db() + self.assertEqual(self.activity.state, WorkerActivityState.Queued) + def test_put_activity_wrong_worker_version(self): """ Raises a generic error in case the worker version does not exists because a single SQL request is performed @@ -217,7 +277,7 @@ class TestWorkerActivity(FixtureTestCase): }, content_type='application/json', ) - self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) + self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) self.assertDictEqual(response.json(), { '__all__': [ 'Either this worker activity does not exist, is assigned to another process, or ' diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index b5e354e42b09227551003cbe2c025cc5d00a1013..820a6c4de196c0b75dbfc0f1ba3b43e47e9e5113 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -202,6 +202,7 @@ REST_FRAMEWORK = { 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.TokenAuthentication', 'arkindex.ponos.authentication.AgentAuthentication', + 'arkindex.ponos.authentication.TaskAuthentication', ), 'DEFAULT_PAGINATION_CLASS': 'arkindex.project.pagination.PageNumberPagination', 'DEFAULT_SCHEMA_CLASS': 'arkindex.project.openapi.AutoSchema', diff --git a/arkindex/project/tests/openapi/test_schema.py b/arkindex/project/tests/openapi/test_schema.py index 1dd19e06d262faafdbf18951997c2f46e22a2b93..ead901bae65971f8a0bcaa38072f54f8d7149c90 100644 --- a/arkindex/project/tests/openapi/test_schema.py +++ b/arkindex/project/tests/openapi/test_schema.py @@ -56,6 +56,7 @@ class TestAutoSchema(TestCase): {'cookieAuth': []}, {'tokenAuth': []}, {'agentAuth': []}, + {'taskAuth': []}, # Allows no authentication too {}, ],