From 56906ee3812bf38dead35be5e47e8b48c2e8bffc Mon Sep 17 00:00:00 2001
From: Valentin Rigal <rigal@teklia.com>
Date: Wed, 29 Nov 2023 15:57:42 +0000
Subject: [PATCH] Prevent the death of other tasks when one fails

---
 .../ponos/migrations/0006_task_worker_run.py  | 20 +++++
 arkindex/ponos/models.py                      |  7 ++
 arkindex/process/builder.py                   |  6 +-
 arkindex/process/managers.py                  |  3 +
 arkindex/process/models.py                    |  1 +
 arkindex/process/signals.py                   | 24 ++++--
 arkindex/process/tests/test_processes.py      |  2 +-
 arkindex/process/tests/test_signals.py        | 78 +++++++++++++++++++
 arkindex/process/tests/test_templates.py      |  2 +-
 arkindex/process/tests/test_workerruns.py     |  6 +-
 10 files changed, 136 insertions(+), 13 deletions(-)
 create mode 100644 arkindex/ponos/migrations/0006_task_worker_run.py

diff --git a/arkindex/ponos/migrations/0006_task_worker_run.py b/arkindex/ponos/migrations/0006_task_worker_run.py
new file mode 100644
index 0000000000..ab76c50007
--- /dev/null
+++ b/arkindex/ponos/migrations/0006_task_worker_run.py
@@ -0,0 +1,20 @@
+# Generated by Django 4.1.7 on 2023-11-28 11:02
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('process', '0023_alter_workerversion_model_usage'),
+        ('ponos', '0005_remove_task_tags'),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name='task',
+            name='worker_run',
+            field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tasks', to='process.workerrun'),
+        ),
+    ]
diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py
index 733bec3979..a2034bb448 100644
--- a/arkindex/ponos/models.py
+++ b/arkindex/ponos/models.py
@@ -502,6 +502,13 @@ class Task(models.Model):
         related_name="tasks",
         on_delete=models.CASCADE,
     )
+    worker_run = models.ForeignKey(
+        'process.WorkerRun',
+        on_delete=models.SET_NULL,
+        related_name='tasks',
+        null=True,
+        blank=True,
+    )
     parents = models.ManyToManyField(
         "self",
         related_name="children",
diff --git a/arkindex/process/builder.py b/arkindex/process/builder.py
index 4d882ae4f8..a29a5a9a16 100644
--- a/arkindex/process/builder.py
+++ b/arkindex/process/builder.py
@@ -88,7 +88,8 @@ class ProcessBuilder(object):
         has_docker_socket=False,
         extra_files={},
         requires_gpu=False,
-        shm_size=None
+        shm_size=None,
+        worker_run=None,
     ) -> None:
         """
         Build a Task with default attributes and add it to the current stack.
@@ -111,6 +112,7 @@ class ProcessBuilder(object):
                 has_docker_socket=has_docker_socket,
                 extra_files=extra_files,
                 image_artifact_id=artifact,
+                worker_run=worker_run,
             )
         )
 
@@ -228,6 +230,7 @@ class ProcessBuilder(object):
                 **self.base_env,
                 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
             },
+            worker_run=worker_run,
         )
         self._create_worker_versions_cache([(settings.IMPORTS_WORKER_VERSION, None, None)])
 
@@ -249,6 +252,7 @@ class ProcessBuilder(object):
                 **self.base_env,
                 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
             },
+            worker_run=worker_run,
         )
         self._create_worker_versions_cache([(settings.IMPORTS_WORKER_VERSION, None, None)])
 
diff --git a/arkindex/process/managers.py b/arkindex/process/managers.py
index 2b50dd4053..e0e6bbf15f 100644
--- a/arkindex/process/managers.py
+++ b/arkindex/process/managers.py
@@ -7,6 +7,7 @@ from django.db.models.functions import Coalesce
 from django.db.models.query import QuerySet
 from django.utils.functional import cached_property
 
+from arkindex.process.models import Task
 from arkindex.users.managers import BaseACLManager
 from arkindex.users.models import Role
 
@@ -179,6 +180,8 @@ class WorkerRunQuerySet(QuerySet):
             field.related_model
             for field in self.model._meta.get_fields()
             if isinstance(field, ManyToOneRel)
+            # Ignore reverse links to processes tasks
+            and field.related_model is not Task
         ]
         for field in related_models:
             if field.objects.filter(worker_run__in=ids).exists():
diff --git a/arkindex/process/models.py b/arkindex/process/models.py
index 5b9ae53614..a40c9581a0 100644
--- a/arkindex/process/models.py
+++ b/arkindex/process/models.py
@@ -953,6 +953,7 @@ class WorkerRun(models.Model):
             run=run,
             token=token,
             process=process,
+            worker_run=self,
             extra_files=extra_files,
             requires_gpu=process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported)
         )
diff --git a/arkindex/process/signals.py b/arkindex/process/signals.py
index 661bf35172..63c6487969 100644
--- a/arkindex/process/signals.py
+++ b/arkindex/process/signals.py
@@ -71,15 +71,25 @@ def generate_summary(sender, instance, **kwargs):
 @receiver(task_failure)
 def stop_started_activities(sender, task, **kwargs):
     """
-    When a Ponos task fails, update any relevant WorkerActivity instances from `started` to `error`.
+    When a Ponos task fails, update WorkerActivity from the same WorkerRun from `started` to `error`.
     This allows to retry a process and get it to re-run on activities that were not finished without skipping.
-
-    This is definitely not perfect, as this can cause WorkerActivities for worker versions of unrelated tasks
-    to also be marked as error when they are successful, but we do not have any reliable link
-    between Ponos tasks and Arkindex worker versions.
     """
     process = Process.objects.filter(id=task.process_id).only('id', 'activity_state').first()
-    if not process or process.activity_state == ActivityState.Disabled:
+    if process is None or process.activity_state == ActivityState.Disabled:
         return
-    count = process.activities.filter(state=WorkerActivityState.Started).update(state=WorkerActivityState.Error)
+
+    extra_filters = {}
+    if task.worker_run:
+        # Look for process activities matching this specific worker version, model version and configuration
+        extra_filters.update({
+            'worker_version_id': task.worker_run.version_id,
+            'model_version_id': task.worker_run.model_version_id,
+            'configuration_id': task.worker_run.configuration_id,
+        })
+
+    count = (
+        process.activities
+        .filter(state=WorkerActivityState.Started, **extra_filters)
+        .update(state=WorkerActivityState.Error)
+    )
     logger.info(f'Updated {count} worker activities from Started to Error as a task failed on {process.id}')
diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py
index dd9a7af4dd..8f7fa64151 100644
--- a/arkindex/process/tests/test_processes.py
+++ b/arkindex/process/tests/test_processes.py
@@ -2933,7 +2933,7 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(process.worker_runs.count(), 2)
 
         self.client.force_login(self.user)
-        with self.assertNumQueries(9):
+        with self.assertNumQueries(11):
             response = self.client.delete(reverse('api:clear-process', kwargs={'pk': str(process.id)}))
         self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
         process.refresh_from_db()
diff --git a/arkindex/process/tests/test_signals.py b/arkindex/process/tests/test_signals.py
index e45ac681ea..5b73f455ef 100644
--- a/arkindex/process/tests/test_signals.py
+++ b/arkindex/process/tests/test_signals.py
@@ -21,6 +21,7 @@ from arkindex.process.models import (
 from arkindex.process.signals import _list_ancestors
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.project.tools import build_public_key
+from arkindex.training.models import Model
 
 
 class TestSignals(FixtureAPITestCase):
@@ -366,3 +367,80 @@ class TestSignals(FixtureAPITestCase):
                 self.assertEqual(processed_activity.state, WorkerActivityState.Processed)
                 self.assertEqual(error_activity.state, WorkerActivityState.Error)
                 self.assertEqual(started_activity.state, expected_activity_state)
+
+    @patch("arkindex.ponos.serializers.TaskSerializer.get_logs")
+    @patch("arkindex.ponos.tasks.notify_process_completion.delay")
+    def test_task_failure_filters_activities_worker_run(self, notify_mock, get_logs_mock):
+        """
+        A Ponos task failure only updates worker activity with the same worker version,
+        model version and worker configuration.
+        """
+        get_logs_mock.return_value = None
+        model = Model.objects.create(name='Generic model', public=False)
+        model_version = model.versions.create(hash='b' * 32, archive_hash='a' * 32, size=8)
+        worker_configuration = self.worker_1.configurations.create(name="conf")
+        self.run_1.configuration = worker_configuration
+        self.run_1.save()
+
+        self.process_2.activity_state = ActivityState.Ready
+        self.process_2.save()
+        self.process_2.run()
+        task = self.process_2.tasks.first()
+        task.state = State.Running
+        task.agent = self.agent
+        task.save()
+
+        # Create one activity per WorkerActivityState on random elements
+        element = self.corpus.elements.first()
+        activity_1 = self.process_2.activities.create(
+            worker_version=self.version_1,
+            configuration=worker_configuration,
+            model_version=None,
+            element=element,
+            started=datetime.now(),
+        )
+        activity_2 = self.process_2.activities.create(
+            worker_version=self.version_1,
+            configuration=None,
+            model_version=model_version,
+            element=element,
+            started=datetime.now(),
+        )
+        activity_3 = self.process_2.activities.create(
+            worker_version=self.version_1,
+            configuration=worker_configuration,
+            model_version=model_version,
+            element=element,
+            started=datetime.now(),
+        )
+
+        cases = [
+            (None, None, None, [activity_1, activity_2, activity_3]),
+            (self.run_1, model_version, None, [activity_2]),
+            (self.run_1, None, worker_configuration, [activity_1]),
+            (self.run_1, model_version, worker_configuration, [activity_3]),
+        ]
+        for worker_run, model_version, worker_configuration, expected in cases:
+            self.process_2.activities.update(state=WorkerActivityState.Started)
+            if worker_run:
+                worker_run.model_version = model_version
+                worker_run.configuration = worker_configuration
+                worker_run.save()
+                task.worker_run = worker_run
+                task.save()
+
+            with self.subTest(
+                worker_run=worker_run,
+                model_version=model_version,
+                worker_configuration=worker_configuration
+            ):
+                resp = self.client.patch(
+                    reverse('api:task-details', kwargs={'pk': task.id}),
+                    data={'state': State.Error.value},
+                    HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}",
+                )
+                self.assertEqual(resp.status_code, status.HTTP_200_OK)
+                for activity in self.process_2.activities.filter(id__in=[e.id for e in expected]):
+                    self.assertEqual(activity.state, WorkerActivityState.Error)
+                for activity in self.process_2.activities.exclude(id__in=[e.id for e in expected]):
+                    self.assertEqual(activity.state, WorkerActivityState.Started)
diff --git a/arkindex/process/tests/test_templates.py b/arkindex/process/tests/test_templates.py
index 43517e50f2..800a210829 100644
--- a/arkindex/process/tests/test_templates.py
+++ b/arkindex/process/tests/test_templates.py
@@ -335,7 +335,7 @@ class TestTemplates(FixtureAPITestCase):
             parents=[],
         )
         # Apply a template that has two other worker runs
-        with self.assertNumQueries(20):
+        with self.assertNumQueries(22):
             response = self.client.post(
                 reverse('api:apply-process-template', kwargs={'pk': str(self.template.id)}),
                 data=json.dumps({"process_id": str(process.id)}),
diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py
index 14f6a5ca39..02112aac97 100644
--- a/arkindex/process/tests/test_workerruns.py
+++ b/arkindex/process/tests/test_workerruns.py
@@ -2906,7 +2906,7 @@ class TestWorkerRuns(FixtureAPITestCase):
         self.worker_1.memberships.update(level=Role.Guest.value)
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(8):
+        with self.assertNumQueries(9):
             response = self.client.delete(
                 reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})
             )
@@ -2933,7 +2933,7 @@ class TestWorkerRuns(FixtureAPITestCase):
     def test_delete(self):
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(8):
+        with self.assertNumQueries(9):
             response = self.client.delete(
                 reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})
             )
@@ -2976,7 +2976,7 @@ class TestWorkerRuns(FixtureAPITestCase):
         self.assertTrue(self.run_1.id in run_3.parents)
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(8):
+        with self.assertNumQueries(9):
             response = self.client.delete(
                 reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})
             )
-- 
GitLab