diff --git a/arkindex/process/models.py b/arkindex/process/models.py index ca182da0ec3a0e746e6c53f02e0f43d210f2cc67..2d22c239c5cd076fb6fbb0984f19a49a04a0e5a2 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -350,14 +350,13 @@ class Process(IndexableModel): version_id=worker_version.id, configuration_id=worker_configuration and worker_configuration.id, parents=[], + model_version=model_version, ) env = { 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id), 'WORKER_VERSION_ID': str(worker_version.id), } - if model_version: - env['MODEL_VERSION_ID'] = str(model_version.id) # Build the training task, as there is no initial task tasks = { diff --git a/arkindex/process/tests/test_create_training_process.py b/arkindex/process/tests/test_create_training_process.py index 794305150a1b28be19c42daa7f66244e909db05d..e334f082138d717e480fb6e530e8be84d207b477 100644 --- a/arkindex/process/tests/test_create_training_process.py +++ b/arkindex/process/tests/test_create_training_process.py @@ -411,16 +411,14 @@ class TestCreateTrainingProcess(FixtureTestCase): 'ARKINDEX_CORPUS_ID', 'ARKINDEX_PROCESS_ID', 'ARKINDEX_WORKER_RUN_ID', - 'MODEL_VERSION_ID', 'WORKER_VERSION_ID', ]) - self.assertEqual(task.env['MODEL_VERSION_ID'], str(self.model_version.id)) self.assertEqual(task.requires_gpu, True) # Check worker run properties self.assertEqual(str(training_process.worker_runs.get().id), task.env['ARKINDEX_WORKER_RUN_ID']) worker_run = WorkerRun.objects.get(id=task.env['ARKINDEX_WORKER_RUN_ID']) self.assertEqual(worker_run.version_id, self.training_worker_version.id) - self.assertEqual(worker_run.model_version_id, None) + self.assertEqual(worker_run.model_version, self.model_version) @override_settings(PONOS_RECIPE={}) def test_create_training_shm_size(self):