diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py index c5b6a4d0c5043eeabb9c5cb1c20383fb171be753..eb0fc55e095b5eda5f67773994843575cb2131e4 100644 --- a/arkindex/ponos/api.py +++ b/arkindex/ponos/api.py @@ -218,7 +218,7 @@ class TaskRestart(ProcessACLMixin, CreateAPIView): basename, *_ = copy.slug.rsplit("_old", 1) else: basename = copy.slug - latest_task = Task.objects.filter(run=copy.run, slug__startswith=f"{basename}_old").order_by("-created").first() + latest_task = Task.objects.filter(process=copy.process, run=copy.run, slug__startswith=f"{basename}_old").order_by("-created").first() if not latest_task: # There is no previously restarted task: the original task will have the slug slug_old1 suffix = 1 diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index 9b59df1ef9b4c35a5252fc5dd7ff24563988f9d2..09ddc8be17e695286d0be15452ba8eee980c8599 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -1256,6 +1256,79 @@ class TestAPI(FixtureAPITestCase): Task.objects.filter(id__in=[self.task3.id, task4.id]), ) + @override_settings(PONOS_RQ_EXECUTION=False) + @patch("arkindex.project.aws.s3") + def test_restart_task_name_process(self, s3_mock): + """ + When a task is restarted, its name depends on the existing tasks already existing on the + same process, not on tasks that exist on other processes + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Task has been restarted") + } + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + task_1_slug = self.task1.slug + + other_process = Process.objects.create( + farm=Farm.objects.first(), + mode=ProcessMode.Workers, + corpus=self.corpus, + creator=self.user, + ) + other_process.tasks.create( + run=0, + depth=0, + slug=f"{task_1_slug}_old1", + expiry=datetime(1970, 1, 1, tzinfo=timezone.utc), + ) + + self.task1.state = State.Error.value + self.task1.save() + + self.client.force_login(self.user) + with self.assertNumQueries(13): + with patch("django.utils.timezone.now") as mock_now: + mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=2) + response = self.client.post( + reverse("api:task-restart", kwargs={"pk": str(self.task1.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + self.assertEqual(self.process.tasks.count(), 4) + restarted_task = self.process.tasks.latest("created") + self.assertDictEqual( + response.json(), + { + "id": str(restarted_task.id), + "depth": 0, + "agent": None, + "extra_files": {}, + "full_log": "http://somewhere", + "gpu": None, + "logs": "Task has been restarted", + "original_task_id": str(self.task1.id), + "parents": [], + "run": 0, + "shm_size": None, + "slug": task_1_slug, + "state": "pending", + "requires_gpu": False, + }, + ) + self.assertQuerySetEqual(self.task1.children.all(), Task.objects.none()) + self.task1.refresh_from_db() + self.assertEqual(self.task1.slug, f"{task_1_slug}_old1") + self.assertNotEqual(self.task1.token, restarted_task.token) + self.assertIsNone(restarted_task.agent_id) + self.assertIsNone(restarted_task.gpu_id) + self.assertQuerySetEqual( + restarted_task.children.all(), + Task.objects.filter(id__in=[self.task2.id]), + ) + @override_settings(PONOS_RQ_EXECUTION=True) @patch("arkindex.ponos.tasks.run_task_rq.delay") @patch("arkindex.project.aws.s3")