diff --git a/arkindex/ponos/migrations/0006_task_worker_run.py b/arkindex/ponos/migrations/0006_task_worker_run.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab76c500071a4223e50452f39029c63f72322ff8
--- /dev/null
+++ b/arkindex/ponos/migrations/0006_task_worker_run.py
@@ -0,0 +1,20 @@
+# Generated by Django 4.1.7 on 2023-11-28 11:02
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('process', '0023_alter_workerversion_model_usage'),
+        ('ponos', '0005_remove_task_tags'),
+    ]
+
+    operations = [
+        migrations.AddField(
+            model_name='task',
+            name='worker_run',
+            field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tasks', to='process.workerrun'),
+        ),
+    ]
diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py
index 733bec39793df5017774c3e6d7e2c71daa03d980..a2034bb448cad0cd3e57b5ab666df7bb67c16b51 100644
--- a/arkindex/ponos/models.py
+++ b/arkindex/ponos/models.py
@@ -502,6 +502,13 @@ class Task(models.Model):
         related_name="tasks",
         on_delete=models.CASCADE,
     )
+    worker_run = models.ForeignKey(
+        'process.WorkerRun',
+        on_delete=models.SET_NULL,
+        related_name='tasks',
+        null=True,
+        blank=True,
+    )
     parents = models.ManyToManyField(
         "self",
         related_name="children",
diff --git a/arkindex/process/builder.py b/arkindex/process/builder.py
index 4d882ae4f8a346bcd091aae7d347a2a33eb78e7e..a29a5a9a161d6aa67c7964f687ff8992389d8b36 100644
--- a/arkindex/process/builder.py
+++ b/arkindex/process/builder.py
@@ -88,7 +88,8 @@ class ProcessBuilder(object):
         has_docker_socket=False,
         extra_files={},
         requires_gpu=False,
-        shm_size=None
+        shm_size=None,
+        worker_run=None,
     ) -> None:
         """
         Build a Task with default attributes and add it to the current stack.
@@ -111,6 +112,7 @@ class ProcessBuilder(object):
                 has_docker_socket=has_docker_socket,
                 extra_files=extra_files,
                 image_artifact_id=artifact,
+                worker_run=worker_run,
             )
         )
 
@@ -228,6 +230,7 @@ class ProcessBuilder(object):
                 **self.base_env,
                 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
             },
+            worker_run=worker_run,
         )
         self._create_worker_versions_cache([(settings.IMPORTS_WORKER_VERSION, None, None)])
 
@@ -249,6 +252,7 @@ class ProcessBuilder(object):
                 **self.base_env,
                 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id),
             },
+            worker_run=worker_run,
         )
         self._create_worker_versions_cache([(settings.IMPORTS_WORKER_VERSION, None, None)])
 
diff --git a/arkindex/process/managers.py b/arkindex/process/managers.py
index 2b50dd4053d0d04efc8751232c229731910b23d5..e0e6bbf15f9919f79f4c82ee9421ca4044da9c1a 100644
--- a/arkindex/process/managers.py
+++ b/arkindex/process/managers.py
@@ -7,6 +7,7 @@ from django.db.models.functions import Coalesce
 from django.db.models.query import QuerySet
 from django.utils.functional import cached_property
 
+from arkindex.process.models import Task
 from arkindex.users.managers import BaseACLManager
 from arkindex.users.models import Role
 
@@ -179,6 +180,8 @@ class WorkerRunQuerySet(QuerySet):
             field.related_model
             for field in self.model._meta.get_fields()
             if isinstance(field, ManyToOneRel)
+            # Ignore reverse links to processes tasks
+            and field.related_model is not Task
         ]
         for field in related_models:
             if field.objects.filter(worker_run__in=ids).exists():
diff --git a/arkindex/process/models.py b/arkindex/process/models.py
index 5b9ae5361478d572752c276ee570823d259730ff..a40c9581a0f24903393437966bf183502068077f 100644
--- a/arkindex/process/models.py
+++ b/arkindex/process/models.py
@@ -953,6 +953,7 @@ class WorkerRun(models.Model):
             run=run,
             token=token,
             process=process,
+            worker_run=self,
             extra_files=extra_files,
             requires_gpu=process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported)
         )
diff --git a/arkindex/process/signals.py b/arkindex/process/signals.py
index 661bf35172197f796a354fecde62afbe5ea26dd8..63c648796937e118e46f65bb8957a5d1da79bab6 100644
--- a/arkindex/process/signals.py
+++ b/arkindex/process/signals.py
@@ -71,15 +71,25 @@ def generate_summary(sender, instance, **kwargs):
 @receiver(task_failure)
 def stop_started_activities(sender, task, **kwargs):
     """
-    When a Ponos task fails, update any relevant WorkerActivity instances from `started` to `error`.
+    When a Ponos task fails, update WorkerActivity from the same WorkerRun from `started` to `error`.
     This allows to retry a process and get it to re-run on activities that were not finished without skipping.
-
-    This is definitely not perfect, as this can cause WorkerActivities for worker versions of unrelated tasks
-    to also be marked as error when they are successful, but we do not have any reliable link
-    between Ponos tasks and Arkindex worker versions.
     """
     process = Process.objects.filter(id=task.process_id).only('id', 'activity_state').first()
-    if not process or process.activity_state == ActivityState.Disabled:
+    if process is None or process.activity_state == ActivityState.Disabled:
         return
-    count = process.activities.filter(state=WorkerActivityState.Started).update(state=WorkerActivityState.Error)
+
+    extra_filters = {}
+    if task.worker_run:
+        # Look for process activities matching this specific worker version, model version and configuration
+        extra_filters.update({
+            'worker_version_id': task.worker_run.version_id,
+            'model_version_id': task.worker_run.model_version_id,
+            'configuration_id': task.worker_run.configuration_id,
+        })
+
+    count = (
+        process.activities
+        .filter(state=WorkerActivityState.Started, **extra_filters)
+        .update(state=WorkerActivityState.Error)
+    )
     logger.info(f'Updated {count} worker activities from Started to Error as a task failed on {process.id}')
diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py
index dd9a7af4dd987b47ba2639cb1046cf3789c316ac..8f7fa64151a9dbc159a81370a1d17ea1118bc15f 100644
--- a/arkindex/process/tests/test_processes.py
+++ b/arkindex/process/tests/test_processes.py
@@ -2933,7 +2933,7 @@ class TestProcesses(FixtureAPITestCase):
         self.assertEqual(process.worker_runs.count(), 2)
 
         self.client.force_login(self.user)
-        with self.assertNumQueries(9):
+        with self.assertNumQueries(11):
             response = self.client.delete(reverse('api:clear-process', kwargs={'pk': str(process.id)}))
         self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
         process.refresh_from_db()
diff --git a/arkindex/process/tests/test_signals.py b/arkindex/process/tests/test_signals.py
index e45ac681eaf7496600213f4831cb48d29342a11c..5b73f455ef16a9d73baf70fa28fe85167923d7b5 100644
--- a/arkindex/process/tests/test_signals.py
+++ b/arkindex/process/tests/test_signals.py
@@ -21,6 +21,7 @@ from arkindex.process.models import (
 from arkindex.process.signals import _list_ancestors
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.project.tools import build_public_key
+from arkindex.training.models import Model
 
 
 class TestSignals(FixtureAPITestCase):
@@ -366,3 +367,80 @@ class TestSignals(FixtureAPITestCase):
                 self.assertEqual(processed_activity.state, WorkerActivityState.Processed)
                 self.assertEqual(error_activity.state, WorkerActivityState.Error)
                 self.assertEqual(started_activity.state, expected_activity_state)
+
+    @patch("arkindex.ponos.serializers.TaskSerializer.get_logs")
+    @patch("arkindex.ponos.tasks.notify_process_completion.delay")
+    def test_task_failure_filters_activities_worker_run(self, notify_mock, get_logs_mock):
+        """
+        A Ponos task failure only updates worker activity with the same worker version,
+        model version and worker configuration.
+        """
+        get_logs_mock.return_value = None
+        model = Model.objects.create(name='Generic model', public=False)
+        model_version = model.versions.create(hash='b' * 32, archive_hash='a' * 32, size=8)
+        worker_configuration = self.worker_1.configurations.create(name="conf")
+        self.run_1.configuration = worker_configuration
+        self.run_1.save()
+
+        self.process_2.activity_state = ActivityState.Ready
+        self.process_2.save()
+        self.process_2.run()
+        task = self.process_2.tasks.first()
+        task.state = State.Running
+        task.agent = self.agent
+        task.save()
+
+        # Create one activity per WorkerActivityState on random elements
+        element = self.corpus.elements.first()
+        activity_1 = self.process_2.activities.create(
+            worker_version=self.version_1,
+            configuration=worker_configuration,
+            model_version=None,
+            element=element,
+            started=datetime.now(),
+        )
+        activity_2 = self.process_2.activities.create(
+            worker_version=self.version_1,
+            configuration=None,
+            model_version=model_version,
+            element=element,
+            started=datetime.now(),
+        )
+        activity_3 = self.process_2.activities.create(
+            worker_version=self.version_1,
+            configuration=worker_configuration,
+            model_version=model_version,
+            element=element,
+            started=datetime.now(),
+        )
+
+        cases = [
+            (None, None, None, [activity_1, activity_2, activity_3]),
+            (self.run_1, model_version, None, [activity_2]),
+            (self.run_1, None, worker_configuration, [activity_1]),
+            (self.run_1, model_version, worker_configuration, [activity_3]),
+        ]
+        for worker_run, model_version, worker_configuration, expected in cases:
+            self.process_2.activities.update(state=WorkerActivityState.Started)
+            if worker_run:
+                worker_run.model_version = model_version
+                worker_run.configuration = worker_configuration
+                worker_run.save()
+                task.worker_run = worker_run
+                task.save()
+
+            with self.subTest(
+                worker_run=worker_run,
+                model_version=model_version,
+                worker_configuration=worker_configuration
+            ):
+                resp = self.client.patch(
+                    reverse('api:task-details', kwargs={'pk': task.id}),
+                    data={'state': State.Error.value},
+                    HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}",
+                )
+                self.assertEqual(resp.status_code, status.HTTP_200_OK)
+                for activity in self.process_2.activities.filter(id__in=[e.id for e in expected]):
+                    self.assertEqual(activity.state, WorkerActivityState.Error)
+                for activity in self.process_2.activities.exclude(id__in=[e.id for e in expected]):
+                    self.assertEqual(activity.state, WorkerActivityState.Started)
diff --git a/arkindex/process/tests/test_templates.py b/arkindex/process/tests/test_templates.py
index 43517e50f2af72c3645a43c31b2f7903824316a4..800a2108294681cfe8f90f2343f0bfe7f9445083 100644
--- a/arkindex/process/tests/test_templates.py
+++ b/arkindex/process/tests/test_templates.py
@@ -335,7 +335,7 @@ class TestTemplates(FixtureAPITestCase):
             parents=[],
         )
         # Apply a template that has two other worker runs
-        with self.assertNumQueries(20):
+        with self.assertNumQueries(22):
             response = self.client.post(
                 reverse('api:apply-process-template', kwargs={'pk': str(self.template.id)}),
                 data=json.dumps({"process_id": str(process.id)}),
diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py
index 14f6a5ca39aacf641c48cd591c8e383dd556832e..02112aac9790e05527e5a8f23d6ebfe26983af91 100644
--- a/arkindex/process/tests/test_workerruns.py
+++ b/arkindex/process/tests/test_workerruns.py
@@ -2906,7 +2906,7 @@ class TestWorkerRuns(FixtureAPITestCase):
         self.worker_1.memberships.update(level=Role.Guest.value)
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(8):
+        with self.assertNumQueries(9):
             response = self.client.delete(
                 reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})
             )
@@ -2933,7 +2933,7 @@ class TestWorkerRuns(FixtureAPITestCase):
     def test_delete(self):
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(8):
+        with self.assertNumQueries(9):
             response = self.client.delete(
                 reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})
             )
@@ -2976,7 +2976,7 @@ class TestWorkerRuns(FixtureAPITestCase):
         self.assertTrue(self.run_1.id in run_3.parents)
         self.client.force_login(self.user)
 
-        with self.assertNumQueries(8):
+        with self.assertNumQueries(9):
             response = self.client.delete(
                 reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)})
             )