diff --git a/arkindex/process/api.py b/arkindex/process/api.py index b76556470ba5b7dcaf4c78c3b8edd0b67ea68a60..07b5a46f5b4a78744178bde71dd96ba0e9ae1877 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -89,7 +89,11 @@ from arkindex.process.serializers.imports import ( ) from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer from arkindex.process.serializers.training import ProcessDatasetSerializer, StartTrainingSerializer -from arkindex.process.serializers.worker_runs import WorkerRunEditSerializer, WorkerRunSerializer +from arkindex.process.serializers.worker_runs import ( + UserWorkerRunSerializer, + WorkerRunEditSerializer, + WorkerRunSerializer, +) from arkindex.process.serializers.workers import ( CorpusWorkerVersionSerializer, DockerWorkerVersionSerializer, @@ -1359,6 +1363,45 @@ class WorkerRunList(ProcessACLMixin, ListCreateAPIView): return context +@extend_schema(tags=['process']) +@extend_schema_view( + post=extend_schema( + operation_id='CreateUserWorkerRun', + description="Create a worker run tied to the user's local process.", + responses={201: WorkerRunSerializer}, + ), +) +class UserWorkerRunCreate(ProcessACLMixin, CreateAPIView): + permission_classes = (IsVerified, ) + serializer_class = UserWorkerRunSerializer + queryset = WorkerRun.objects.none() + + @cached_property + def local_process(self): + try: + return Process.objects.annotate(last_run=Value(None, output_field=IntegerField())).get( + creator=self.request.user, + mode=ProcessMode.Local + ) + except Process.DoesNotExist: + return None + + def get_serializer_context(self): + context = super().get_serializer_context() + context['local_process'] = self.local_process + return context + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + run = serializer.instance + return Response( + status=status.HTTP_201_CREATED, + data=WorkerRunSerializer(run, context=self.get_serializer_context()).data, + ) + + @extend_schema(tags=['process']) @extend_schema_view( get=extend_schema( diff --git a/arkindex/process/serializers/worker_runs.py b/arkindex/process/serializers/worker_runs.py index 2dc3161fe19d710fa603bad5a4bc6345abe5ab18..04110e67c2f32d585908f379842542fefab1f294 100644 --- a/arkindex/process/serializers/worker_runs.py +++ b/arkindex/process/serializers/worker_runs.py @@ -7,6 +7,7 @@ from rest_framework.exceptions import ValidationError from arkindex.process.models import ( GitRef, + Process, ProcessMode, WorkerConfiguration, WorkerRun, @@ -190,3 +191,81 @@ class WorkerRunEditSerializer(WorkerRunSerializer): class Meta(WorkerRunSerializer.Meta): # Same as WorkerRunSerializer, but the worker_version_id cannot be edited fields = tuple(set(WorkerRunSerializer.Meta.fields) - {'worker_version_id'}) + + +class UserWorkerRunSerializer(serializers.ModelSerializer): + worker_version_id = serializers.UUIDField( + # The worker version corresponding to the ID is returned by the validate_worker_version_id function + source='version', + help_text=dedent(""" + ID of a worker version to create the worker run from. + + This worker version must not have a Git revision, and its worker must not be linked to a Git repository. + You must have **contributor** access to this worker. + """), + ) + model_version_id = serializers.UUIDField( + # The model version corresponding to the ID is returned by the validate_model_version_id function + source='model_version', + required=False, + help_text=dedent(""" + If set, you must either: + - have **contributor** access to the corresponding model, OR + - have **guest** access to the corresponding model AND the model version must have a tag and be available. + """) + ) + configuration_id = serializers.PrimaryKeyRelatedField( + source='configuration', + required=False, + queryset=WorkerConfiguration.objects.all(), + style={'base_template': 'input.html'}, + ) + + def validate_worker_version_id(self, worker_version_id): + # Check that the worker version exists + try: + worker_version = WorkerVersion.objects.select_related('worker').get(pk=worker_version_id) + except WorkerVersion.DoesNotExist: + raise ValidationError(detail=f'Worker version {worker_version_id} does not exist.') + + # Check that the worker is executable by the request user + if not worker_version.worker.is_executable(self.context['request'].user): + raise ValidationError(detail='You do not have contributor access to this worker.') + + # Check that the version doesn't have a revision, and the related worker is not linked to a repository + if worker_version.revision_id: + raise ValidationError(detail='The worker version used to create a local worker run must not have a revision set.') + if worker_version.worker.repository_id: + raise ValidationError(detail='The worker used to create a local worker run must not be linked to a repository.') + + return worker_version + + def validate_model_version_id(self, model_version_id): + # Check that the model version exists + try: + model_version = ModelVersion.objects.select_related('model').get(pk=model_version_id) + except ModelVersion.DoesNotExist: + raise ValidationError(detail=f'Model version {model_version_id} does not exist.') + + # Check that the user has guest access to the model version + if not model_version.is_executable(self.context['request'].user): + raise ValidationError(detail='You do not have guest access to this model.') + + return model_version + + def validate(self, data): + local_process = self.context['local_process'] + # If the user doesn't already have a local process, create it + if not local_process: + local_process = Process.objects.create( + mode=ProcessMode.Local, + creator=self.context['request'].user + ) + local_process.last_run = None + data['process'] = local_process + + return data + + class Meta: + model = WorkerRun + fields = ('worker_version_id', 'model_version_id', 'configuration_id') diff --git a/arkindex/process/tests/test_user_workerruns.py b/arkindex/process/tests/test_user_workerruns.py index 75e1a51318cb535dc88919078950a5ad3f2e9bf5..dac68b8dc37109b53e8427a8d88dea348e846003 100644 --- a/arkindex/process/tests/test_user_workerruns.py +++ b/arkindex/process/tests/test_user_workerruns.py @@ -5,9 +5,17 @@ from django.urls import reverse from django.utils import timezone from rest_framework import status -from arkindex.process.models import Process, ProcessMode, WorkerRun, WorkerVersion +from arkindex.process.models import ( + Process, + ProcessMode, + WorkerConfiguration, + WorkerRun, + WorkerVersion, + WorkerVersionState, +) from arkindex.project.tests import FixtureAPITestCase -from arkindex.users.models import User +from arkindex.training.models import Model, ModelVersion, ModelVersionState +from arkindex.users.models import Right, Role, User class TestUserWorkerRuns(FixtureAPITestCase): @@ -17,9 +25,23 @@ class TestUserWorkerRuns(FixtureAPITestCase): cls.version_1 = WorkerVersion.objects.get(worker__slug='reco') cls.worker_1 = cls.version_1.worker cls.custom_version = WorkerVersion.objects.get(worker__slug='custom') + cls.other_version = WorkerVersion.objects.create( + worker=cls.custom_version.worker, + revision=None, + version=2, + configuration={ + 'name': 'bartimaeus', + 'category': 'level 14 djinni' + }, + state=WorkerVersionState.Created, + ) # Local worker run cls.local_process = Process.objects.get(mode=ProcessMode.Local, creator=cls.user) cls.local_run = WorkerRun.objects.get(process=cls.local_process) + # Give execution rights to self.user on the custom worker + Right.objects.create(user=cls.user, content_object=cls.custom_version.worker, level=Role.Contributor.value) + + # ListUserWorkerRuns def test_list_user_runs_requires_login(self): with self.assertNumQueries(0): @@ -166,3 +188,247 @@ class TestUserWorkerRuns(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.json()['count'], 0) self.assertEqual(response.json()['results'], []) + + # CreateUserWorkerRun + + def test_create_user_run_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post(reverse('api:user-worker-run-create')) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_user_run_requires_verified(self): + self.user.verified_email = False + self.user.save() + self.client.force_login(self.user) + with self.assertNumQueries(2): + response = self.client.post(reverse('api:user-worker-run-create')) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_create_user_run(self): + self.client.force_login(self.user) + with self.assertNumQueries(9): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.other_version.id)} + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + created_run = WorkerRun.objects.get(version=self.other_version) + self.assertDictEqual(response.json(), { + 'id': str(created_run.id), + 'parents': [], + 'configuration': None, + 'model_version': None, + 'worker_version': { + 'configuration': { + 'name': 'bartimaeus', + 'category': 'level 14 djinni' + }, + 'created': self.other_version.created.isoformat().replace('+00:00', 'Z'), + 'docker_image': None, + 'docker_image_iid': None, + 'docker_image_name': None, + 'gpu_usage': 'disabled', + 'id': str(self.other_version.id), + 'model_usage': False, + 'revision': None, + 'state': 'created', + 'version': 2, + 'worker': { + 'id': str(self.other_version.worker.id), + 'name': 'Custom worker', + 'slug': 'custom', + 'type': 'custom' + } + }, + 'summary': 'Worker Custom worker @ version 2', + 'process': { + 'activity_state': 'disabled', + 'corpus': None, + 'id': str(self.local_process.id), + 'mode': 'local', + 'model_id': None, + 'name': None, + 'state': 'unscheduled', + 'test_folder_id': None, + 'train_folder_id': None, + 'use_cache': False, + 'validation_folder_id': None + } + }) + + def test_create_user_run_no_local_process(self): + """ + If the user doesn't already have a local process, one is created + """ + test_user = User.objects.get(email='user2@user.fr') + Right.objects.create(user=test_user, content_object=self.other_version.worker, level=Role.Contributor.value) + self.client.force_login(test_user) + self.assertFalse(Process.objects.filter(mode=ProcessMode.Local, creator=test_user).exists()) + with self.assertNumQueries(9): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.other_version.id)} + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertTrue(Process.objects.filter(mode=ProcessMode.Local, creator=test_user).exists()) + + def test_create_user_run_wv_doesnt_exist(self): + self.client.force_login(self.user) + with self.assertNumQueries(4): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'worker_version_id': ['Worker version aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa does not exist.']}) + + def test_create_user_run_wv_has_revision(self): + Right.objects.create(user=self.user, content_object=self.version_1.worker, level=Role.Contributor.value) + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.version_1.id)} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'worker_version_id': ['The worker version used to create a local worker run must not have a revision set.']}) + + def test_create_user_run_worker_has_repository(self): + Right.objects.create(user=self.user, content_object=self.version_1.worker, level=Role.Contributor.value) + test_version = self.version_1.worker.versions.create(configuration={}, version=7) + self.client.force_login(self.user) + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(test_version.id)} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'worker_version_id': ['The worker used to create a local worker run must not be linked to a repository.']}) + + def test_create_user_run_worker_not_executable(self): + test_user = User.objects.get(email='user2@user.fr') + Right.objects.create(user=test_user, content_object=self.other_version.worker, level=Role.Guest.value) + self.client.force_login(test_user) + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.other_version.id)} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'worker_version_id': ['You do not have contributor access to this worker.']}) + + def test_create_user_run_mv_doesnt_exist(self): + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.other_version.id), 'model_version_id': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'model_version_id': ['Model version aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa does not exist.']}) + + def test_create_user_run_wc_doesnt_exist(self): + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.other_version.id), 'configuration_id': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'configuration_id': ['Invalid pk "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" - object does not exist.']}) + + def test_create_user_run_model_no_access(self): + test_model = Model.objects.create(name='Some model', public=False) + model_version = ModelVersion.objects.create(model_id=test_model.id, hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', archive_hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', size=8) + self.client.force_login(self.user) + with self.assertNumQueries(10): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.other_version.id), 'model_version_id': str(model_version.id)} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'model_version_id': ['You do not have guest access to this model.']}) + + def test_create_user_run_full(self): + test_model = Model.objects.create(name='Some model', public=False) + model_version = ModelVersion.objects.create( + state=ModelVersionState.Available, + model_id=test_model.id, + hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', + archive_hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + size=8, + tag='0.1.3' + ) + Right.objects.create(user=self.user, content_object=test_model, level=Role.Guest.value) + test_configuration = WorkerConfiguration.objects.create(worker=self.other_version.worker, name="Some configuration", configuration={'param': 'value'}) + + self.client.force_login(self.user) + + with self.assertNumQueries(12): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={ + 'worker_version_id': str(self.other_version.id), + 'model_version_id': str(model_version.id), + 'configuration_id': str(test_configuration.id) + } + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + created_run = WorkerRun.objects.get(version=self.other_version) + self.assertDictEqual(response.json(), { + 'id': str(created_run.id), + 'parents': [], + 'configuration': { + 'archived': False, + 'configuration': {'param': 'value'}, + 'id': str(test_configuration.id), + 'name': 'Some configuration' + }, + 'model_version': { + 'configuration': {}, + 'id': str(model_version.id), + 'model': { + 'id': str(test_model.id), + 'name': 'Some model' + }, + 'size': 8, + 'state': 'available', + 'tag': '0.1.3' + }, + 'worker_version': { + 'configuration': { + 'name': 'bartimaeus', + 'category': 'level 14 djinni' + }, + 'created': self.other_version.created.isoformat().replace('+00:00', 'Z'), + 'docker_image': None, + 'docker_image_iid': None, + 'docker_image_name': None, + 'gpu_usage': 'disabled', + 'id': str(self.other_version.id), + 'model_usage': False, + 'revision': None, + 'state': 'created', + 'version': 2, + 'worker': { + 'id': str(self.other_version.worker.id), + 'name': 'Custom worker', + 'slug': 'custom', + 'type': 'custom' + } + }, + 'summary': f"Worker Custom worker @ version 2 with model Some model @ {str(model_version.id)[:6]} using configuration 'Some configuration'", + 'process': { + 'activity_state': 'disabled', + 'corpus': None, + 'id': str(self.local_process.id), + 'mode': 'local', + 'model_id': None, + 'name': None, + 'state': 'unscheduled', + 'test_folder_id': None, + 'train_folder_id': None, + 'use_cache': False, + 'validation_folder_id': None + } + }) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index b11849ce99e44dde4c16c34379b2b2e7d3bcf1a9..e8d119d879cee29545e314fd39828e358925e612 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -105,6 +105,7 @@ from arkindex.process.api import ( StartProcess, StartTraining, UpdateWorkerActivity, + UserWorkerRunCreate, UserWorkerRunList, WorkerActivityList, WorkerConfigurationList, @@ -294,6 +295,7 @@ api = [ path('process/<uuid:pk>/workers/', WorkerRunList.as_view(), name='worker-run-list'), path('process/workers/<uuid:pk>/', WorkerRunDetails.as_view(), name='worker-run-details'), path('process/local/workers/', UserWorkerRunList.as_view(), name='user-worker-run-list'), + path('process/local/workers/create/', UserWorkerRunCreate.as_view(), name='user-worker-run-create'), path('process/<uuid:pk>/elements/', ListProcessElements.as_view(), name='process-elements-list'), path('process/<uuid:pk>/activity-stats/', ProcessWorkersActivity.as_view(), name='process-activity-stats'), path('process/<uuid:pk>/template/', CreateProcessTemplate.as_view(), name='create-process-template'), diff --git a/arkindex/training/models.py b/arkindex/training/models.py index 8efa043eb61b7ac4bb2ce566eab39c4cef3fd4d3..a214e23837f95c01556b0acd503746f728d17445 100644 --- a/arkindex/training/models.py +++ b/arkindex/training/models.py @@ -167,6 +167,12 @@ class ModelVersion(S3FileMixin, IndexableModel): """ return sha256((str(self.id) + self.hash + settings.SECRET_KEY).encode('utf-8')).hexdigest() + def is_executable(self, user) -> bool: + """ + Whether this model version is executable by a user + """ + return ModelVersion.objects.executable(user).filter(id=self.id).exists() + def __str__(self): if self.tag: return f'{self.model.name} @ {self.tag} ({self.truncated_id}…)'