From fc247d34b5c743efd860ff1dbbf03bb974394afe Mon Sep 17 00:00:00 2001
From: ml bonhomme <bonhomme@teklia.com>
Date: Fri, 6 Sep 2024 07:45:13 +0000
Subject: [PATCH] Update remaining worker activities to error when a Task is
 updated to a final state

---
 arkindex/ponos/serializers.py    |  27 +++
 arkindex/ponos/tests/test_api.py | 371 ++++++++++++++++++++++++++++++-
 2 files changed, 396 insertions(+), 2 deletions(-)

diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py
index ae35b00896..fdc09a5a90 100644
--- a/arkindex/ponos/serializers.py
+++ b/arkindex/ponos/serializers.py
@@ -12,6 +12,7 @@ from rest_framework.exceptions import ValidationError
 
 from arkindex.ponos.models import FINAL_STATES, Agent, AgentMode, Artifact, State, Task
 from arkindex.ponos.signals import task_failure
+from arkindex.process.models import ActivityState, WorkerActivityState
 from arkindex.project.serializer_fields import EnumField
 from arkindex.project.triggers import notify_process_completion
 
@@ -148,6 +149,32 @@ class TaskSerializer(TaskLightSerializer):
                     .update(state=State.Pending)
                 )
 
+            # When a task is updated to a final state, if it is the last one of its worker_run + run + depth to be
+            # so, then all the remaining started or queued worker activities must get their state updated to error
+            # (to allow the elements to be reprocessed)
+            # Only check for this when the process' activity state is `ready` and the task is linked to a worker run,
+            # as it is otherwise not relevant.
+            if (
+                instance.process.activity_state == ActivityState.Ready
+                and instance.worker_run
+                and not (
+                    instance.process.tasks
+                    .using("default")
+                    .filter(run=instance.run, depth=instance.depth, worker_run=instance.worker_run)
+                    .exclude(state__in=FINAL_STATES)
+                    .exists()
+                )
+            ):
+                instance.process.activities.filter(
+                    worker_version_id=instance.worker_run.version_id,
+                    configuration_id=instance.worker_run.configuration_id,
+                    model_version_id=instance.worker_run.model_version_id
+                ).exclude(
+                    state__in=[WorkerActivityState.Processed, WorkerActivityState.Error]
+                ).update(
+                    state=WorkerActivityState.Error
+                )
+
             # When some tasks have failed, tasks that depend on them will remain unscheduled.
             # When a task is completed, its children will be set to pending.
             # All children have a depth above zero, since they have at least one dependency.
diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py
index a6459df950..9b59df1ef9 100644
--- a/arkindex/ponos/tests/test_api.py
+++ b/arkindex/ponos/tests/test_api.py
@@ -11,9 +11,17 @@ from django.urls import reverse
 from rest_framework import status
 
 from arkindex.documents.models import Corpus
-from arkindex.ponos.models import Agent, AgentMode, Farm, State, Task
-from arkindex.process.models import Process, ProcessMode, WorkerVersion
+from arkindex.ponos.models import FINAL_STATES, Agent, AgentMode, Farm, State, Task
+from arkindex.process.models import (
+    ActivityState,
+    Process,
+    ProcessMode,
+    WorkerActivity,
+    WorkerActivityState,
+    WorkerVersion,
+)
 from arkindex.project.tests import FixtureAPITestCase
+from arkindex.training.models import Model, ModelVersionState
 from arkindex.users.models import Right, Role, User
 
 
@@ -335,6 +343,365 @@ class TestAPI(FixtureAPITestCase):
         self.task1.refresh_from_db()
         self.assertEqual(self.task1.state, State.Stopping)
 
