From 79637973c2b19618214de8a8f982d4d233347336 Mon Sep 17 00:00:00 2001 From: Valentin Rigal <rigal@teklia.com> Date: Thu, 26 Oct 2023 14:38:35 +0000 Subject: [PATCH] Create a worker version without a revision --- arkindex/process/api.py | 100 ++++++---------- arkindex/process/serializers/workers.py | 42 ++++++- arkindex/process/tests/test_workers.py | 151 +++++++++++++++--------- 3 files changed, 170 insertions(+), 123 deletions(-) diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 20bbb748b8..b76556470b 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -1048,19 +1048,20 @@ class WorkerTypesList(ListAPIView): ], ), post=extend_schema( - description=( - 'Create a version for a given worker ID, revision ID and JSON configuration.\n\n' - 'The user must be **internal** to perform this request.' - ) - ), + description=dedent(""" + Create a new version for a worker. + + Authentication can be done: + * Using a user authentication (via a cookie or token). + The user must have an administrator access to the worker. + * Using a ponos task authentication. + The worker must be linked to a repository. + + The `revision_id` parameter must be set for workers linked to a repository only. + """) + ) ) class WorkerVersionList(WorkerACLMixin, ListCreateAPIView): - """ - List versions, their revision and associated git references for a given worker UUID - - Create a WorkerVersion instance for a worker and a revision with a given JSON configuration, - return existing query if a workerVersion already exists for this worker and this revision - """ permission_classes = (IsVerified, ) serializer_class = WorkerVersionCreateSerializer queryset = WorkerVersion.objects.none() @@ -1069,75 +1070,42 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView): def simple_mode(self): return self.request.query_params.get('mode', 'complete').lower() == 'simple' - def get_queryset(self): + @cached_property + def worker(self): worker = get_object_or_404( Worker.objects.select_related('repository'), pk=self.kwargs['pk'] ) - - if not self.has_execution_access(worker): + if self.request.method in permissions.SAFE_METHODS and not self.has_execution_access(worker): raise PermissionDenied(detail='You do not have an execution access to this worker.') + if ( + self.request.method not in permissions.SAFE_METHODS + # Either a task authentication or an admin access is required for creation + and not isinstance(self.request.auth, Task) + and not self.has_admin_access(worker) + ): + raise PermissionDenied(detail='You do not have an admin access to this worker.') + return worker + def get_serializer_context(self): + context = super().get_serializer_context() + context['worker'] = self.worker + return context + + def get_queryset(self): filters = Q() if self.simple_mode: # Limit output to versions with tags or master/main branches - filters = Q(revision__refs__type=GitRefType.Tag) | Q(revision__refs__type=GitRefType.Branch, revision__refs__name__in=["master", "main"]) - - return worker.versions \ + filters = ( + Q(revision__refs__type=GitRefType.Tag) + | Q(revision__refs__type=GitRefType.Branch, revision__refs__name__in=["master", "main"]) + ) + return self.worker.versions \ .filter(filters) \ .prefetch_related('revision__repo', 'revision__refs', 'worker__type') \ .distinct() \ .order_by('-revision__created') - def create(self, request, *args, **kwargs): - if not isinstance(request.auth, Task): - raise PermissionDenied(detail='Only Ponos tasks can create a new version from a worker.') - - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - - # We need to use the default database to avoid stale read - # when creating a worker version on a newly created worker - worker = get_object_or_404( - Worker.objects.using('default'), - pk=self.kwargs['pk'] - ) - - revision = serializer.validated_data.get('revision_id') - if revision is None and worker.repository is not None: - raise ValidationError({ - 'revision_id': ['The revision must be set on a worker that is linked to a repository.'] - }) - elif revision and worker.repository is None: - raise ValidationError({ - 'revision_id': ['The revision cannot be set on a worker that is not linked to a repository.'] - }) - elif revision and worker.repository_id != revision.repo_id: - raise ValidationError({ - 'revision_id': ['The revision must be part of the same repository as the worker.'] - }) - - # Define a version number on workers that are not linked to a repository - version = None - if revision is None: - last_version = worker.versions.using('default').aggregate(last_version=Max('version'))["last_version"] - version = last_version + 1 if last_version else 1 - - worker_version, created = WorkerVersion.objects.get_or_create( - worker=worker, - revision=revision, - version=version, - defaults={ - 'configuration': serializer.validated_data['configuration'], - 'gpu_usage': serializer.validated_data['gpu_usage'], - 'model_usage': serializer.validated_data['model_usage'], - } - ) - - reponse_status = status.HTTP_201_CREATED if created else status.HTTP_200_OK - - return Response(WorkerVersionSerializer(worker_version).data, status=reponse_status) - @extend_schema_view( get=extend_schema( diff --git a/arkindex/process/serializers/workers.py b/arkindex/process/serializers/workers.py index 842e03a92e..49fd979832 100644 --- a/arkindex/process/serializers/workers.py +++ b/arkindex/process/serializers/workers.py @@ -6,7 +6,7 @@ from enum import Enum from textwrap import dedent from django.db import transaction -from django.db.models import Q +from django.db.models import Max, Q from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError @@ -296,8 +296,9 @@ class WorkerVersionCreateSerializer(WorkerVersionSerializer): Git revision for this version. This field is required on workers linked to a repository. - On other workers, an automatic version number is attributed to the version. + On other workers, it cannot be set as an automatic version number is attributed. """).strip(), + source='revision' ) class Meta (WorkerVersionSerializer.Meta): @@ -305,6 +306,43 @@ class WorkerVersionCreateSerializer(WorkerVersionSerializer): 'revision_id', ) + def validate_revision_id(self, revision): + worker = self.context['worker'] + if worker.versions.using('default').filter(revision=revision).exists(): + raise ValidationError('A version of this worker already exists with this revision') + if isinstance(self.context['request'].auth, Task): + if worker.repository_id is None: + # Task authentication is forbidden on workers not linked to a repository + raise ValidationError( + 'Task authentication requires to create a version on workers linked to a repository.' + ) + elif worker.repository_id is not None: + raise ValidationError( + 'Ponos authentication is required to create a version on a worker linked to a repository.' + ) + if worker.repository is None: + raise ValidationError('A revision cannot be set on a worker that is not linked to a repository.') + if worker.repository_id != revision.repo_id: + raise ValidationError('The revision must be part of the same repository as the worker.') + return revision + + def validate(self, data): + super().validate(data) + worker = self.context['worker'] + revision = data.get('revision') + if revision is None and self.context['worker'].repository_id is not None: + raise ValidationError( + {'revision_id': ['A revision must be set on a worker that is linked to a repository.']} + ) + # Define a version number on workers that are not linked to a repository + version = None + if revision is None: + last_version = worker.versions.using('default').aggregate(last_version=Max('version'))["last_version"] + version = last_version + 1 if last_version else 1 + data['version'] = version + data['worker_id'] = worker.id + return data + class RepositorySerializer(serializers.ModelSerializer): """ diff --git a/arkindex/process/tests/test_workers.py b/arkindex/process/tests/test_workers.py index 9d318af89d..035b80582e 100644 --- a/arkindex/process/tests/test_workers.py +++ b/arkindex/process/tests/test_workers.py @@ -700,21 +700,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): [str(tagged_rev.id), str(master_rev.id), str(main_rev.id)] ) - def test_versions_post_requires_task(self): - """ - Only Ponos tasks are able to create new versions - """ - users = [None, self.user, self.superuser] - for user in users: - if user: - self.client.force_login(user) - response = self.client.post( - reverse('api:worker-versions', kwargs={'pk': str(self.worker_reco.id)}), - data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json' - ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - def test_versions_post_non_existing_worker(self): + def test_create_version_non_existing_worker(self): response = self.client.post( reverse('api:worker-versions', kwargs={'pk': '12341234-1234-1234-1234-123412341234'}), data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}}, @@ -724,7 +710,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.json(), {'detail': 'Not found.'}) - def test_versions_post_available_requires_docker_image(self): + def test_create_version_available_requires_docker_image(self): response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_reco.id)}), data={ @@ -742,7 +728,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): ] }) - def test_versions_post_return_existing_worker_version(self): + def test_create_version_return_existing_revision(self): # A worker version already exists for this worker and this revision response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_reco.id)}), @@ -750,14 +736,12 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): format='json', HTTP_AUTHORIZATION=f'Ponos {self.task.token}', ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - data = response.json() - self.assertEqual(data['id'], str(self.version_1.id)) - self.assertEqual(data['configuration'], {"test": 42}) - self.assertEqual(data['revision']['id'], str(self.rev.id)) - self.assertEqual(data['state'], 'available') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'revision_id': ['A version of this worker already exists with this revision'] + }) - def test_versions_post_empty(self): + def test_create_version_empty(self): response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_reco.id)}), HTTP_AUTHORIZATION=f'Ponos {self.task.token}', @@ -765,7 +749,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {'configuration': ['This field is required.']}) - def test_versions_post_unrelated_revision(self): + def test_create_version_unrelated_revision(self): """ It is not possible to create a version between a worker and a revision of two different repositories """ @@ -786,42 +770,99 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): 'revision_id': ['The revision must be part of the same repository as the worker.'] }) - def test_create_worker_version_revision_requires_worker_repository(self): + def test_create_version_revision_requires_worker_repository(self): + self.worker_custom.memberships.filter(user=self.user).update(level=Role.Admin.value) + self.client.force_login(self.user) response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_custom.id)}), data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, format='json', - HTTP_AUTHORIZATION=f'Ponos {self.task.token}', ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - 'revision_id': ['The revision cannot be set on a worker that is not linked to a repository.'] + 'revision_id': ['A revision cannot be set on a worker that is not linked to a repository.'] }) - def test_create_worker_version_null_revision_requires_null_worker_repository(self): + def test_create_version_null_revision_requires_null_worker_repository(self): + self.worker_custom.memberships.filter(user=self.user).update(level=Role.Admin.value) + self.client.force_login(self.user) response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_reco.id)}), data={'configuration': {"test": "test2"}, 'model_usage': True}, format='json', - HTTP_AUTHORIZATION=f'Ponos {self.task.token}', ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - 'revision_id': ['The revision must be set on a worker that is linked to a repository.'] + 'revision_id': ['A revision must be set on a worker that is linked to a repository.'], + }) + + def test_create_version_null_worker_repository_requires_null_revision(self): + self.worker_custom.memberships.filter(user=self.user).update(level=Role.Admin.value) + self.client.force_login(self.user) + response = self.client.post( + reverse('api:worker-versions', kwargs={'pk': str(self.worker_custom.id)}), + data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'revision_id': ['A revision cannot be set on a worker that is not linked to a repository.'] + }) + + def test_create_version_null_revision_forbidden_task_auth(self): + """ + Ponos Task auth cannot create a version on a worker that is not linked to a repository. + """ + self.worker_custom.versions.create(version=41, configuration={}) + with self.assertNumQueries(4): + response = self.client.post( + reverse('api:worker-versions', kwargs={'pk': str(self.worker_custom.id)}), + data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, + format='json', + HTTP_AUTHORIZATION=f'Ponos {self.task.token}', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'revision_id': ['Task authentication requires to create a version on workers linked to a repository.'] + }) + + def test_create_version_user_auth_requires_admin(self): + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:worker-versions', kwargs={'pk': str(self.worker_custom.id)}), + data={'configuration': {"test": "val"}, 'model_usage': True}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {'detail': 'You do not have an admin access to this worker.'}) + + def test_create_version_user_auth_requires_null_repository(self): + self.client.force_login(self.user) + with self.assertNumQueries(10): + response = self.client.post( + reverse('api:worker-versions', kwargs={'pk': str(self.worker_dla.id)}), + data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'revision_id': ['Ponos authentication is required to create a version on a worker linked to a repository.'] }) - def test_create_worker_version_null_revision(self): + def test_create_version_null_revision(self): """ A worker version can be created with no revision on a worker that has no repository. Its version number is automatically incremented. """ + self.worker_custom.memberships.filter(user=self.user).update(level=Role.Admin.value) self.worker_custom.versions.create(version=41, configuration={}) - with self.assertNumQueries(8): + self.client.force_login(self.user) + with self.assertNumQueries(9): response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_custom.id)}), data={'configuration': {"test": "val"}, 'model_usage': True}, format='json', - HTTP_AUTHORIZATION=f'Ponos {self.task.token}', ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() @@ -846,7 +887,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): }, }) - def test_create_worker_version(self): + def test_create_version(self): response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_reco.id)}), data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, @@ -863,7 +904,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.assertEqual(data['gpu_usage'], 'disabled') self.assertEqual(data['model_usage'], True) - def test_create_worker_version_wrong_gpu_usage(self): + def test_create_version_wrong_gpu_usage(self): response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_reco.id)}), data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'gpu_usage': 'not_supported'}, @@ -872,7 +913,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_no_user_configuration_ok(self): + def test_create_version_no_user_configuration_ok(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -885,7 +926,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - def test_valid_user_configuration(self): + def test_create_version_valid_user_configuration(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -919,7 +960,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_valid_user_configuration_dict(self): + def test_create_version_valid_user_configuration_dict(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -946,7 +987,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_user_configuration_dict_strings_only(self): + def test_create_version_user_configuration_dict_strings_only(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -963,7 +1004,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_valid_user_configuration_enum(self): + def test_create_version_valid_user_configuration_enum(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -991,7 +1032,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_valid_user_configuration_list(self): + def test_create_version_valid_user_configuration_list(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1027,7 +1068,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_list_requires_subtype(self): + def test_create_version_invalid_user_configuration_list_requires_subtype(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1053,7 +1094,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_list_wrong_default(self): + def test_create_version_invalid_user_configuration_list_wrong_default(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1079,7 +1120,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_list_wrong_subtype(self): + def test_create_version_invalid_user_configuration_list_wrong_subtype(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1105,7 +1146,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_list_wrong_default_subtype(self): + def test_create_version_invalid_user_configuration_list_wrong_default_subtype(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1131,7 +1172,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_not_list_choices(self): + def test_create_version_invalid_user_configuration_not_list_choices(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1157,7 +1198,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_not_dict(self): + def test_create_version_invalid_user_configuration_not_dict(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1173,7 +1214,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {"configuration": {"user_configuration": [["Expected a dictionary of items but got type \"str\"."]]}}) - def test_invalid_user_configuration_item_not_dict(self): + def test_create_version_invalid_user_configuration_item_not_dict(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1196,7 +1237,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): }] }}) - def test_invalid_user_configuration_wrong_field_type(self): + def test_create_version_invalid_user_configuration_wrong_field_type(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1227,7 +1268,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_wrong_default_type(self): + def test_create_version_invalid_user_configuration_wrong_default_type(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1258,7 +1299,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_choices_no_enum(self): + def test_create_version_invalid_user_configuration_choices_no_enum(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1289,7 +1330,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_missing_key(self): + def test_create_version_invalid_user_configuration_missing_key(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1318,7 +1359,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } ) - def test_invalid_user_configuration_invalid_key(self): + def test_create_version_invalid_user_configuration_invalid_key(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), data={ @@ -1352,7 +1393,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) - def test_invalid_user_configuration_default_value(self): + def test_create_version_invalid_user_configuration_default_value(self): cases = [ ({"type": "int", "default": False}, 'int'), ({"type": "int", "default": True}, 'int'), -- GitLab