From 64caf9da8c1605e4bc4ae847f1c06677d9cecd26 Mon Sep 17 00:00:00 2001
From: Valentin Rigal <rigal@teklia.com>
Date: Tue, 14 Feb 2023 16:54:28 +0000
Subject: [PATCH] Store model version on the training worker run

---
 arkindex/process/models.py                             | 3 +--
 arkindex/process/tests/test_create_training_process.py | 4 +---
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/arkindex/process/models.py b/arkindex/process/models.py
index ca182da0ec..2d22c239c5 100644
--- a/arkindex/process/models.py
+++ b/arkindex/process/models.py
@@ -350,14 +350,13 @@ class Process(IndexableModel):
                 version_id=worker_version.id,
                 configuration_id=worker_configuration and worker_configuration.id,
                 parents=[],
+                model_version=model_version,
             )
 
             env = {
                 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
                 'WORKER_VERSION_ID': str(worker_version.id),
             }
-            if model_version:
-                env['MODEL_VERSION_ID'] = str(model_version.id)
 
             # Build the training task, as there is no initial task
             tasks = {
diff --git a/arkindex/process/tests/test_create_training_process.py b/arkindex/process/tests/test_create_training_process.py
index 794305150a..e334f08213 100644
--- a/arkindex/process/tests/test_create_training_process.py
+++ b/arkindex/process/tests/test_create_training_process.py
@@ -411,16 +411,14 @@ class TestCreateTrainingProcess(FixtureTestCase):
             'ARKINDEX_CORPUS_ID',
             'ARKINDEX_PROCESS_ID',
             'ARKINDEX_WORKER_RUN_ID',
-            'MODEL_VERSION_ID',
             'WORKER_VERSION_ID',
         ])
-        self.assertEqual(task.env['MODEL_VERSION_ID'], str(self.model_version.id))
         self.assertEqual(task.requires_gpu, True)
         # Check worker run properties
         self.assertEqual(str(training_process.worker_runs.get().id), task.env['ARKINDEX_WORKER_RUN_ID'])
         worker_run = WorkerRun.objects.get(id=task.env['ARKINDEX_WORKER_RUN_ID'])
         self.assertEqual(worker_run.version_id, self.training_worker_version.id)
-        self.assertEqual(worker_run.model_version_id, None)
+        self.assertEqual(worker_run.model_version, self.model_version)
 
     @override_settings(PONOS_RECIPE={})
     def test_create_training_shm_size(self):
-- 
GitLab