Skip to content
Snippets Groups Projects
Commit 651d888c authored by ml bonhomme's avatar ml bonhomme :bee:
Browse files

Have the restarted task keep the same name as the original one

parent 54a2c429
No related branches found
No related tags found
1 merge request!2300Have the restarted task keep the same name as the original one
...@@ -225,32 +225,42 @@ class TaskRestart(ProcessACLMixin, CreateAPIView): ...@@ -225,32 +225,42 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
) )
# TODO Check the original_task_id field directly once it is implemented # TODO Check the original_task_id field directly once it is implemented
# https://gitlab.teklia.com/arkindex/frontend/-/issues/1383 # https://gitlab.teklia.com/arkindex/frontend/-/issues/1383
if task.process.tasks.filter(run=task.run, slug=self.increment(task.slug)).exists(): _, *suffix = task.slug.rsplit("_old", 1)
if suffix:
raise ValidationError( raise ValidationError(
detail="This task has already been restarted" detail="This task has already been restarted"
) )
return task return task
def increment(self, name):
basename, *suffix = name.rsplit("_restart", 1)
suffix = int(suffix[0]) + 1 if suffix and suffix[0].isdigit() else 1
return f"{basename}_restart{suffix}"
@transaction.atomic @transaction.atomic
def create(self, request, pk=None, **kwargs): def create(self, request, pk=None, **kwargs):
copy = self.get_task() copy = self.get_task()
parents = list(copy.parents.all()) parents = list(copy.parents.all())
# Rename the original task
basename, *_ = copy.slug.rsplit("_old", 1)
latest_task = Task.objects.filter(run=copy.run, slug__startswith=f"{basename}_old").order_by("-created").first()
if not latest_task:
# There is no previously restarted task: the original task will have the slug slug_old1
suffix = 1
else:
# If there are previously restarted tasks, then the suffix is incremented
_, *suffix = latest_task.slug.rsplit("_old", 1)
suffix = int(suffix[0]) + 1 if suffix and suffix[0].isdigit() else 1
copy.slug = f"{basename}_old{suffix}"
copy.save()
# Copy the original task
copy.id = uuid.uuid4() copy.id = uuid.uuid4()
copy.slug = basename
copy.state = State.Pending copy.state = State.Pending
copy.token = task_token_default() copy.token = task_token_default()
copy.slug = self.increment(copy.slug)
copy.save() copy.save()
# Create links to retried task parents # Create links to restarted task parents
copy.parents.add(*parents) copy.parents.add(*parents)
# Move all tasks depending on the retried task to the copy # Move all tasks depending on the restarted task to the copy
Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id) Task.children.through.objects.filter(to_task_id=pk).update(to_task_id=copy.id)
return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED) return Response(TaskSerializer(copy).data, status=status.HTTP_201_CREATED)
import uuid import uuid
from datetime import datetime, timedelta, timezone
from io import BytesIO from io import BytesIO
from unittest import expectedFailure from unittest import expectedFailure
from unittest.mock import call, patch, seal from unittest.mock import call, patch, seal
...@@ -582,11 +583,10 @@ class TestAPI(FixtureAPITestCase): ...@@ -582,11 +583,10 @@ class TestAPI(FixtureAPITestCase):
def test_restart_task_already_restarted(self): def test_restart_task_already_restarted(self):
self.client.force_login(self.user) self.client.force_login(self.user)
self.task2.slug = self.task1.slug + "_restart1" self.task1.slug = self.task1.slug + "_old1"
self.task2.save()
self.task1.state = State.Completed.value self.task1.state = State.Completed.value
self.task1.save() self.task1.save()
with self.assertNumQueries(8): with self.assertNumQueries(7):
response = self.client.post( response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
) )
...@@ -600,16 +600,92 @@ class TestAPI(FixtureAPITestCase): ...@@ -600,16 +600,92 @@ class TestAPI(FixtureAPITestCase):
def test_restart_task(self, s3_mock): def test_restart_task(self, s3_mock):
""" """
From: From:
task1 → task2_restart42 → task3 task2_old1
↘ ↗
task 4 task1 → task2 → task3
↘ ↗
task 4
To:
task2_old1
task1 → task2_old2
task2 → task3
↘ ↗
task 4
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {
"Body": BytesIO(b"Task has been restarted")
}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
task4 = self.process.tasks.create(run=self.task1.run, depth=1)
task4.parents.add(self.task2)
task4.children.add(self.task3)
task_2_slug = self.task2.slug
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=1)
old_task_2 = self.process.tasks.create(run=self.task1.run, depth=1, slug=f"{task_2_slug}_old1")
old_task_2.state = State.Error.value
old_task_2.save()
old_task_2.parents.add(self.task1)
self.task1.state = State.Completed.value
self.task1.save()
self.task2.state = State.Error.value
self.task2.save()
self.client.force_login(self.user)
with self.assertNumQueries(13):
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=2)
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.process.tasks.count(), 6)
restarted_task = self.process.tasks.latest("created")
self.assertDictEqual(
response.json(),
{
"id": str(restarted_task.id),
"depth": 1,
"agent": None,
"extra_files": {},
"full_log": "http://somewhere",
"gpu": None,
"logs": "Task has been restarted",
"parents": [str(self.task1.id)],
"run": 0,
"shm_size": None,
"slug": task_2_slug,
"state": "pending",
},
)
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
self.task2.refresh_from_db()
self.assertEqual(self.task2.slug, f"{task_2_slug}_old2")
self.assertQuerysetEqual(
restarted_task.children.all(),
Task.objects.filter(id__in=[self.task3.id, task4.id]),
)
@patch("arkindex.project.aws.s3")
def test_restart_task_no_previous_restart(self, s3_mock):
"""
From:
task1 → task2 → task3
↘ ↗
task 4
To: To:
task1 → task2_restart42 task1 → task2_old1
task2_restart43 → task3 task2 → task3
↘ ↗ ↘ ↗
task 4 task 4
""" """
s3_mock.Object.return_value.bucket_name = "ponos" s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog" s3_mock.Object.return_value.key = "somelog"
...@@ -621,14 +697,14 @@ class TestAPI(FixtureAPITestCase): ...@@ -621,14 +697,14 @@ class TestAPI(FixtureAPITestCase):
task4 = self.process.tasks.create(run=self.task1.run, depth=1) task4 = self.process.tasks.create(run=self.task1.run, depth=1)
task4.parents.add(self.task2) task4.parents.add(self.task2)
task4.children.add(self.task3) task4.children.add(self.task3)
task_2_slug = self.task2.slug
self.task1.state = State.Completed.value self.task1.state = State.Completed.value
self.task1.save() self.task1.save()
self.task2.state = State.Error.value self.task2.state = State.Error.value
self.task2.slug = "task2_restart42"
self.task2.save() self.task2.save()
self.client.force_login(self.user) self.client.force_login(self.user)
with self.assertNumQueries(12): with self.assertNumQueries(13):
response = self.client.post( response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task2.id)}) reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
) )
...@@ -648,11 +724,13 @@ class TestAPI(FixtureAPITestCase): ...@@ -648,11 +724,13 @@ class TestAPI(FixtureAPITestCase):
"parents": [str(self.task1.id)], "parents": [str(self.task1.id)],
"run": 0, "run": 0,
"shm_size": None, "shm_size": None,
"slug": "task2_restart43", "slug": task_2_slug,
"state": "pending", "state": "pending",
}, },
) )
self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none()) self.assertQuerysetEqual(self.task2.children.all(), Task.objects.none())
self.task2.refresh_from_db()
self.assertEqual(self.task2.slug, f"{task_2_slug}_old1")
self.assertQuerysetEqual( self.assertQuerysetEqual(
restarted_task.children.all(), restarted_task.children.all(),
Task.objects.filter(id__in=[self.task3.id, task4.id]), Task.objects.filter(id__in=[self.task3.id, task4.id]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment