From 30f88fc5f5523b87a80cc2088d8d0b64ed331c00 Mon Sep 17 00:00:00 2001
From: Theo Lesage <tlesage@teklia.com>
Date: Mon, 6 May 2024 14:23:32 +0000
Subject: [PATCH] Link tasks to their restarts

---
 .../tests/tasks/test_corpus_delete.py         |  8 ++++++--
 arkindex/ponos/admin.py                       |  2 ++
 arkindex/ponos/api.py                         | 13 +++++++------
 .../migrations/0009_task_original_task.py     | 19 +++++++++++++++++++
 arkindex/ponos/models.py                      |  8 +++++++-
 arkindex/ponos/serializers.py                 |  2 ++
 arkindex/ponos/tests/test_api.py              | 16 ++++++++++++----
 7 files changed, 55 insertions(+), 13 deletions(-)
 create mode 100644 arkindex/ponos/migrations/0009_task_original_task.py

diff --git a/arkindex/documents/tests/tasks/test_corpus_delete.py b/arkindex/documents/tests/tasks/test_corpus_delete.py
index 527431b3f5..5a0e02f3c5 100644
--- a/arkindex/documents/tests/tasks/test_corpus_delete.py
+++ b/arkindex/documents/tests/tasks/test_corpus_delete.py
@@ -49,7 +49,7 @@ class TestDeleteCorpus(FixtureTestCase):
         )
         element_process.elements.add(element)
         worker_run = element_process.worker_runs.create(version=cls.worker_version, parents=[])
-        task_1, task_2, task_3 = Task.objects.bulk_create(
+        task_1, task_2, task_3, task_4 = Task.objects.bulk_create(
             [
                 Task(
                     run=0,
@@ -58,11 +58,15 @@ class TestDeleteCorpus(FixtureTestCase):
                     worker_run=worker_run,
                     slug=f"unscheduled task {i}",
                     state=State.Unscheduled,
-                ) for i in range(1, 4)
+                ) for i in range(0, 4)
             ]
         )
         task_1.parents.set([task_2])
         task_3.parents.set([task_1, task_2])
+        task_3.slug += "_old1"
+        task_4.original_task_id = task_3.id
+        task_3.save()
+        task_4.save()
         element.worker_run = worker_run
         element.worker_version = cls.worker_version
         element.save()
diff --git a/arkindex/ponos/admin.py b/arkindex/ponos/admin.py
index d2210a4668..0920c2190d 100644
--- a/arkindex/ponos/admin.py
+++ b/arkindex/ponos/admin.py
@@ -41,6 +41,7 @@ class TaskAdmin(admin.ModelAdmin):
         "updated",
         "container",
         "shm_size",
+        "original_task",
     )
     fieldsets = (
         (
@@ -54,6 +55,7 @@ class TaskAdmin(admin.ModelAdmin):
                     "state",
                     "process",
                     "priority",
+                    "original_task",
                 ),
             },
         ),
diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py
index 0c123b1082..368a44055b 100644
--- a/arkindex/ponos/api.py
+++ b/arkindex/ponos/api.py
@@ -223,12 +223,9 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
             raise ValidationError(
                 detail="Task's state must be in a final state to be restarted."
             )
-        # TODO Check the original_task_id field directly once it is implemented
-        # https://gitlab.teklia.com/arkindex/frontend/-/issues/1383
-        _, *suffix = task.slug.rsplit("_old", 1)
-        if suffix:
+        if task.restarts.exists():
             raise ValidationError(
-                detail="This task has already been restarted"
+                detail="This task has already been restarted."
             )
         return task
 
@@ -238,7 +235,10 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
         parents = list(copy.parents.all())
 
         # Rename the original task
-        basename, *_ = copy.slug.rsplit("_old", 1)
+        if copy.original_task_id:
+            basename, *_ = copy.slug.rsplit("_old", 1)
+        else:
+            basename = copy.slug
         latest_task = Task.objects.filter(run=copy.run, slug__startswith=f"{basename}_old").order_by("-created").first()
         if not latest_task:
             # There is no previously restarted task: the original task will have the slug slug_old1
@@ -251,6 +251,7 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
         copy.save()
 
         # Copy the original task
+        copy.original_task_id = copy.id
         copy.id = uuid.uuid4()
         copy.slug = basename
         copy.state = State.Pending
diff --git a/arkindex/ponos/migrations/0009_task_original_task.py b/arkindex/ponos/migrations/0009_task_original_task.py
new file mode 100644
index 0000000000..a922fbdc9a
--- /dev/null
+++ b/arkindex/ponos/migrations/0009_task_original_task.py
@@ -0,0 +1,19 @@
+# Generated by Django 4.1.7 on 2024-04-23 12:19
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ("ponos", "0008_agent_mode"),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name="task",
+            name="original_task",
+            field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="restarts", to="ponos.task"),
+        ),
+    ]
diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py
index 26e6186920..76a91f8a15 100644
--- a/arkindex/ponos/models.py
+++ b/arkindex/ponos/models.py
@@ -327,12 +327,18 @@ class Task(models.Model):
         related_name="children",
         symmetrical=False,
     )
