From dc74a360e902d73c740b37d5956d7b919e253e40 Mon Sep 17 00:00:00 2001
From: manon blanco <blanco@teklia.com>
Date: Thu, 4 Nov 2021 14:48:34 +0000
Subject: [PATCH] Allow a user to set a WorkerConfiguration on a WorkerRun

---
 arkindex/dataimport/api.py                   | 13 ++-
 arkindex/dataimport/serializers/imports.py   | 14 ++-
 arkindex/dataimport/tests/test_workerruns.py | 99 ++++++--------------
 3 files changed, 53 insertions(+), 73 deletions(-)

diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py
index 6b90496196..9828987a18 100644
--- a/arkindex/dataimport/api.py
+++ b/arkindex/dataimport/api.py
@@ -1096,13 +1096,17 @@ class WorkerRunList(WorkerACLMixin, ListCreateAPIView):
         if not self.has_execution_access(worker):
             raise ValidationError({'worker_version_id': ['You do not have an execution access to this version.']})
 
+        configuration = serializer.validated_data.pop('configuration_id', None)
+        if configuration and configuration.worker_id != worker.id:
+            raise ValidationError({'configuration_id': ['The configuration must be part of the same worker.']})
+
         if process.mode != DataImportMode.Workers:
             raise ValidationError({'dataimport': ['Import mode must be Workers']})
 
         if process.workflow_id is not None:
             raise ValidationError({'__all__': ["Cannot create a WorkerRun on a DataImport that has already started"]})
 
-        serializer.save(dataimport=process)
+        serializer.save(dataimport=process, configuration=configuration)
 
 
 @extend_schema(tags=['imports'])
@@ -1154,6 +1158,13 @@ class WorkerRunDetails(CorpusACLMixin, RetrieveUpdateDestroyAPIView):
         instance.dataimport.worker_runs.filter(parents__contains=[instance.id]).update(parents=ArrayRemove('parents', instance.id))
         return super().perform_destroy(instance)
 
+    def perform_update(self, serializer):
+        worker = serializer.instance.version.worker
+        configuration = serializer.validated_data.get('configuration_id', None)
+        if configuration and configuration.worker_id != worker.id:
+            raise ValidationError({'configuration_id': ['The configuration must be part of the same worker.']})
+        super().perform_update(serializer)
+
 
 @extend_schema_view(post=extend_schema(
     operation_id='CreateImportTranskribus',
diff --git a/arkindex/dataimport/serializers/imports.py b/arkindex/dataimport/serializers/imports.py
index a71de604a5..a99ad4f04b 100644
--- a/arkindex/dataimport/serializers/imports.py
+++ b/arkindex/dataimport/serializers/imports.py
@@ -4,7 +4,14 @@ from django.conf import settings
 from rest_framework import serializers
 from rest_framework.exceptions import ValidationError
 
-from arkindex.dataimport.models import ActivityState, DataFile, DataImport, DataImportMode, WorkerRun
+from arkindex.dataimport.models import (
+    ActivityState,
+    DataFile,
+    DataImport,
+    DataImportMode,
+    WorkerConfiguration,
+    WorkerRun,
+)
 from arkindex.dataimport.serializers.git import RevisionSerializer
 from arkindex.dataimport.serializers.workers import WorkerLightSerializer
 from arkindex.documents.models import Corpus, Element, ElementType
@@ -322,8 +329,7 @@ class WorkerRunSerializer(serializers.ModelSerializer):
     worker_version_id = serializers.UUIDField(source='version_id')
     # Serialize worker with its basic informations
     worker = WorkerLightSerializer(source='version.worker', read_only=True)
-    # A DictField will require valid dicts, but without a child= argument, it will allow any value
-    configuration = serializers.DictField(source='old_configuration', allow_empty=True, default={})
+    configuration_id = serializers.PrimaryKeyRelatedField(queryset=WorkerConfiguration.objects.all(), required=False, allow_null=True)
 
     class Meta:
         model = WorkerRun
@@ -334,7 +340,7 @@ class WorkerRunSerializer(serializers.ModelSerializer):
             'worker_version_id',
             'dataimport_id',
             'worker',
-            'configuration',
+            'configuration_id',
         )
 
 
diff --git a/arkindex/dataimport/tests/test_workerruns.py b/arkindex/dataimport/tests/test_workerruns.py
index 5a66da262c..41a90cb5c8 100644
--- a/arkindex/dataimport/tests/test_workerruns.py
+++ b/arkindex/dataimport/tests/test_workerruns.py
@@ -36,6 +36,9 @@ class TestWorkerRuns(FixtureAPITestCase):
         cls.worker_1 = cls.version_1.worker
         cls.repo = cls.worker_1.repository
         cls.run_1 = cls.dataimport_1.worker_runs.create(version=cls.version_1, parents=[])
+        cls.configuration_1 = cls.worker_1.configurations.create(name="My config", configuration={"key": "value"})
+        worker_version = WorkerVersion.objects.exclude(worker=cls.version_1.worker).first()
+        cls.configuration_2 = worker_version.worker.configurations.create(name="Config")
         cls.dataimport_2 = cls.corpus.imports.create(creator=cls.user, mode=DataImportMode.Workers)
         # Add an execution access right on the worker
         cls.worker_1.memberships.create(user=cls.user, level=Role.Contributor.value)
@@ -73,7 +76,7 @@ class TestWorkerRuns(FixtureAPITestCase):
                 'type': self.worker_1.type,
                 'slug': self.worker_1.slug,
             },
-            'configuration': {},
+            'configuration_id': None,
         }])
 
     def test_runs_list_filter_dataimport(self):
@@ -177,7 +180,7 @@ class TestWorkerRuns(FixtureAPITestCase):
                 'type': self.worker_1.type,
                 'slug': self.worker_1.slug,
             },
-            'configuration': {},
+            'configuration_id': None,
         })
 
     def test_runs_post_via_repository_right(self):
@@ -207,18 +210,13 @@ class TestWorkerRuns(FixtureAPITestCase):
 
     def test_create_run_configuration(self):
         self.client.force_login(self.user)
-        with self.assertNumQueries(14):
+        with self.assertNumQueries(15):
             response = self.client.post(
                 reverse('api:worker-run-list', kwargs={'pk': str(self.dataimport_2.id)}),
                 data={
                     'worker_version_id': str(self.version_1.id),
                     'parents': [],
-                    'configuration': {
-                        'a': 'b',
-                        'c': {
-                            'd': 42
-                        }
-                    }
+                    'configuration_id': str(self.configuration_1.id)
                 },
                 format='json'
             )
@@ -237,35 +235,17 @@ class TestWorkerRuns(FixtureAPITestCase):
                 'type': self.worker_1.type,
                 'slug': self.worker_1.slug,
             },
-            'configuration': {
-                'a': 'b',
-                'c': {
-                    'd': 42
-                }
-            }
+            'configuration_id': str(self.configuration_1.id)
         })
 
-    def test_create_run_configuration_object(self):
+    def test_create_run_invalid_configuration(self):
         self.client.force_login(self.user)
-        parameters = [
-            (1, 'Expected a dictionary of items but got type "int".'),
-            ([], 'Expected a dictionary of items but got type "list".'),
-            ('a', 'Expected a dictionary of items but got type "str".'),
-            (None, 'This field may not be null.'),
-        ]
-        for configuration, message in parameters:
-            with self.subTest(configuration=configuration), self.assertNumQueries(2):
-                response = self.client.post(
-                    reverse('api:worker-run-list', kwargs={'pk': str(self.dataimport_2.id)}),
-                    data={
-                        'worker_version_id': str(self.version_1.id),
-                        'parents': [],
-                        'configuration': configuration,
-                    },
-                    format='json'
-                )
-                self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-                self.assertDictEqual(response.json(), {'configuration': [message]})
+        response = self.client.post(
+            reverse('api:worker-run-list', kwargs={'pk': str(self.dataimport_2.id)}),
+            data={'worker_version_id': str(self.version_1.id), 'parents': [], 'configuration_id': str(self.configuration_2.id)}, format='json'
+        )
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {'configuration_id': ['The configuration must be part of the same worker.']})
 
     def test_retrieve_run_requires_login(self):
         response = self.client.get(reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}))
@@ -307,7 +287,7 @@ class TestWorkerRuns(FixtureAPITestCase):
                 'type': self.worker_1.type,
                 'slug': self.worker_1.slug,
             },
-            'configuration': {},
+            'configuration_id': None,
         })
 
     def test_update_run_requires_login(self):
@@ -430,53 +410,36 @@ class TestWorkerRuns(FixtureAPITestCase):
                 'type': self.worker_1.type,
                 'slug': self.worker_1.slug,
             },
-            'configuration': {},
+            'configuration_id': None,
         })
         self.run_1.refresh_from_db()
         self.assertNotEqual(self.run_1.version_id, dla_version.id)
 
     def test_update_run_configuration(self):
         self.client.force_login(self.user)
-        self.assertDictEqual(self.run_1.old_configuration, {})
-        with self.assertNumQueries(8):
+        self.assertEqual(self.run_1.configuration, None)
+        with self.assertNumQueries(9):
             response = self.client.patch(
                 reverse('api:worker-run-details', kwargs={'pk': self.run_1.id}),
                 data={
-                    'configuration': {
-                        'a': 'b',
-                        'c': {
-                            'd': 42
-                        }
-                    }
+                    'configuration_id': str(self.configuration_1.id)
                 },
                 format='json'
             )
             self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.run_1.refresh_from_db()
-        self.assertDictEqual(self.run_1.old_configuration, {
-            'a': 'b',
-            'c': {
-                'd': 42
-            }
-        })
+        self.assertEqual(self.run_1.configuration.id, self.configuration_1.id)
 
-    def test_update_run_configuration_object(self):
+    def test_update_run_invalid_configuration(self):
         self.client.force_login(self.user)
-        parameters = [
-            (1, 'Expected a dictionary of items but got type "int".'),
-            ([], 'Expected a dictionary of items but got type "list".'),
-            ('a', 'Expected a dictionary of items but got type "str".'),
-            (None, 'This field may not be null.'),
-        ]
-        for configuration, message in parameters:
-            with self.subTest(configuration=configuration), self.assertNumQueries(6):
-                response = self.client.patch(
-                    reverse('api:worker-run-details', kwargs={'pk': self.run_1.id}),
-                    data={'configuration': configuration},
-                    format='json'
-                )
-                self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-                self.assertDictEqual(response.json(), {'configuration': [message]})
+        self.assertEqual(self.run_1.configuration, None)
+        response = self.client.patch(
+            reverse('api:worker-run-details', kwargs={'pk': self.run_1.id}),
+            data={'configuration_id': str(self.configuration_2.id)},
+            format='json'
+        )
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {'configuration_id': ['The configuration must be part of the same worker.']})
 
     def test_update_run_dataimport_already_started(self):
         """
@@ -533,7 +496,7 @@ class TestWorkerRuns(FixtureAPITestCase):
                 'type': self.worker_1.type,
                 'slug': self.worker_1.slug,
             },
-            'configuration': {},
+            'configuration_id': None,
         })
 
     def test_delete_run_requires_login(self):
-- 
GitLab