diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 11ae53abd171c02245a65b96f965bc4a4a2960c2..42bc4ce104b873d52a181b427146402ba863e49b 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -84,7 +84,7 @@ from arkindex.process.serializers.workers import ( WorkerSerializer, WorkerStatisticsSerializer, WorkerTypeSerializer, - WorkerVersionEditSerializer, + WorkerVersionCreateSerializer, WorkerVersionSerializer, ) from arkindex.project.aws import get_ingest_resource @@ -909,7 +909,7 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView): return existing query if a workerVersion already exists for this worker and this revision """ permission_classes = (IsVerified, ) - serializer_class = WorkerVersionSerializer + serializer_class = WorkerVersionCreateSerializer queryset = WorkerVersion.objects.none() @property @@ -950,7 +950,7 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView): pk=self.kwargs['pk'] ) - revision = serializer.validated_data['revision'] + revision = serializer.validated_data['revision_id'] if worker.repository_id != revision.repo_id: raise ValidationError({ 'revision': ['The revision must be part of the same repository as the worker.'] @@ -1007,7 +1007,7 @@ class WorkerVersionRetrieve(RetrieveUpdateAPIView): Retrieve a specific worker version. No authentication is required. """ permission_classes = (IsVerifiedOrReadOnly, ) - serializer_class = WorkerVersionEditSerializer + serializer_class = WorkerVersionSerializer queryset = WorkerVersion.objects.select_related('worker').all() def check_object_permissions(self, request, instance): diff --git a/arkindex/process/serializers/workers.py b/arkindex/process/serializers/workers.py index b21215c4f9c7389aeddc6d99247f1caced3b8bd8..e083f79d7a0a098dad6ea1f4b2c3baa27c624e3d 100644 --- a/arkindex/process/serializers/workers.py +++ b/arkindex/process/serializers/workers.py @@ -150,7 +150,7 @@ class WorkerVersionSerializer(serializers.ModelSerializer): # State defaults to created when instantiating a WorkerVersion state = EnumField(WorkerVersionState, required=False) worker = WorkerLightSerializer(read_only=True) - revision = serializers.UUIDField() + revision = RevisionWithRefsSerializer(required=False, read_only=True) gpu_usage = EnumField(WorkerVersionGPUUsage, required=False, default=WorkerVersionGPUUsage.Disabled) model_usage = serializers.BooleanField(required=False, default=False) @@ -169,24 +169,13 @@ class WorkerVersionSerializer(serializers.ModelSerializer): 'model_usage', 'worker' ) - read_only_fields = ('docker_image_name',) + read_only_fields = ('docker_image_name', 'revision') # Avoid loading all revisions and all Ponos artifacts when opening this endpoint in a browser extra_kwargs = { 'revision': {'style': {'base_template': 'input.html'}}, 'docker_image': {'style': {'base_template': 'input.html'}}, } - def to_representation(self, instance): - self.fields['revision'] = RevisionWithRefsSerializer(read_only=True) - return super(WorkerVersionSerializer, self).to_representation(instance) - - def validate_revision(self, revision_id): - # Retrieve a revision from its ID without listing them with a PrimaryKeyRelatedField - try: - return Revision.objects.get(id=revision_id) - except Revision.DoesNotExist: - raise ValidationError({'revision': ['Revision with this ID does not exist.']}) - def validate_configuration(self, configuration): errors = defaultdict(list) user_configuration = configuration.get('user_configuration') @@ -225,11 +214,13 @@ class WorkerVersionSerializer(serializers.ModelSerializer): return data -class WorkerVersionEditSerializer(WorkerVersionSerializer): - """ - Serialize a worker version run with a non editable revision field - """ - revision = serializers.UUIDField(read_only=True) +class WorkerVersionCreateSerializer(WorkerVersionSerializer): + revision_id = serializers.PrimaryKeyRelatedField(write_only=True, queryset=Revision.objects.all(), style={'base_template': 'input.html'}) + + class Meta (WorkerVersionSerializer.Meta): + fields = WorkerVersionSerializer.Meta.fields + ( + 'revision_id', + ) class RepositorySerializer(serializers.ModelSerializer): diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py index 928b7fcce4c3873ff61a9315ff8fcebb301726f9..239a88619504fdcf9b08688ef7b54d426ed980bc 100644 --- a/arkindex/process/tests/test_workerruns.py +++ b/arkindex/process/tests/test_workerruns.py @@ -1,5 +1,7 @@ import uuid +from unittest.mock import patch +from django.test import override_settings from django.urls import reverse from rest_framework import status @@ -44,6 +46,9 @@ class TestWorkerRuns(FixtureAPITestCase): # Add an execution access right on the worker cls.worker_1.memberships.create(user=cls.user, level=Role.Contributor.value) + cls.creds = cls.user.credentials.get() + cls.gl_patch = patch('arkindex.process.providers.Gitlab') + # Model and Model version setup cls.model_1 = Model.objects.create(name='My model') cls.model_1.memberships.create(user=cls.user, level=Role.Contributor.value) @@ -462,6 +467,86 @@ class TestWorkerRuns(FixtureAPITestCase): 'worker_version_id': str(self.version_1.id) }) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') + def test_gitrefs_are_retrieved(self): + """ + Check the GitRefs are retrieved with the revision + """ + from arkindex.process.providers import GitLabProvider + + # Create gitrefs and check they were created + self.gl_mock = self.gl_patch.start() + commit_refs = [ + {'name': 'refs/tags/v0.1.0', 'type': 'tag'}, + {'name': 'refs/heads/branch1', 'type': 'branch'}, + {'name': 'refs/heads/branch2', 'type': 'branch'}, + ] + revision = Revision.objects.create( + repo=self.repo, + hash='1', + message='commit message', + author='bob', + ) + + for ref in commit_refs: + GitLabProvider(credentials=self.creds) \ + .update_or_create_ref(self.repo, revision, ref['name'], ref['type']) + + refs = [ + {'name': ref.name, 'type': ref.type, 'repo': ref.repository} + for ref in revision.refs.all() + ] + self.assertListEqual(refs, [ + {'name': 'refs/tags/v0.1.0', 'type': GitRefType.Tag, 'repo': self.repo}, + {'name': 'refs/heads/branch1', 'type': GitRefType.Branch, 'repo': self.repo}, + {'name': 'refs/heads/branch2', 'type': GitRefType.Branch, 'repo': self.repo}, + ]) + self.assertEqual(self.repo.refs.count(), 3) + + # Assign the revision with gitrefs to the worker version + self.version_1.revision = revision + self.version_1.save() + self.assertTrue(revision.refs.exists()) + + # Check that the gitrefs are retrieved with RetrieveWorkerRun + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.get(reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + version_revision = response.json()['worker_version']['revision'] + refs = version_revision.pop('refs') + # Only asserting on the name and type of the refs since they were created in bulk using + # GitLabProvider.update_or_create_ref, ignoring the IDs. + refs_no_id = [ + { + "name": item["name"], + "type": item["type"] + } + for item in refs + ] + self.assertDictEqual(version_revision, { + "author": "bob", + "commit_url": "http://my_repo.fake/workers/worker/commit/1", + "created": revision.created.isoformat().replace('+00:00', 'Z'), + "hash": "1", + "id": str(revision.id), + "message": "commit message", + }) + self.assertCountEqual(refs_no_id, [ + { + "name": "refs/tags/v0.1.0", + "type": "tag" + }, + { + "name": "refs/heads/branch1", + "type": "branch" + }, + { + "name": "refs/heads/branch2", + "type": "branch" + } + ]) + def test_update_run_requires_id_and_parents(self): self.client.force_login(self.user) with self.assertNumQueries(7): diff --git a/arkindex/process/tests/test_workers.py b/arkindex/process/tests/test_workers.py index 22e986b513a34cf29addfef409d5c55bf2b5cb73..73694c0f2cf9293b7a014969c9b2a4dd68fb4baf 100644 --- a/arkindex/process/tests/test_workers.py +++ b/arkindex/process/tests/test_workers.py @@ -612,7 +612,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(user) response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), - data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json' + data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json' ) self.assertEqual(response.status_code, status_code) @@ -620,7 +620,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(self.internal_user) response = self.client.post( reverse('api:worker-versions', kwargs={'pk': '12341234-1234-1234-1234-123412341234'}), - data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json' + data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json' ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.json(), {'detail': 'Not found.'}) @@ -630,7 +630,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), data={ - 'revision': str(self.rev2.id), + 'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'state': 'available', }, @@ -648,7 +648,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): # A worker version already exists for this worker and this revision response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), - data={'revision': str(self.rev.id), 'configuration': {"test": "test1"}}, format='json' + data={'revision_id': str(self.rev.id), 'configuration': {"test": "test1"}}, format='json' ) self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() @@ -664,7 +664,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - 'revision': ['This field is required.'], + 'revision_id': ['This field is required.'], 'configuration': ['This field is required.'] }) @@ -681,7 +681,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(self.internal_user) response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), - data={'revision': str(rev.id), 'configuration': {"test": "test2"}}, format='json' + data={'revision_id': str(rev.id), 'configuration': {"test": "test2"}}, format='json' ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { @@ -692,7 +692,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(self.internal_user) response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), - data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, format='json' + data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, format='json' ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -708,7 +708,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.client.force_login(self.internal_user) response = self.client.post( reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}), - data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}, 'gpu_usage': 'not_supported'}, format='json' + data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'gpu_usage': 'not_supported'}, format='json' ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -717,7 +717,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": {"beep": "boop"}, "gpu_usage": "disabled", }, @@ -730,7 +730,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_integer": {"title": "Demo Integer", "type": "int", "required": True, "default": 1}, @@ -764,7 +764,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_dict": {"title": "Demo Dict", "type": "dict", "required": True, "default": {"a": "b", "c": "d"}}, @@ -791,7 +791,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_dict": {"title": "Demo Dict", "type": "dict", "required": True, "default": {"a": ["12", "13"], "c": "d"}}, @@ -808,7 +808,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_choice": {"title": "Decisions", "type": "enum", "required": True, "default": 1, "choices": [1, 2, 3]} @@ -836,7 +836,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": [1, 2, 3, 4]}, @@ -873,7 +873,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_list": {"title": "Demo List", "type": "list", "required": True, "default": [1, 2, 3, 4]}, @@ -899,7 +899,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": 12}, @@ -925,7 +925,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "dict", "default": [1, 2, 3, 4]}, @@ -951,7 +951,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": [1, 2, "three", 4]}, @@ -977,7 +977,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_choice": {"title": "Decisions", "type": "enum", "required": True, "default": 1, "choices": "eeee"} @@ -1003,7 +1003,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": "non" }, @@ -1019,7 +1019,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "something": { @@ -1050,7 +1050,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "one_float": { @@ -1081,7 +1081,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "something": { @@ -1112,7 +1112,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_integer": {"type": "int", "required": True, "default": 1} @@ -1141,7 +1141,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "demo_integer": { @@ -1187,7 +1187,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}), data={ - "revision": str(self.rev2.id), + "revision_id": str(self.rev2.id), "configuration": { "user_configuration": { "param": {"title": 'param', **params} @@ -1336,6 +1336,25 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): self.assertEqual(data['state'], 'error') self.assertEqual(data['gpu_usage'], 'disabled') + def test_cannot_update_worker_version_revision(self): + self.version_1.state = WorkerVersionState.Created + self.version_1.docker_image = None + self.version_1.save() + self.client.force_login(self.internal_user) + with self.assertNumQueries(9): + response = self.client.patch( + reverse('api:version-retrieve', kwargs={'pk': str(self.version_1.id)}), + data={ + 'revision_id': str(self.rev2.id) + }, format='json' + ) + # revision_id just gets ignored + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertEqual(data['id'], str(self.version_1.id)) + self.version_1.refresh_from_db() + self.assertEqual(data['revision']['id'], str(self.rev.id)) + def test_update_version_valid(self): self.version_1.state = WorkerVersionState.Created self.version_1.docker_image = None