Skip to content
Snippets Groups Projects
Commit 17a27902 authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Filter existing tasks by process when restarting a task

parent fc247d34
No related branches found
No related tags found
1 merge request!2431Filter existing tasks by process when restarting a task
......@@ -218,7 +218,7 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
basename, *_ = copy.slug.rsplit("_old", 1)
else:
basename = copy.slug
latest_task = Task.objects.filter(run=copy.run, slug__startswith=f"{basename}_old").order_by("-created").first()
latest_task = Task.objects.filter(process=copy.process, run=copy.run, slug__startswith=f"{basename}_old").order_by("-created").first()
if not latest_task:
# There is no previously restarted task: the original task will have the slug slug_old1
suffix = 1
......
......@@ -1256,6 +1256,79 @@ class TestAPI(FixtureAPITestCase):
Task.objects.filter(id__in=[self.task3.id, task4.id]),
)
@override_settings(PONOS_RQ_EXECUTION=False)
@patch("arkindex.project.aws.s3")
def test_restart_task_name_process(self, s3_mock):
"""
When a task is restarted, its name depends on the existing tasks already existing on the
same process, not on tasks that exist on other processes
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {
"Body": BytesIO(b"Task has been restarted")
}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
task_1_slug = self.task1.slug
other_process = Process.objects.create(
farm=Farm.objects.first(),
mode=ProcessMode.Workers,
corpus=self.corpus,
creator=self.user,
)
other_process.tasks.create(
run=0,
depth=0,
slug=f"{task_1_slug}_old1",
expiry=datetime(1970, 1, 1, tzinfo=timezone.utc),
)
self.task1.state = State.Error.value
self.task1.save()
self.client.force_login(self.user)
with self.assertNumQueries(13):
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=2)
response = self.client.post(
reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(self.process.tasks.count(), 4)
restarted_task = self.process.tasks.latest("created")
self.assertDictEqual(
response.json(),
{
"id": str(restarted_task.id),
"depth": 0,
"agent": None,
"extra_files": {},
"full_log": "http://somewhere",
"gpu": None,
"logs": "Task has been restarted",
"original_task_id": str(self.task1.id),
"parents": [],
"run": 0,
"shm_size": None,
"slug": task_1_slug,
"state": "pending",
"requires_gpu": False,
},
)
self.assertQuerySetEqual(self.task1.children.all(), Task.objects.none())
self.task1.refresh_from_db()
self.assertEqual(self.task1.slug, f"{task_1_slug}_old1")
self.assertNotEqual(self.task1.token, restarted_task.token)
self.assertIsNone(restarted_task.agent_id)
self.assertIsNone(restarted_task.gpu_id)
self.assertQuerySetEqual(
restarted_task.children.all(),
Task.objects.filter(id__in=[self.task2.id]),
)
@override_settings(PONOS_RQ_EXECUTION=True)
@patch("arkindex.ponos.tasks.run_task_rq.delay")
@patch("arkindex.project.aws.s3")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment