diff --git a/arkindex/ponos/migrations/0006_task_worker_run.py b/arkindex/ponos/migrations/0006_task_worker_run.py new file mode 100644 index 0000000000000000000000000000000000000000..ab76c500071a4223e50452f39029c63f72322ff8 --- /dev/null +++ b/arkindex/ponos/migrations/0006_task_worker_run.py @@ -0,0 +1,20 @@ +# Generated by Django 4.1.7 on 2023-11-28 11:02 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('process', '0023_alter_workerversion_model_usage'), + ('ponos', '0005_remove_task_tags'), + ] + + operations = [ + migrations.AddField( + model_name='task', + name='worker_run', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tasks', to='process.workerrun'), + ), + ] diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py index 733bec39793df5017774c3e6d7e2c71daa03d980..a2034bb448cad0cd3e57b5ab666df7bb67c16b51 100644 --- a/arkindex/ponos/models.py +++ b/arkindex/ponos/models.py @@ -502,6 +502,13 @@ class Task(models.Model): related_name="tasks", on_delete=models.CASCADE, ) + worker_run = models.ForeignKey( + 'process.WorkerRun', + on_delete=models.SET_NULL, + related_name='tasks', + null=True, + blank=True, + ) parents = models.ManyToManyField( "self", related_name="children", diff --git a/arkindex/process/builder.py b/arkindex/process/builder.py index 4d882ae4f8a346bcd091aae7d347a2a33eb78e7e..a29a5a9a161d6aa67c7964f687ff8992389d8b36 100644 --- a/arkindex/process/builder.py +++ b/arkindex/process/builder.py @@ -88,7 +88,8 @@ class ProcessBuilder(object): has_docker_socket=False, extra_files={}, requires_gpu=False, - shm_size=None + shm_size=None, + worker_run=None, ) -> None: """ Build a Task with default attributes and add it to the current stack. @@ -111,6 +112,7 @@ class ProcessBuilder(object): has_docker_socket=has_docker_socket, extra_files=extra_files, image_artifact_id=artifact, + worker_run=worker_run, ) ) @@ -228,6 +230,7 @@ class ProcessBuilder(object): **self.base_env, 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id), }, + worker_run=worker_run, ) self._create_worker_versions_cache([(settings.IMPORTS_WORKER_VERSION, None, None)]) @@ -249,6 +252,7 @@ class ProcessBuilder(object): **self.base_env, 'ARKINDEX_WORKER_RUN_ID': str(worker_run.id), }, + worker_run=worker_run, ) self._create_worker_versions_cache([(settings.IMPORTS_WORKER_VERSION, None, None)]) diff --git a/arkindex/process/managers.py b/arkindex/process/managers.py index 2b50dd4053d0d04efc8751232c229731910b23d5..e0e6bbf15f9919f79f4c82ee9421ca4044da9c1a 100644 --- a/arkindex/process/managers.py +++ b/arkindex/process/managers.py @@ -7,6 +7,7 @@ from django.db.models.functions import Coalesce from django.db.models.query import QuerySet from django.utils.functional import cached_property +from arkindex.process.models import Task from arkindex.users.managers import BaseACLManager from arkindex.users.models import Role @@ -179,6 +180,8 @@ class WorkerRunQuerySet(QuerySet): field.related_model for field in self.model._meta.get_fields() if isinstance(field, ManyToOneRel) + # Ignore reverse links to processes tasks + and field.related_model is not Task ] for field in related_models: if field.objects.filter(worker_run__in=ids).exists(): diff --git a/arkindex/process/models.py b/arkindex/process/models.py index 5b9ae5361478d572752c276ee570823d259730ff..a40c9581a0f24903393437966bf183502068077f 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -953,6 +953,7 @@ class WorkerRun(models.Model): run=run, token=token, process=process, + worker_run=self, extra_files=extra_files, requires_gpu=process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported) ) diff --git a/arkindex/process/signals.py b/arkindex/process/signals.py index 661bf35172197f796a354fecde62afbe5ea26dd8..63c648796937e118e46f65bb8957a5d1da79bab6 100644 --- a/arkindex/process/signals.py +++ b/arkindex/process/signals.py @@ -71,15 +71,25 @@ def generate_summary(sender, instance, **kwargs): @receiver(task_failure) def stop_started_activities(sender, task, **kwargs): """ - When a Ponos task fails, update any relevant WorkerActivity instances from `started` to `error`. + When a Ponos task fails, update WorkerActivity from the same WorkerRun from `started` to `error`. This allows to retry a process and get it to re-run on activities that were not finished without skipping. - - This is definitely not perfect, as this can cause WorkerActivities for worker versions of unrelated tasks - to also be marked as error when they are successful, but we do not have any reliable link - between Ponos tasks and Arkindex worker versions. """ process = Process.objects.filter(id=task.process_id).only('id', 'activity_state').first() - if not process or process.activity_state == ActivityState.Disabled: + if process is None or process.activity_state == ActivityState.Disabled: return - count = process.activities.filter(state=WorkerActivityState.Started).update(state=WorkerActivityState.Error) + + extra_filters = {} + if task.worker_run: + # Look for process activities matching this specific worker version, model version and configuration + extra_filters.update({ + 'worker_version_id': task.worker_run.version_id, + 'model_version_id': task.worker_run.model_version_id, + 'configuration_id': task.worker_run.configuration_id, + }) + + count = ( + process.activities + .filter(state=WorkerActivityState.Started, **extra_filters) + .update(state=WorkerActivityState.Error) + ) logger.info(f'Updated {count} worker activities from Started to Error as a task failed on {process.id}') diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index dd9a7af4dd987b47ba2639cb1046cf3789c316ac..8f7fa64151a9dbc159a81370a1d17ea1118bc15f 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -2933,7 +2933,7 @@ class TestProcesses(FixtureAPITestCase): self.assertEqual(process.worker_runs.count(), 2) self.client.force_login(self.user) - with self.assertNumQueries(9): + with self.assertNumQueries(11): response = self.client.delete(reverse('api:clear-process', kwargs={'pk': str(process.id)})) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) process.refresh_from_db() diff --git a/arkindex/process/tests/test_signals.py b/arkindex/process/tests/test_signals.py index e45ac681eaf7496600213f4831cb48d29342a11c..5b73f455ef16a9d73baf70fa28fe85167923d7b5 100644 --- a/arkindex/process/tests/test_signals.py +++ b/arkindex/process/tests/test_signals.py @@ -21,6 +21,7 @@ from arkindex.process.models import ( from arkindex.process.signals import _list_ancestors from arkindex.project.tests import FixtureAPITestCase from arkindex.project.tools import build_public_key +from arkindex.training.models import Model class TestSignals(FixtureAPITestCase): @@ -366,3 +367,80 @@ class TestSignals(FixtureAPITestCase): self.assertEqual(processed_activity.state, WorkerActivityState.Processed) self.assertEqual(error_activity.state, WorkerActivityState.Error) self.assertEqual(started_activity.state, expected_activity_state) + + @patch("arkindex.ponos.serializers.TaskSerializer.get_logs") + @patch("arkindex.ponos.tasks.notify_process_completion.delay") + def test_task_failure_filters_activities_worker_run(self, notify_mock, get_logs_mock): + """ + A Ponos task failure only updates worker activity with the same worker version, + model version and worker configuration. + """ + get_logs_mock.return_value = None + model = Model.objects.create(name='Generic model', public=False) + model_version = model.versions.create(hash='b' * 32, archive_hash='a' * 32, size=8) + worker_configuration = self.worker_1.configurations.create(name="conf") + self.run_1.configuration = worker_configuration + self.run_1.save() + + self.process_2.activity_state = ActivityState.Ready + self.process_2.save() + self.process_2.run() + task = self.process_2.tasks.first() + task.state = State.Running + task.agent = self.agent + task.save() + + # Create one activity per WorkerActivityState on random elements + element = self.corpus.elements.first() + activity_1 = self.process_2.activities.create( + worker_version=self.version_1, + configuration=worker_configuration, + model_version=None, + element=element, + started=datetime.now(), + ) + activity_2 = self.process_2.activities.create( + worker_version=self.version_1, + configuration=None, + model_version=model_version, + element=element, + started=datetime.now(), + ) + activity_3 = self.process_2.activities.create( + worker_version=self.version_1, + configuration=worker_configuration, + model_version=model_version, + element=element, + started=datetime.now(), + ) + + cases = [ + (None, None, None, [activity_1, activity_2, activity_3]), + (self.run_1, model_version, None, [activity_2]), + (self.run_1, None, worker_configuration, [activity_1]), + (self.run_1, model_version, worker_configuration, [activity_3]), + ] + for worker_run, model_version, worker_configuration, expected in cases: + self.process_2.activities.update(state=WorkerActivityState.Started) + if worker_run: + worker_run.model_version = model_version + worker_run.configuration = worker_configuration + worker_run.save() + task.worker_run = worker_run + task.save() + + with self.subTest( + worker_run=worker_run, + model_version=model_version, + worker_configuration=worker_configuration + ): + resp = self.client.patch( + reverse('api:task-details', kwargs={'pk': task.id}), + data={'state': State.Error.value}, + HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + for activity in self.process_2.activities.filter(id__in=[e.id for e in expected]): + self.assertEqual(activity.state, WorkerActivityState.Error) + for activity in self.process_2.activities.exclude(id__in=[e.id for e in expected]): + self.assertEqual(activity.state, WorkerActivityState.Started) diff --git a/arkindex/process/tests/test_templates.py b/arkindex/process/tests/test_templates.py index 43517e50f2af72c3645a43c31b2f7903824316a4..800a2108294681cfe8f90f2343f0bfe7f9445083 100644 --- a/arkindex/process/tests/test_templates.py +++ b/arkindex/process/tests/test_templates.py @@ -335,7 +335,7 @@ class TestTemplates(FixtureAPITestCase): parents=[], ) # Apply a template that has two other worker runs - with self.assertNumQueries(20): + with self.assertNumQueries(22): response = self.client.post( reverse('api:apply-process-template', kwargs={'pk': str(self.template.id)}), data=json.dumps({"process_id": str(process.id)}), diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py index 14f6a5ca39aacf641c48cd591c8e383dd556832e..02112aac9790e05527e5a8f23d6ebfe26983af91 100644 --- a/arkindex/process/tests/test_workerruns.py +++ b/arkindex/process/tests/test_workerruns.py @@ -2906,7 +2906,7 @@ class TestWorkerRuns(FixtureAPITestCase): self.worker_1.memberships.update(level=Role.Guest.value) self.client.force_login(self.user) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.delete( reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}) ) @@ -2933,7 +2933,7 @@ class TestWorkerRuns(FixtureAPITestCase): def test_delete(self): self.client.force_login(self.user) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.delete( reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}) ) @@ -2976,7 +2976,7 @@ class TestWorkerRuns(FixtureAPITestCase): self.assertTrue(self.run_1.id in run_3.parents) self.client.force_login(self.user) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.delete( reverse('api:worker-run-details', kwargs={'pk': str(self.run_1.id)}) )