From 66fe35054643d6308ff5bf851eccb8b65d0d98c9 Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Thu, 10 Nov 2022 14:51:26 +0000
Subject: [PATCH] Add gitrefs in RetrieveWorkerRun test

---
 arkindex/process/api.py                   |  8 +--
 arkindex/process/serializers/workers.py   | 27 +++----
 arkindex/process/tests/test_workerruns.py | 85 +++++++++++++++++++++++
 arkindex/process/tests/test_workers.py    | 71 ++++++++++++-------
 4 files changed, 143 insertions(+), 48 deletions(-)

diff --git a/arkindex/process/api.py b/arkindex/process/api.py
index 11ae53abd1..42bc4ce104 100644
--- a/arkindex/process/api.py
+++ b/arkindex/process/api.py
@@ -84,7 +84,7 @@ from arkindex.process.serializers.workers import (
     WorkerSerializer,
     WorkerStatisticsSerializer,
     WorkerTypeSerializer,
-    WorkerVersionEditSerializer,
+    WorkerVersionCreateSerializer,
     WorkerVersionSerializer,
 )
 from arkindex.project.aws import get_ingest_resource
@@ -909,7 +909,7 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView):
     return existing query if a workerVersion already exists for this worker and this revision
     """
     permission_classes = (IsVerified, )
-    serializer_class = WorkerVersionSerializer
+    serializer_class = WorkerVersionCreateSerializer
     queryset = WorkerVersion.objects.none()
 
     @property
@@ -950,7 +950,7 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView):
             pk=self.kwargs['pk']
         )
 
-        revision = serializer.validated_data['revision']
+        revision = serializer.validated_data['revision_id']
         if worker.repository_id != revision.repo_id:
             raise ValidationError({
                 'revision': ['The revision must be part of the same repository as the worker.']
@@ -1007,7 +1007,7 @@ class WorkerVersionRetrieve(RetrieveUpdateAPIView):
     Retrieve a specific worker version. No authentication is required.
     """
     permission_classes = (IsVerifiedOrReadOnly, )
-    serializer_class = WorkerVersionEditSerializer
+    serializer_class = WorkerVersionSerializer
     queryset = WorkerVersion.objects.select_related('worker').all()
 
     def check_object_permissions(self, request, instance):
diff --git a/arkindex/process/serializers/workers.py b/arkindex/process/serializers/workers.py
index b21215c4f9..e083f79d7a 100644
--- a/arkindex/process/serializers/workers.py
+++ b/arkindex/process/serializers/workers.py
@@ -150,7 +150,7 @@ class WorkerVersionSerializer(serializers.ModelSerializer):
     # State defaults to created when instantiating a WorkerVersion
     state = EnumField(WorkerVersionState, required=False)
     worker = WorkerLightSerializer(read_only=True)
-    revision = serializers.UUIDField()
+    revision = RevisionWithRefsSerializer(required=False, read_only=True)
     gpu_usage = EnumField(WorkerVersionGPUUsage, required=False, default=WorkerVersionGPUUsage.Disabled)
     model_usage = serializers.BooleanField(required=False, default=False)
 
@@ -169,24 +169,13 @@ class WorkerVersionSerializer(serializers.ModelSerializer):
             'model_usage',
             'worker'
         )
-        read_only_fields = ('docker_image_name',)
+        read_only_fields = ('docker_image_name', 'revision')
         # Avoid loading all revisions and all Ponos artifacts when opening this endpoint in a browser
         extra_kwargs = {
             'revision': {'style': {'base_template': 'input.html'}},
             'docker_image': {'style': {'base_template': 'input.html'}},
         }
 
-    def to_representation(self, instance):
-        self.fields['revision'] = RevisionWithRefsSerializer(read_only=True)
-        return super(WorkerVersionSerializer, self).to_representation(instance)
-
-    def validate_revision(self, revision_id):
-        # Retrieve a revision from its ID without listing them with a PrimaryKeyRelatedField
-        try:
-            return Revision.objects.get(id=revision_id)
-        except Revision.DoesNotExist:
-            raise ValidationError({'revision': ['Revision with this ID does not exist.']})
-
     def validate_configuration(self, configuration):
         errors = defaultdict(list)
         user_configuration = configuration.get('user_configuration')
@@ -225,11 +214,13 @@ class WorkerVersionSerializer(serializers.ModelSerializer):
         return data
 
 
-class WorkerVersionEditSerializer(WorkerVersionSerializer):
-    """
-    Serialize a worker version run with a non editable revision field
-    """
-    revision = serializers.UUIDField(read_only=True)
+class WorkerVersionCreateSerializer(WorkerVersionSerializer):
+    revision_id = serializers.PrimaryKeyRelatedField(write_only=True, queryset=Revision.objects.all(), style={'base_template': 'input.html'})
+
+    class Meta (WorkerVersionSerializer.Meta):
+        fields = WorkerVersionSerializer.Meta.fields + (
+            'revision_id',
+        )
 
 
 class RepositorySerializer(serializers.ModelSerializer):
diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py
index 928b7fcce4..239a886195 100644
--- a/arkindex/process/tests/test_workerruns.py
+++ b/arkindex/process/tests/test_workerruns.py
@@ -1,5 +1,7 @@
 import uuid
+from unittest.mock import patch
 
+from django.test import override_settings
 from django.urls import reverse
 from rest_framework import status
 
@@ -44,6 +46,9 @@ class TestWorkerRuns(FixtureAPITestCase):
         # Add an execution access right on the worker
         cls.worker_1.memberships.create(user=cls.user, level=Role.Contributor.value)
 
+        cls.creds = cls.user.credentials.get()
+        cls.gl_patch = patch('arkindex.process.providers.Gitlab')
+
         # Model and Model version setup
         cls.model_1 = Model.objects.create(name='My model')
         cls.model_1.memberships.create(user=cls.user, level=Role.Contributor.value)
@@ -462,6 +467,86 @@ class TestWorkerRuns(FixtureAPITestCase):
             'worker_version_id': str(self.version_1.id)
         })
 
+    @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost')
+    def test_gitrefs_are_retrieved(self):
+        """
+        Check the GitRefs are retrieved with the revision
+        """
+        from arkindex.process.providers import GitLabProvider
+
+        # Create gitrefs and check they were created
+        self.gl_mock = self.gl_patch.start()
+        commit_refs = [
+            {'name': 'refs/tags/v0.1.0', 'type': 'tag'},
+            {'name': 'refs/heads/branch1', 'type': 'branch'},
+            {'name': 'refs/heads/branch2', 'type': 'branch'},
+        ]
+        revision = Revision.objects.create(
+            repo=self.repo,
+            hash='1',
+            message='commit message',
+            author='bob',
+        )
+
+        for ref in commit_refs:
+            GitLabProvider(credentials=self.creds) \
+                .update_or_create_ref(self.repo, revision, ref['name'], ref['type'])
+
+        refs = [
+            {'name': ref.name, 'type': ref.type, 'repo': ref.repository}
+            for ref in revision.refs.all()
+        ]
+        self.assertListEqual(refs, [
+            {'name': 'refs/tags/v0.1.0', 'type': GitRefType.Tag, 'repo': self.repo},
+            {'name': 'refs/heads/branch1', 'type': GitRefType.Branch, 'repo': self.repo},
+            {'name': 'refs/heads/branch2', 'type': GitRefType.Branch, 'repo': self.repo},
+        ])
+        self.assertEqual(self.repo.refs.count(), 3)
+
+        # Assign the revision with gitrefs to the worker version
+        self.version_1.revision = revision
+        self.version_1.save()
+        self.assertTrue(revision.refs.exists())
+
+        # Check that the gitrefs are retrieved with RetrieveWorkerRun
+        self.client.force_login(self.user)
+        with self.assertNumQueries(8):
+            response = self.client.get(reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}))
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        version_revision = response.json()['worker_version']['revision']
+        refs = version_revision.pop('refs')
+        # Only asserting on the name and type of the refs since they were created in bulk using
+        # GitLabProvider.update_or_create_ref, ignoring the IDs.
+        refs_no_id = [
+            {
+                "name": item["name"],
+                "type": item["type"]
+            }
+            for item in refs
+        ]
+        self.assertDictEqual(version_revision, {
+            "author": "bob",
+            "commit_url": "http://my_repo.fake/workers/worker/commit/1",
+            "created": revision.created.isoformat().replace('+00:00', 'Z'),
+            "hash": "1",
+            "id": str(revision.id),
+            "message": "commit message",
+        })
+        self.assertCountEqual(refs_no_id, [
+            {
+                "name": "refs/tags/v0.1.0",
+                "type": "tag"
+            },
+            {
+                "name": "refs/heads/branch1",
+                "type": "branch"
+            },
+            {
+                "name": "refs/heads/branch2",
+                "type": "branch"
+            }
+        ])
+
     def test_update_run_requires_id_and_parents(self):
         self.client.force_login(self.user)
         with self.assertNumQueries(7):
diff --git a/arkindex/process/tests/test_workers.py b/arkindex/process/tests/test_workers.py
index 22e986b513..73694c0f2c 100644
--- a/arkindex/process/tests/test_workers.py
+++ b/arkindex/process/tests/test_workers.py
@@ -612,7 +612,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
                 self.client.force_login(user)
             response = self.client.post(
                 reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}),
-                data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json'
+                data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json'
             )
             self.assertEqual(response.status_code, status_code)
 
@@ -620,7 +620,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         self.client.force_login(self.internal_user)
         response = self.client.post(
             reverse('api:worker-versions', kwargs={'pk': '12341234-1234-1234-1234-123412341234'}),
-            data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json'
+            data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}}, format='json'
         )
         self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
         self.assertEqual(response.json(), {'detail': 'Not found.'})
