diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py index 930ce597d8a2eed88a86287349654c8d83a2e302..4a9710570240cfa821f4bd8ad4cb75573c2eb9cf 100644 --- a/arkindex/ponos/models.py +++ b/arkindex/ponos/models.py @@ -416,8 +416,15 @@ class Workflow(models.Model): ) # Create tasks without any parent - tasks = { - slug: self.tasks.create( + tasks = {} + for slug, recipe in self.recipes.items(): + # Add the task token to the environment now, as higher-level code cannot add a token + # when building workflow recipes since the Task instances do not exist. + env = recipe.environment.copy() + token = task_token_default() + env['ARKINDEX_TASK_TOKEN'] = token + + tasks[slug] = self.tasks.create( run=run, slug=slug, tags=recipe.tags, @@ -425,16 +432,15 @@ class Workflow(models.Model): image=recipe.image, command=recipe.command, shm_size=recipe.shm_size, - env=recipe.environment, + env=env, has_docker_socket=recipe.has_docker_socket, image_artifact=Artifact.objects.get(id=recipe.artifact) if recipe.artifact else None, requires_gpu=recipe.requires_gpu, extra_files=recipe.extra_files if recipe.extra_files else {}, + token=token, ) - for slug, recipe in self.recipes.items() - } # Apply parents for slug, recipe in self.recipes.items(): diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index 72411497f78be94b54aa472e6874b9cc6a1ea831..b91a6ffe95b120375aea97a3a39d368616808466 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -360,7 +360,11 @@ class TestAPI(APITestCase): 'image': 'hello-world', 'command': None, 'shm_size': None, - 'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'}, + 'env': { + 'test': 'test_workflow', + 'top_env_variable': 'workflow_variable', + 'ARKINDEX_TASK_TOKEN': self.task1.token, + }, 'has_docker_socket': False, 'image_artifact_url': None, 'agent_id': None, @@ -403,7 +407,11 @@ class TestAPI(APITestCase): 'image': 'hello-world', 'command': None, 'shm_size': None, - 'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'}, + 'env': { + 'test': 'test_workflow', + 'top_env_variable': 'workflow_variable', + 'ARKINDEX_TASK_TOKEN': self.task1.token, + }, 'has_docker_socket': False, 'image_artifact_url': 'http://testserver' + reverse( 'api:task-artifact-download', @@ -437,7 +445,11 @@ class TestAPI(APITestCase): 'image': 'hello-world', 'command': None, 'shm_size': '128g', - 'env': {'test': 'test_workflow', 'top_env_variable': 'workflow_variable'}, + 'env': { + 'test': 'test_workflow', + 'top_env_variable': 'workflow_variable', + 'ARKINDEX_TASK_TOKEN': self.task1.token, + }, 'has_docker_socket': False, 'image_artifact_url': None, 'agent_id': None, diff --git a/arkindex/process/tests/test_create_s3_import.py b/arkindex/process/tests/test_create_s3_import.py index 604014e34cbde7ff8347b57bc733d54b68cdda34..2a7ac6e8741c97b540f8f6038ece21748e965360 100644 --- a/arkindex/process/tests/test_create_s3_import.py +++ b/arkindex/process/tests/test_create_s3_import.py @@ -167,6 +167,7 @@ class TestCreateS3Import(FixtureTestCase): self.assertDictEqual(task.env, { 'ARKINDEX_CORPUS_ID': str(self.corpus.id), 'ARKINDEX_PROCESS_ID': str(process.id), + 'ARKINDEX_TASK_TOKEN': task.token, 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id), 'INGEST_S3_ENDPOINT': 'http://s3.null.teklia.com', 'INGEST_S3_ACCESS_KEY': '🔑', @@ -224,6 +225,7 @@ class TestCreateS3Import(FixtureTestCase): self.assertDictEqual(task.env, { 'ARKINDEX_CORPUS_ID': str(self.corpus.id), 'ARKINDEX_PROCESS_ID': str(process.id), + 'ARKINDEX_TASK_TOKEN': task.token, 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id), 'INGEST_S3_ENDPOINT': 'http://s3.null.teklia.com', 'INGEST_S3_ACCESS_KEY': '🔑', diff --git a/arkindex/process/tests/test_create_training_process.py b/arkindex/process/tests/test_create_training_process.py index e334f082138d717e480fb6e530e8be84d207b477..1e0e42b8aa2ff86304ee77cb82b6432b91cacbf4 100644 --- a/arkindex/process/tests/test_create_training_process.py +++ b/arkindex/process/tests/test_create_training_process.py @@ -329,9 +329,11 @@ class TestCreateTrainingProcess(FixtureTestCase): self.assertEqual(sorted(task.env.keys()), [ 'ARKINDEX_CORPUS_ID', 'ARKINDEX_PROCESS_ID', + 'ARKINDEX_TASK_TOKEN', 'ARKINDEX_WORKER_RUN_ID', 'WORKER_VERSION_ID', ]) + self.assertEqual(task.env['ARKINDEX_TASK_TOKEN'], task.token) self.assertEqual(task.requires_gpu, False) # Check worker run properties @@ -410,10 +412,12 @@ class TestCreateTrainingProcess(FixtureTestCase): self.assertEqual(sorted(task.env.keys()), [ 'ARKINDEX_CORPUS_ID', 'ARKINDEX_PROCESS_ID', + 'ARKINDEX_TASK_TOKEN', 'ARKINDEX_WORKER_RUN_ID', 'WORKER_VERSION_ID', ]) self.assertEqual(task.requires_gpu, True) + self.assertEqual(task.env['ARKINDEX_TASK_TOKEN'], task.token) # Check worker run properties self.assertEqual(str(training_process.worker_runs.get().id), task.env['ARKINDEX_WORKER_RUN_ID']) worker_run = WorkerRun.objects.get(id=task.env['ARKINDEX_WORKER_RUN_ID'])