-
     container = models.CharField(
         max_length=64,
         null=True,
         blank=True,
     )
+    original_task = models.ForeignKey(
+        "self",
+        on_delete=models.SET_NULL,
+        null=True,
+        blank=True,
+        related_name="restarts"
+    )
 
     created = models.DateTimeField(auto_now_add=True)
     updated = models.DateTimeField(auto_now=True)
diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py
index 932cc7c0c9..392e993f43 100644
--- a/arkindex/ponos/serializers.py
+++ b/arkindex/ponos/serializers.py
@@ -89,6 +89,7 @@ class TaskSerializer(TaskLightSerializer):
             "agent",
             "gpu",
             "extra_files",
+            "original_task_id"
         )
         read_only_fields = TaskLightSerializer.Meta.read_only_fields + (
             "logs",
@@ -96,6 +97,7 @@ class TaskSerializer(TaskLightSerializer):
             "agent",
             "gpu",
             "extra_files",
+            "original_task_id"
         )
 
     @extend_schema_field(serializers.CharField())
diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py
index 986db55f01..7e9f2db698 100644
--- a/arkindex/ponos/tests/test_api.py
+++ b/arkindex/ponos/tests/test_api.py
@@ -74,6 +74,7 @@ class TestAPI(FixtureAPITestCase):
                 "slug": "initialisation",
                 "state": "unscheduled",
                 "parents": [],
+                "original_task_id": None,
                 "logs": "Failed successfully",
                 "full_log": "http://somewhere",
                 "extra_files": {},
@@ -157,6 +158,7 @@ class TestAPI(FixtureAPITestCase):
                         "slug": "initialisation",
                         "state": "unscheduled",
                         "parents": [],
+                        "original_task_id": None,
                         "logs": "Failed successfully",
                         "full_log": "http://somewhere",
                         "extra_files": {},
@@ -198,6 +200,7 @@ class TestAPI(FixtureAPITestCase):
                 "slug": "initialisation",
                 "state": "unscheduled",
                 "parents": [],
+                "original_task_id": None,
                 "logs": "Failed successfully",
                 "full_log": "http://somewhere",
                 "extra_files": {},
@@ -586,14 +589,16 @@ class TestAPI(FixtureAPITestCase):
         self.task1.slug = self.task1.slug + "_old1"
         self.task1.state = State.Completed.value
         self.task1.save()
-        with self.assertNumQueries(7):
+        self.task2.original_task_id = self.task1.id
+        self.task2.save()
+        with self.assertNumQueries(8):
             response = self.client.post(
                 reverse("api:task-restart", kwargs={"pk": str(self.task1.id)})
             )
             self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
             self.assertListEqual(
                 response.json(),
-                ["This task has already been restarted"],
+                ["This task has already been restarted."],
             )
 
     @patch("arkindex.project.aws.s3")
@@ -630,6 +635,7 @@ class TestAPI(FixtureAPITestCase):
             mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=1)
             old_task_2 = self.process.tasks.create(run=self.task1.run, depth=1, slug=f"{task_2_slug}_old1")
             old_task_2.state = State.Error.value
+            old_task_2.original_task_id = self.task1.id
             old_task_2.save()
         old_task_2.parents.add(self.task1)
         self.task1.state = State.Completed.value
@@ -638,7 +644,7 @@ class TestAPI(FixtureAPITestCase):
         self.task2.save()
 
         self.client.force_login(self.user)
-        with self.assertNumQueries(13):
+        with self.assertNumQueries(14):
             with patch("django.utils.timezone.now") as mock_now:
                 mock_now.return_value = datetime.now(timezone.utc) + timedelta(minutes=2)
                 response = self.client.post(
@@ -657,6 +663,7 @@ class TestAPI(FixtureAPITestCase):
                 "full_log": "http://somewhere",
                 "gpu": None,
                 "logs": "Task has been restarted",
+                "original_task_id": str(self.task2.id),
                 "parents": [str(self.task1.id)],
                 "run": 0,
                 "shm_size": None,
@@ -704,7 +711,7 @@ class TestAPI(FixtureAPITestCase):
         self.task2.save()
 
         self.client.force_login(self.user)
-        with self.assertNumQueries(13):
+        with self.assertNumQueries(14):
             response = self.client.post(
                 reverse("api:task-restart", kwargs={"pk": str(self.task2.id)})
             )
@@ -721,6 +728,7 @@ class TestAPI(FixtureAPITestCase):
                 "full_log": "http://somewhere",
                 "gpu": None,
                 "logs": "Task has been restarted",
+                "original_task_id": str(self.task2.id),
                 "parents": [str(self.task1.id)],
                 "run": 0,
                 "shm_size": None,
-- 
GitLab