Skip to content
Snippets Groups Projects
Commit 64caf9da authored by Valentin Rigal's avatar Valentin Rigal Committed by Erwan Rouchet
Browse files

Store model version on the training worker run

parent 038fede6
No related branches found
No related tags found
1 merge request!1918Store model version on the training worker run
......@@ -350,14 +350,13 @@ class Process(IndexableModel):
version_id=worker_version.id,
configuration_id=worker_configuration and worker_configuration.id,
parents=[],
model_version=model_version,
)
env = {
'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
'WORKER_VERSION_ID': str(worker_version.id),
}
if model_version:
env['MODEL_VERSION_ID'] = str(model_version.id)
# Build the training task, as there is no initial task
tasks = {
......
......@@ -411,16 +411,14 @@ class TestCreateTrainingProcess(FixtureTestCase):
'ARKINDEX_CORPUS_ID',
'ARKINDEX_PROCESS_ID',
'ARKINDEX_WORKER_RUN_ID',
'MODEL_VERSION_ID',
'WORKER_VERSION_ID',
])
self.assertEqual(task.env['MODEL_VERSION_ID'], str(self.model_version.id))
self.assertEqual(task.requires_gpu, True)
# Check worker run properties
self.assertEqual(str(training_process.worker_runs.get().id), task.env['ARKINDEX_WORKER_RUN_ID'])
worker_run = WorkerRun.objects.get(id=task.env['ARKINDEX_WORKER_RUN_ID'])
self.assertEqual(worker_run.version_id, self.training_worker_version.id)
self.assertEqual(worker_run.model_version_id, None)
self.assertEqual(worker_run.model_version, self.model_version)
@override_settings(PONOS_RECIPE={})
def test_create_training_shm_size(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment