Skip to content
Snippets Groups Projects
Commit ceba3bdf authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Do not copy init elements worker runs in Process.copy_runs

parent 0400360e
No related branches found
No related tags found
1 merge request!2392Do not copy init elements worker runs in Process.copy_runs
......@@ -314,24 +314,28 @@ class Process(IndexableModel):
# Keep the relation between template_worker_runs and new_process_worker_runs ids
new_runs = {}
# First we copy each worker runs
# First we copy each worker runs, excluding elements initialisation worker runs
for run in self.worker_runs.all():
# Create a new WorkerRun with same version, configuration and parents.
new_run = WorkerRun(
process=new_process,
version_id=run.version_id,
model_version_id=run.model_version_id,
parents=run.parents,
configuration_id=run.configuration_id,
summary=run.summary
)
# Save the correspondence between this process' worker_run and the new one
new_runs[run.id] = new_run
if run.version_id != WorkerVersion.objects.init_elements_version.id:
# Create a new WorkerRun with same version, configuration and parents.
new_run = WorkerRun(
process=new_process,
version_id=run.version_id,
model_version_id=run.model_version_id,
parents=run.parents,
configuration_id=run.configuration_id,
summary=run.summary
)
# Save the correspondence between this process' worker_run and the new one
new_runs[run.id] = new_run
# Remap parent ids correctly
for run in new_runs.values():
# If there are parents, we need to use the correct worker_run id, the newly created one
run.parents = [new_runs[parent_id].id for parent_id in run.parents]
run.parents = [
new_runs[parent_id].id for parent_id in run.parents
if parent_id in new_runs
]
WorkerRun.objects.bulk_create(new_runs.values())
def list_elements(self):
......
......@@ -6,7 +6,14 @@ from rest_framework import status
from rest_framework.reverse import reverse
from arkindex.documents.models import Corpus
from arkindex.process.models import ProcessMode, WorkerConfiguration, WorkerRun, WorkerVersion, WorkerVersionState
from arkindex.process.models import (
Process,
ProcessMode,
WorkerConfiguration,
WorkerRun,
WorkerVersion,
WorkerVersionState,
)
from arkindex.project.tests import FixtureAPITestCase
from arkindex.training.models import Model, ModelVersionState
from arkindex.users.models import Role, User
......@@ -54,23 +61,23 @@ class TestTemplates(FixtureAPITestCase):
name="A config",
configuration={"param1": "value1"},
)
run_1 = cls.process_template.worker_runs.create(
cls.run_1 = cls.process_template.worker_runs.create(
version=cls.version_1, parents=[], configuration=cls.worker_configuration
)
cls.process_template.worker_runs.create(
version=cls.version_2,
parents=[run_1.id],
parents=[cls.run_1.id],
)
cls.model = Model.objects.create(name="moo")
cls.model_version = cls.model.versions.create(state=ModelVersionState.Available)
run_1 = cls.template.worker_runs.create(
cls.template_run_1 = cls.template.worker_runs.create(
version=cls.version_1, parents=[], configuration=cls.worker_configuration
)
cls.template.worker_runs.create(
version=cls.version_2,
parents=[run_1.id],
parents=[cls.template_run_1.id],
model_version=cls.model_version,
)
......@@ -83,7 +90,7 @@ class TestTemplates(FixtureAPITestCase):
def test_create(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(9):
response = self.client.post(
reverse(
"api:create-process-template", kwargs={"pk": str(self.process_template.id)}
......@@ -102,6 +109,41 @@ class TestTemplates(FixtureAPITestCase):
self.assertTrue(self.process_template.worker_runs.filter(version=parent_run.version).exists())
self.assertTrue(self.process_template.worker_runs.filter(version=child_run.version).exists())
def test_create_excludes_init_elements(self):
init_run = self.process_template.worker_runs.create(version=WorkerVersion.objects.init_elements_version)
self.run_1.parents = [init_run.id]
self.run_1.save()
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(
reverse(
"api:create-process-template", kwargs={"pk": str(self.process_template.id)}
),
{"name": "test_template"},
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.json()["mode"], "template")
new_process_id = response.json()["id"]
# Only two worker runs for the created template
self.assertEqual(self.process_template.worker_runs.count(), 3)
self.assertEqual(Process.objects.get(id=new_process_id).worker_runs.count(), 2)
# No elements initialisation run in the created template
self.assertFalse(WorkerRun.objects.filter(process_id=new_process_id, version_id=WorkerVersion.objects.init_elements_version).exists())
child_run, parent_run = WorkerRun.objects.select_related("version__worker").filter(process__id=new_process_id).order_by("version__worker__slug").all()
# Check dependencies
self.assertListEqual(parent_run.parents, [])
self.assertListEqual(child_run.parents, [parent_run.id])
# Check that every new worker_run is the same as one of the template's
self.assertTrue(self.process_template.worker_runs.filter(version=parent_run.version).exists())
self.assertTrue(self.process_template.worker_runs.filter(version=child_run.version).exists())
def test_create_requires_authentication(self):
response = self.client.post(
reverse("api:create-process-template", kwargs={"pk": str(self.process_template.id)})
......@@ -310,7 +352,7 @@ class TestTemplates(FixtureAPITestCase):
def test_apply(self):
self.assertIsNotNone(self.version_2.docker_image_iid)
self.client.force_login(self.user)
with self.assertNumQueries(9):
with self.assertNumQueries(10):
response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}),
......@@ -333,6 +375,44 @@ class TestTemplates(FixtureAPITestCase):
self.assertIsNone(child_run.configuration_id)
self.assertListEqual(child_run.parents, [parent_run.id])
def test_apply_excludes_init_elements(self):
init_run = self.template.worker_runs.create(version=WorkerVersion.objects.init_elements_version)
self.template_run_1.parents = [init_run.id]
self.template_run_1.save()
self.client.force_login(self.user)
with self.assertNumQueries(9):
response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(self.process.id)}),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["template_id"], str(self.template.id))
created_process_id = response.json()["id"]
# Only two workers runs in the created process
self.assertEqual(self.template.worker_runs.count(), 3)
self.assertEqual(Process.objects.get(id=created_process_id).worker_runs.count(), 2)
child_run, parent_run = WorkerRun.objects.select_related("version__worker").filter(process__id=created_process_id).order_by("version__worker__slug").all()
self.assertEqual(parent_run.process_id, self.process.id)
self.assertEqual(parent_run.version_id, self.version_1.id)
self.assertIsNone(parent_run.model_version_id)
self.assertEqual(parent_run.configuration_id, self.worker_configuration.id)
self.assertListEqual(parent_run.parents, [])
self.assertEqual(child_run.process_id, self.process.id)
self.assertEqual(child_run.version_id, self.version_2.id)
self.assertEqual(child_run.model_version_id, self.model_version.id)
self.assertIsNone(child_run.configuration_id)
self.assertListEqual(child_run.parents, [parent_run.id])
# No elements initialisation run in the created process
self.assertFalse(WorkerRun.objects.filter(process_id=created_process_id, version_id=WorkerVersion.objects.init_elements_version))
def test_apply_delete_previous_worker_runs(self):
self.client.force_login(self.user)
# Create a process with one worker run already
......@@ -344,7 +424,7 @@ class TestTemplates(FixtureAPITestCase):
parents=[],
)
# Apply a template that has two other worker runs
with self.assertNumQueries(11):
with self.assertNumQueries(12):
response = self.client.post(
reverse("api:apply-process-template", kwargs={"pk": str(self.template.id)}),
data=json.dumps({"process_id": str(process.id)}),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment