Skip to content
Snippets Groups Projects
Commit 651d888c authored by ml bonhomme's avatar ml bonhomme :bee:
Browse files

Have the restarted task keep the same name as the original one

parent 54a2c429
No related branches found
No related tags found
1 merge request!2300Have the restarted task keep the same name as the original one
......@@ -225,32 +225,42 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
)
# TODO Check the original_task_id field directly once it is implemented
# https://gitlab.teklia.com/arkindex/frontend/-/issues/1383
if task.process.tasks.filter(run=task.run, slug=self.increment(task.slug)).exists():
_, *suffix = task.slug.rsplit("_old", 1)
if suffix:
raise ValidationError(
detail="This task has already been restarted"
)
return task
def increment(self, name):
basename, *suffix = name.rsplit("_restart", 1)
suffix = int(suffix[0]) + 1 if suffix and suffix[0].isdigit() else 1
return f"{basename}_restart{suffix}"
@transaction.atomic
def create(self, request, pk=None, **kwargs):
copy = self.get_task()
parents = list(copy.parents.all())
# Rename the original task
basename, *_ = copy.slug.rsplit("_old", 1)
latest_task = Task.objects.filter(run=copy.run, slug__startswith=f"{basename}_old").order_by("-created").first()
if not latest_task:
# There is no previously restarted task: the original task will have the slug slug_old1
suffix = 1
else:
# If there are previously restarted tasks, then the suffix is incremented
_, *suffix = latest_task.slug.rsplit("_old", 1)
suffix = int(suffix[0]) + 1 if suffix and suffix[0].isdigit() else 1
copy.slug = f"{basename}_old{suffix}"
copy.save()
# Copy the original task
copy.id = uuid.uuid4()
copy.slug = basename
copy.state = State.Pending
copy.token = task_token_default()
copy.slug = self.increment(copy.slug)
copy.save()
# Create links to retried task parents
# Create links to restarted task parents
copy.parents.add(*parents)
# Move all tasks depending on the retried task to the copy
# Move all tasks depending on the restarted task to the copy
Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id)
return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED)
import uuid
from datetime import datetime, timedelta, timezone
from io import BytesIO
from unittest import expectedFailure
from unittest.mock import call, patch, seal
......@@ -582,11 +583,10 @@ class TestAPI(FixtureAPITestCase):
def test_restart_task_already_restarted(self):
self.client.force_login(self.user)
self.task2.slug = self.task1.slug + "_restart1"
self.task2.save()
self.task1.slug = self.task1.slug + "_old1"
self.task1.state = State.Completed.value
self.task1.save()
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
......@@ -600,16 +600,92 @@ class TestAPI(FixtureAPITestCase):
def test_restart_task(self, s3_mock):
"""
From:
task1 → task2_restart42 → task3
↘ ↗
task 4
task2_old1
task1 → task2 → task3
↘ ↗
task 4
To:
task2_old1
task1 → task2_old2
task2 → task3
↘ ↗
task 4
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {
"Body": BytesIO(b"Task has been restarted")
}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
task4 = self.process.tasks.create(run=self.task1.run, depth=1)
task4.parents.add(self.task2)
task4.children.add(self.task3)
task_2_slug = self.task2.slug
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=1)
old_task_2 = self.process.tasks.create(run=self.task1.run, depth=1, slug=f"{task_2_slug}_old1")
old_task_2.state = State.Error.value
old_task_2.save()
old_task_2.parents.add(self.task1)
self.task1.state = State.Completed.value
self.task1.save()
self.task2.state = State.Error.value
self.task2.save()
self.client.force_login(self.user)
with self.assertNumQueries(13):
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=2)
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.process.tasks.count(), 6)
restarted_task = self.process.tasks.latest("created")
self.assertDictEqual(
response.json(),
{
"id": str(restarted_task.id),
"depth": 1,
"agent": None,
"extra_files": {},
"full_log": "http://somewhere",
"gpu": None,
"logs": "Task has been restarted",
"parents": [str(self.task1.id)],
"run": 0,
"shm_size": None,
"slug": task_2_slug,
"state": "pending",
},
)
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
self.task2.refresh_from_db()
self.assertEqual(self.task2.slug, f"{task_2_slug}_old2")
self.assertQuerysetEqual(
restarted_task.children.all(),
Task.objects.filter(id__in=[self.task3.id, task4.id]),
)
@patch("arkindex.project.aws.s3")
def test_restart_task_no_previous_restart(self, s3_mock):
"""
From:
task1 → task2 → task3
↘ ↗
task 4
To:
task1 → task2_restart42
task1 → task2_old1
task2_restart43 → task3
↘ ↗
task 4
task2 → task3
↘ ↗
task 4
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
......@@ -621,14 +697,14 @@ class TestAPI(FixtureAPITestCase):
task4 = self.process.tasks.create(run=self.task1.run, depth=1)
task4.parents.add(self.task2)
task4.children.add(self.task3)
task_2_slug = self.task2.slug
self.task1.state = State.Completed.value
self.task1.save()
self.task2.state = State.Error.value
self.task2.slug = "task2_restart42"
self.task2.save()
self.client.force_login(self.user)
with self.assertNumQueries(12):
with self.assertNumQueries(13):
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
)
......@@ -648,11 +724,13 @@ class TestAPI(FixtureAPITestCase):
"parents": [str(self.task1.id)],
"run": 0,
"shm_size": None,
"slug": "task2_restart43",
"slug": task_2_slug,
"state": "pending",
},
)
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
self.task2.refresh_from_db()
self.assertEqual(self.task2.slug, f"{task_2_slug}_old1")
self.assertQuerysetEqual(
restarted_task.children.all(),
Task.objects.filter(id__in=[self.task3.id, task4.id]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment