From 14aaffaccaa14c37826437abc8fef609cb2c1471 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Thu, 9 Mar 2023 09:47:52 +0100
Subject: [PATCH] Add task token automatically in all Ponos tasks

---
 arkindex/ponos/models.py                       | 16 +++++++++++-----
 arkindex/ponos/tests/test_api.py               | 18 +++++++++++++++---
 .../process/tests/test_create_s3_import.py     |  2 ++
 .../tests/test_create_training_process.py      |  4 ++++
 4 files changed, 32 insertions(+), 8 deletions(-)

diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py
index 930ce597d8..4a97105702 100644
--- a/arkindex/ponos/models.py
+++ b/arkindex/ponos/models.py
@@ -416,8 +416,15 @@ class Workflow(models.Model):
         )
 
         # Create tasks without any parent
-        tasks = {
-            slug: self.tasks.create(
+        tasks = {}
+        for slug, recipe in self.recipes.items():
+            # Add the task token to the environment now, as higher-level code cannot add a token
+            # when building workflow recipes since the Task instances do not exist.
+            env = recipe.environment.copy()
+            token = task_token_default()
+            env['ARKINDEX_TASK_TOKEN'] = token
+
+            tasks[slug] = self.tasks.create(
                 run=run,
                 slug=slug,
                 tags=recipe.tags,
@@ -425,16 +432,15 @@ class Workflow(models.Model):
                 image=recipe.image,
                 command=recipe.command,
                 shm_size=recipe.shm_size,
-                env=recipe.environment,
+                env=env,
                 has_docker_socket=recipe.has_docker_socket,
                 image_artifact=Artifact.objects.get(id=recipe.artifact)
                 if recipe.artifact
                 else None,
                 requires_gpu=recipe.requires_gpu,
                 extra_files=recipe.extra_files if recipe.extra_files else {},
+                token=token,
             )
-            for slug, recipe in self.recipes.items()
-        }
 
         # Apply parents
         for slug, recipe in self.recipes.items():
diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py
index 72411497f7..b91a6ffe95 100644
--- a/arkindex/ponos/tests/test_api.py
+++ b/arkindex/ponos/tests/test_api.py
@@ -360,7 +360,11 @@ class TestAPI(APITestCase):
             'image': 'hello-world',
             'command': None,
             'shm_size': None,
-            'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'},
+            'env': {
+                'test': 'test_workflow',
+                'top_env_variable': 'workflow_variable',
+                'ARKINDEX_TASK_TOKEN': self.task1.token,
+            },
             'has_docker_socket': False,
             'image_artifact_url': None,
             'agent_id': None,
@@ -403,7 +407,11 @@ class TestAPI(APITestCase):
             'image': 'hello-world',
             'command': None,
             'shm_size': None,
-            'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'},
+            'env': {
+                'test': 'test_workflow',
+                'top_env_variable': 'workflow_variable',
+                'ARKINDEX_TASK_TOKEN': self.task1.token,
+            },
             'has_docker_socket': False,
             'image_artifact_url': 'http://testserver' + reverse(
                 'api:task-artifact-download',
@@ -437,7 +445,11 @@ class TestAPI(APITestCase):
             'image': 'hello-world',
             'command': None,
             'shm_size': '128g',
-            'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'},
+            'env': {
+                'test': 'test_workflow',
+                'top_env_variable': 'workflow_variable',
+                'ARKINDEX_TASK_TOKEN': self.task1.token,
+            },
             'has_docker_socket': False,
             'image_artifact_url': None,
             'agent_id': None,
diff --git a/arkindex/process/tests/test_create_s3_import.py b/arkindex/process/tests/test_create_s3_import.py
index 604014e34c..2a7ac6e874 100644
--- a/arkindex/process/tests/test_create_s3_import.py
+++ b/arkindex/process/tests/test_create_s3_import.py
@@ -167,6 +167,7 @@ class TestCreateS3Import(FixtureTestCase):
         self.assertDictEqual(task.env, {
             'ARKINDEX_CORPUS_ID': str(self.corpus.id),
             'ARKINDEX_PROCESS_ID': str(process.id),
+            'ARKINDEX_TASK_TOKEN': task.token,
             'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
             'INGEST_S3_ENDPOINT': 'http://s3.null.teklia.com',
             'INGEST_S3_ACCESS_KEY': '🔑',
@@ -224,6 +225,7 @@ class TestCreateS3Import(FixtureTestCase):
         self.assertDictEqual(task.env, {
             'ARKINDEX_CORPUS_ID': str(self.corpus.id),
             'ARKINDEX_PROCESS_ID': str(process.id),
+            'ARKINDEX_TASK_TOKEN': task.token,
             'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
             'INGEST_S3_ENDPOINT': 'http://s3.null.teklia.com',
             'INGEST_S3_ACCESS_KEY': '🔑',
diff --git a/arkindex/process/tests/test_create_training_process.py b/arkindex/process/tests/test_create_training_process.py
index e334f08213..1e0e42b8aa 100644
--- a/arkindex/process/tests/test_create_training_process.py
+++ b/arkindex/process/tests/test_create_training_process.py
@@ -329,9 +329,11 @@ class TestCreateTrainingProcess(FixtureTestCase):
         self.assertEqual(sorted(task.env.keys()), [
             'ARKINDEX_CORPUS_ID',
             'ARKINDEX_PROCESS_ID',
+            'ARKINDEX_TASK_TOKEN',
             'ARKINDEX_WORKER_RUN_ID',
             'WORKER_VERSION_ID',
         ])
+        self.assertEqual(task.env['ARKINDEX_TASK_TOKEN'], task.token)
         self.assertEqual(task.requires_gpu, False)
 
         # Check worker run properties
@@ -410,10 +412,12 @@ class TestCreateTrainingProcess(FixtureTestCase):
         self.assertEqual(sorted(task.env.keys()), [
             'ARKINDEX_CORPUS_ID',
             'ARKINDEX_PROCESS_ID',
+            'ARKINDEX_TASK_TOKEN',
             'ARKINDEX_WORKER_RUN_ID',
             'WORKER_VERSION_ID',
         ])
         self.assertEqual(task.requires_gpu, True)
+        self.assertEqual(task.env['ARKINDEX_TASK_TOKEN'], task.token)
         # Check worker run properties
         self.assertEqual(str(training_process.worker_runs.get().id), task.env['ARKINDEX_WORKER_RUN_ID'])
         worker_run = WorkerRun.objects.get(id=task.env['ARKINDEX_WORKER_RUN_ID'])
-- 
GitLab