@@ -630,7 +630,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}),
             data={
-                'revision': str(self.rev2.id),
+                'revision_id': str(self.rev2.id),
                 'configuration': {"test": "test2"},
                 'state': 'available',
             },
@@ -648,7 +648,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         # A worker version already exists for this worker and this revision
         response = self.client.post(
             reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}),
-            data={'revision': str(self.rev.id), 'configuration': {"test": "test1"}}, format='json'
+            data={'revision_id': str(self.rev.id), 'configuration': {"test": "test1"}}, format='json'
         )
         self.assertEqual(response.status_code, status.HTTP_200_OK)
         data = response.json()
@@ -664,7 +664,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
-            'revision': ['This field is required.'],
+            'revision_id': ['This field is required.'],
             'configuration': ['This field is required.']
         })
 
@@ -681,7 +681,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         self.client.force_login(self.internal_user)
         response = self.client.post(
             reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}),
-            data={'revision': str(rev.id), 'configuration': {"test": "test2"}}, format='json'
+            data={'revision_id': str(rev.id), 'configuration': {"test": "test2"}}, format='json'
         )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {
@@ -692,7 +692,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         self.client.force_login(self.internal_user)
         response = self.client.post(
             reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}),
-            data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, format='json'
+            data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'model_usage': True}, format='json'
         )
         self.assertEqual(response.status_code, status.HTTP_201_CREATED)
 
@@ -708,7 +708,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         self.client.force_login(self.internal_user)
         response = self.client.post(
             reverse('api:worker-versions', kwargs={'pk': str(self.worker_1.id)}),
-            data={'revision': str(self.rev2.id), 'configuration': {"test": "test2"}, 'gpu_usage': 'not_supported'}, format='json'
+            data={'revision_id': str(self.rev2.id), 'configuration': {"test": "test2"}, 'gpu_usage': 'not_supported'}, format='json'
         )
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
 
@@ -717,7 +717,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {"beep": "boop"},
                 "gpu_usage": "disabled",
             },
@@ -730,7 +730,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_integer": {"title": "Demo Integer", "type": "int", "required": True, "default": 1},
@@ -764,7 +764,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_dict": {"title": "Demo Dict", "type": "dict", "required": True, "default": {"a": "b", "c": "d"}},
@@ -791,7 +791,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_dict": {"title": "Demo Dict", "type": "dict", "required": True, "default": {"a": ["12", "13"], "c": "d"}},
@@ -808,7 +808,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_choice": {"title": "Decisions", "type": "enum", "required": True, "default": 1, "choices": [1, 2, 3]}
@@ -836,7 +836,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": [1, 2, 3, 4]},
@@ -873,7 +873,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_list": {"title": "Demo List", "type": "list", "required": True, "default": [1, 2, 3, 4]},
@@ -899,7 +899,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": 12},
@@ -925,7 +925,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "dict", "default": [1, 2, 3, 4]},
@@ -951,7 +951,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": [1, 2, "three", 4]},
@@ -977,7 +977,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_choice": {"title": "Decisions", "type": "enum", "required": True, "default": 1, "choices": "eeee"}
@@ -1003,7 +1003,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": "non"
                 },
@@ -1019,7 +1019,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "something": {
@@ -1050,7 +1050,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "one_float": {
@@ -1081,7 +1081,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "something": {
@@ -1112,7 +1112,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_integer": {"type": "int", "required": True, "default": 1}
@@ -1141,7 +1141,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         response = self.client.post(
             reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
             data={
-                "revision": str(self.rev2.id),
+                "revision_id": str(self.rev2.id),
                 "configuration": {
                     "user_configuration": {
                         "demo_integer": {
@@ -1187,7 +1187,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
                 response = self.client.post(
                     reverse("api:worker-versions", kwargs={"pk": str(self.worker_2.id)}),
                     data={
-                        "revision": str(self.rev2.id),
+                        "revision_id": str(self.rev2.id),
                         "configuration": {
                             "user_configuration": {
                                 "param": {"title": 'param', **params}
@@ -1336,6 +1336,25 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
         self.assertEqual(data['state'], 'error')
         self.assertEqual(data['gpu_usage'], 'disabled')
 
+    def test_cannot_update_worker_version_revision(self):
+        self.version_1.state = WorkerVersionState.Created
+        self.version_1.docker_image = None
+        self.version_1.save()
+        self.client.force_login(self.internal_user)
+        with self.assertNumQueries(9):
+            response = self.client.patch(
+                reverse('api:version-retrieve', kwargs={'pk': str(self.version_1.id)}),
+                data={
+                    'revision_id': str(self.rev2.id)
+                }, format='json'
+            )
+            # revision_id just gets ignored
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+        data = response.json()
+        self.assertEqual(data['id'], str(self.version_1.id))
+        self.version_1.refresh_from_db()
+        self.assertEqual(data['revision']['id'], str(self.rev.id))
+
     def test_update_version_valid(self):
         self.version_1.state = WorkerVersionState.Created
         self.version_1.docker_image = None
-- 
GitLab