diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py index a0a43c5ed8f59b9fe496d70a5f958ccbfe6acf04..a24da2fe6791d8c95a3824c038b18196865d6619 100644 --- a/arkindex/ponos/api.py +++ b/arkindex/ponos/api.py @@ -256,6 +256,8 @@ class TaskRestart(ProcessACLMixin, CreateAPIView): copy.slug = basename copy.state = State.Pending copy.token = task_token_default() + copy.agent_id = None + copy.gpu_id = None copy.save() # Create links to restarted task parents diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index 5b9b3914fe2e655bc97817e55155d7e4fba32551..f38ad3a34d9f3ca1e64caf007b9176f3365165d7 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -50,7 +50,7 @@ class TestAPI(FixtureAPITestCase): cls.docker_agent = Agent.objects.create( mode=AgentMode.Docker, farm=cls.farm, - last_ping=datetime.now(), + last_ping=datetime.now(timezone.utc), cpu_cores=42, cpu_frequency=42e8, ram_total=42e3 @@ -58,7 +58,7 @@ class TestAPI(FixtureAPITestCase): cls.slurm_agent = Agent.objects.create( mode=AgentMode.Slurm, farm=cls.farm, - last_ping=datetime.now(), + last_ping=datetime.now(timezone.utc), ) @property @@ -804,17 +804,27 @@ class TestAPI(FixtureAPITestCase): task4.parents.add(self.task2) task4.children.add(self.task3) task_2_slug = self.task2.slug + with patch("django.utils.timezone.now") as mock_now: mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=1) old_task_2 = self.process.tasks.create(run=self.task1.run, depth=1, slug=f"{task_2_slug}_old1") old_task_2.state = State.Error.value old_task_2.original_task_id = self.task1.id old_task_2.save() + old_task_2.parents.add(self.task1) + self.task1.state = State.Completed.value self.task1.save() self.task2.state = State.Error.value self.task2.requires_gpu = True + self.task2.agent = self.docker_agent + self.task2.gpu = self.docker_agent.gpus.create( + id=uuid.uuid4(), + name="gee pee you", + index=0, + ram_total=42e9, + ) self.task2.save() self.client.force_login(self.user) @@ -825,6 +835,7 @@ class TestAPI(FixtureAPITestCase): reverse("api:task-restart", kwargs={"pk": str(self.task2.id)}) ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(self.process.tasks.count(), 6) restarted_task = self.process.tasks.latest("created") self.assertDictEqual( @@ -849,6 +860,9 @@ class TestAPI(FixtureAPITestCase): self.assertQuerySetEqual(self.task2.children.all(), Task.objects.none()) self.task2.refresh_from_db() self.assertEqual(self.task2.slug, f"{task_2_slug}_old2") + self.assertNotEqual(self.task2.token, restarted_task.token) + self.assertIsNone(restarted_task.agent_id) + self.assertIsNone(restarted_task.gpu_id) self.assertQuerySetEqual( restarted_task.children.all(), Task.objects.filter(id__in=[self.task3.id, task4.id]),