+    @patch("arkindex.ponos.models.base64.encodebytes")
+    @patch("arkindex.project.aws.s3")
+    @patch("arkindex.users.models.User.objects.get")
+    @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
+    def test_update_task_final_state_worker_activities_update(self, get_user_mock, s3_mock, token_mock):
+        """
+        When a task is updated to a final state, only the remaining worker activities that are not in
+        the processed (or already in error) state are updated to the error state.
+        """
+        s3_mock.Object.return_value.bucket_name = "ponos"
+        s3_mock.Object.return_value.key = "somelog"
+        s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
+        s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
+        seal(s3_mock)
+
+        token_mock.side_effect = [b"12345", b"78945",]
+
+        test_process = Process.objects.create(
+            mode=ProcessMode.Workers,
+            creator=self.user,
+            corpus=self.corpus,
+            activity_state=ActivityState.Ready
+        )
+        init_run = test_process.worker_runs.create(version=WorkerVersion.objects.get(worker__slug="initialisation"), parents=[])
+        test_run = test_process.worker_runs.create(version=self.recognizer, parents=[init_run.id])
+
+        test_process.run()
+
+        elem_1, elem_2, elem_3, elem_4 = self.corpus.elements.all()[0:4]
+        activity_processed = WorkerActivity.objects.create(
+            element_id=elem_1.id,
+            state=WorkerActivityState.Processed,
+            worker_version=self.recognizer,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+        activity_queued = WorkerActivity.objects.create(
+            element_id=elem_2.id,
+            state=WorkerActivityState.Queued,
+            worker_version=self.recognizer,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+        activity_started = WorkerActivity.objects.create(
+            element_id=elem_3.id,
+            state=WorkerActivityState.Started,
+            worker_version=self.recognizer,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+        activity_error = WorkerActivity.objects.create(
+            element_id=elem_4.id,
+            state=WorkerActivityState.Error,
+            worker_version=self.recognizer,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+
+        test_task = test_process.tasks.get(worker_run=test_run)
+
+        # Authenticate from a possible agent
+        custom_user = copy.copy(self.docker_agent)
+        custom_user.is_active = True
+        custom_user.is_authenticated = True
+
+        get_user_mock.return_value = custom_user
+        self.client.force_login(self.user)
+
+        # Get the allowed transitions to a final state
+        cases = [item for item in self.docker_task_transitions if item[1] in FINAL_STATES]
+
+        for state_from, state_to in cases:
+            with self.subTest(state_from=state_from, state_to=state_to):
+
+                test_task.state = state_from
+                test_task.save()
+
+                resp = self.client.put(
+                    reverse("api:task-details", args=[test_task.id]),
+                    data={"state": state_to.value},
+                )
+                self.assertEqual(resp.status_code, status.HTTP_200_OK)
+
+                test_task.refresh_from_db()
+                activity_started.refresh_from_db()
+                activity_error.refresh_from_db()
+                activity_queued.refresh_from_db()
+                activity_processed.refresh_from_db()
+
+                self.assertEqual(test_task.state, state_to)
+                self.assertEqual(activity_processed.state, WorkerActivityState.Processed)
+                self.assertEqual(activity_error.state, WorkerActivityState.Error)
+                self.assertEqual(activity_queued.state, WorkerActivityState.Error)
+                self.assertEqual(activity_started.state, WorkerActivityState.Error)
+
+    @patch("arkindex.ponos.models.base64.encodebytes")
+    @patch("arkindex.project.aws.s3")
+    @patch("arkindex.users.models.User.objects.get")
+    @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
+    def test_update_task_final_state_worker_activities_update_worker_run(self, get_user_mock, s3_mock, token_mock):
+        """
+        When a task is updated to a final state, only the worker activities corresponding to that task's worker
+        run are updated.
+        """
+        s3_mock.Object.return_value.bucket_name = "ponos"
+        s3_mock.Object.return_value.key = "somelog"
+        s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
+        s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
+        seal(s3_mock)
+
+        token_mock.side_effect = [b"12345", b"78945",]
+
+        test_process = Process.objects.create(
+            mode=ProcessMode.Workers,
+            creator=self.user,
+            corpus=self.corpus,
+            activity_state=ActivityState.Ready
+        )
+        init_version = WorkerVersion.objects.get(worker__slug="initialisation")
+        init_run = test_process.worker_runs.create(version=init_version, parents=[])
+        test_run = test_process.worker_runs.create(version=self.recognizer, parents=[init_run.id])
+
+        test_process.run()
+
+        elem = self.corpus.elements.first()
+        init_activity = WorkerActivity.objects.create(
+            element_id=elem.id,
+            state=WorkerActivityState.Started,
+            worker_version=init_version,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+        reco_activity = WorkerActivity.objects.create(
+            element_id=elem.id,
+            state=WorkerActivityState.Started,
+            worker_version=self.recognizer,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+
+        test_task = test_process.tasks.get(worker_run=test_run)
+
+        # Authenticate from a possible agent
+        custom_user = copy.copy(self.docker_agent)
+        custom_user.is_active = True
+        custom_user.is_authenticated = True
+
+        get_user_mock.return_value = custom_user
+        self.client.force_login(self.user)
+
+        test_task.state = State.Running
+        test_task.save()
+
+        resp = self.client.put(
+            reverse("api:task-details", args=[test_task.id]),
+            data={"state": State.Completed.value},
+        )
+        self.assertEqual(resp.status_code, status.HTTP_200_OK)
+
+        test_task.refresh_from_db()
+        init_activity.refresh_from_db()
+        reco_activity.refresh_from_db()
+
+        self.assertEqual(test_task.state, State.Completed)
+        self.assertEqual(init_activity.state, WorkerActivityState.Started)
+        self.assertEqual(reco_activity.state, WorkerActivityState.Error)
+
+    @patch("arkindex.ponos.models.base64.encodebytes")
+    @patch("arkindex.project.aws.s3")
+    @patch("arkindex.users.models.User.objects.get")
+    @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
+    def test_update_task_final_state_worker_activities_update_chunks(self, get_user_mock, s3_mock, token_mock):
+        """
+        When a task is updated to a final state, the worker activities corresponding to that task's worker run
+        are only updated if there are no other tasks for that same worker run (different chunks) that are not in
+        a final state.
+        """
+        s3_mock.Object.return_value.bucket_name = "ponos"
+        s3_mock.Object.return_value.key = "somelog"
+        s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
+        s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
+        seal(s3_mock)
+
+        token_mock.side_effect = [b"12345", b"78945", b"77975"]
+
+        test_process = Process.objects.create(
+            mode=ProcessMode.Workers,
+            creator=self.user,
+            corpus=self.corpus,
+            chunks=2,
+            activity_state=ActivityState.Ready
+        )
+        test_run = test_process.worker_runs.create(version=self.recognizer)
+
+        test_process.run()
+
+        elem = self.corpus.elements.first()
+        reco_activity = WorkerActivity.objects.create(
+            element_id=elem.id,
+            state=WorkerActivityState.Started,
+            worker_version=self.recognizer,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+
+        test_task_1, test_task_2 = test_process.tasks.filter(worker_run=test_run).all()
+
+        # Authenticate from a possible agent
+        custom_user = copy.copy(self.docker_agent)
+        custom_user.is_active = True
+        custom_user.is_authenticated = True
+
+        get_user_mock.return_value = custom_user
+        self.client.force_login(self.user)
+
+        test_task_2.state = State.Running
+        test_task_2.save()
+
+        resp = self.client.put(
+            reverse("api:task-details", args=[test_task_2.id]),
+            data={"state": State.Completed.value},
+        )
+        self.assertEqual(resp.status_code, status.HTTP_200_OK)
+
+        test_task_2.refresh_from_db()
+        reco_activity.refresh_from_db()
+        test_task_1.refresh_from_db()
+
+        # Only one of the tasks is in a final state: the WorkerActivityState did not change
+        self.assertEqual(test_task_2.state, State.Completed)
+        self.assertEqual(test_task_1.state, State.Unscheduled)
+        self.assertEqual(reco_activity.state, WorkerActivityState.Started)
+
+        for final_state in FINAL_STATES:
+            with self.subTest(final_state=final_state):
+                # Put test_task_1 in a final state as well
+                test_task_1.state = final_state
+                test_task_2.state = State.Running
+                test_task_1.save()
+                test_task_2.save()
+
+                resp = self.client.put(
+                    reverse("api:task-details", args=[test_task_2.id]),
+                    data={"state": State.Completed.value},
+                )
+                self.assertEqual(resp.status_code, status.HTTP_200_OK)
+
+                test_task_2.refresh_from_db()
+                reco_activity.refresh_from_db()
+                test_task_1.refresh_from_db()
+
+                # The remaining worker activity has been updated to the Error state
+                self.assertEqual(test_task_2.state, State.Completed)
+                self.assertEqual(test_task_1.state, final_state)
+                self.assertEqual(reco_activity.state, WorkerActivityState.Error)
+
+    @override_settings(PUBLIC_HOSTNAME="https://arkindex.localhost")
+    @patch("arkindex.ponos.models.base64.encodebytes")
+    @patch("arkindex.project.aws.s3")
+    @patch("arkindex.users.models.User.objects.get")
+    @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
+    def test_update_task_final_state_worker_activities_update_same_worker_version(self, get_user_mock, s3_mock, token_mock):
+        """
+        When a task is updated to a final state, only the remaining worker activities corresponding to that task's worker
+        run are updated to the error state, not activities belonging to other workers runs with the same worker version.
+        """
+        s3_mock.Object.return_value.bucket_name = "ponos"
+        s3_mock.Object.return_value.key = "somelog"
+        s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
+        s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
+        seal(s3_mock)
+
+        token_mock.side_effect = [b"12345", b"78945", b"77975", b"12995"]
+
+        test_model = Model.objects.create(name="Generic model", public=False)
+        test_model_version = test_model.versions.create(
+            state=ModelVersionState.Available,
+            tag="Test",
+            hash="A" * 32,
+            archive_hash="42",
+            size=1337,
+        )
+        test_configuration = self.recognizer.worker.configurations.create(
+            name="Recognizer configuration", configuration={"value": "test"}
+        )
+
+        test_process = Process.objects.create(
+            mode=ProcessMode.Workers,
+            creator=self.user,
+            corpus=self.corpus,
+            activity_state=ActivityState.Ready
+        )
+        test_run_1 = test_process.worker_runs.create(version=self.recognizer)
+        test_run = test_process.worker_runs.create(version=self.recognizer, model_version_id=test_model_version.id)
+        test_run_2 = test_process.worker_runs.create(version=self.recognizer, model_version_id=test_model_version.id, configuration_id=test_configuration.id)
+
+        test_process.run()
+
+        elem = self.corpus.elements.first()
+        target_activity = WorkerActivity.objects.create(
+            element_id=elem.id,
+            state=WorkerActivityState.Started,
+            worker_version=self.recognizer,
+            model_version=test_model_version,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+        activity_1 = WorkerActivity.objects.create(
+            element_id=elem.id,
+            state=WorkerActivityState.Started,
+            worker_version=self.recognizer,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+        activity_2 = WorkerActivity.objects.create(
+            element_id=elem.id,
+            state=WorkerActivityState.Started,
+            worker_version=self.recognizer,
+            model_version=test_model_version,
+            configuration=test_configuration,
+            process_id=test_process.id,
+            started=datetime.now(timezone.utc)
+        )
+
+        test_task = test_process.tasks.get(worker_run=test_run)
+        task_1 = test_process.tasks.get(worker_run=test_run_1)
+        task_2 = test_process.tasks.get(worker_run=test_run_2)
+
+        # Authenticate from a possible agent
+        custom_user = copy.copy(self.docker_agent)
+        custom_user.is_active = True
+        custom_user.is_authenticated = True
+
+        get_user_mock.return_value = custom_user
+        self.client.force_login(self.user)
+
+        test_task.state = State.Running
+        task_1.state = State.Stopped
+        task_2.state = State.Failed
+        test_task.save()
+        task_1.save()
+        task_2.save()
+
+        resp = self.client.put(
+            reverse("api:task-details", args=[test_task.id]),
+            data={"state": State.Completed.value},
+        )
+        self.assertEqual(resp.status_code, status.HTTP_200_OK)
+
+        test_task.refresh_from_db()
+        target_activity.refresh_from_db()
+        activity_1.refresh_from_db()
+        activity_2.refresh_from_db()
+
+        self.assertEqual(test_task.state, State.Completed)
+        self.assertEqual(target_activity.state, WorkerActivityState.Error)
+        self.assertEqual(activity_1.state, WorkerActivityState.Started)
+        self.assertEqual(activity_2.state, WorkerActivityState.Started)
+
     def test_update_non_running_task_state_stopping(self):
         states = list(State)
         states.remove(State.Running)
-- 
GitLab