From fc247d34b5c743efd860ff1dbbf03bb974394afe Mon Sep 17 00:00:00 2001 From: ml bonhomme <bonhomme@teklia.com> Date: Fri, 6 Sep 2024 07:45:13 +0000 Subject: [PATCH] Update remaining worker activities to error when a Task is updated to a final state --- arkindex/ponos/serializers.py | 27 +++ arkindex/ponos/tests/test_api.py | 371 ++++++++++++++++++++++++++++++- 2 files changed, 396 insertions(+), 2 deletions(-) diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py index ae35b00896..fdc09a5a90 100644 --- a/arkindex/ponos/serializers.py +++ b/arkindex/ponos/serializers.py @@ -12,6 +12,7 @@ from rest_framework.exceptions import ValidationError from arkindex.ponos.models import FINAL_STATES, Agent, AgentMode, Artifact, State, Task from arkindex.ponos.signals import task_failure +from arkindex.process.models import ActivityState, WorkerActivityState from arkindex.project.serializer_fields import EnumField from arkindex.project.triggers import notify_process_completion @@ -148,6 +149,32 @@ class TaskSerializer(TaskLightSerializer): .update(state=State.Pending) ) + # When a task is updated to a final state, if it is the last one of its worker_run + run + depth to be + # so, then all the remaining started or queued worker activities must get their state updated to error + # (to allow the elements to be reprocessed) + # Only check for this when the process' activity state is `ready` and the task is linked to a worker run, + # as it is otherwise not relevant. + if ( + instance.process.activity_state == ActivityState.Ready + and instance.worker_run + and not ( + instance.process.tasks + .using("default") + .filter(run=instance.run, depth=instance.depth, worker_run=instance.worker_run) + .exclude(state__in=FINAL_STATES) + .exists() + ) + ): + instance.process.activities.filter( + worker_version_id=instance.worker_run.version_id, + configuration_id=instance.worker_run.configuration_id, + model_version_id=instance.worker_run.model_version_id + ).exclude( + state__in=[WorkerActivityState.Processed, WorkerActivityState.Error] + ).update( + state=WorkerActivityState.Error + ) + # When some tasks have failed, tasks that depend on them will remain unscheduled. # When a task is completed, its children will be set to pending. # All children have a depth above zero, since they have at least one dependency. diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py index a6459df950..9b59df1ef9 100644 --- a/arkindex/ponos/tests/test_api.py +++ b/arkindex/ponos/tests/test_api.py @@ -11,9 +11,17 @@ from django.urls import reverse from rest_framework import status from arkindex.documents.models import Corpus -from arkindex.ponos.models import Agent, AgentMode, Farm, State, Task -from arkindex.process.models import Process, ProcessMode, WorkerVersion +from arkindex.ponos.models import FINAL_STATES, Agent, AgentMode, Farm, State, Task +from arkindex.process.models import ( + ActivityState, + Process, + ProcessMode, + WorkerActivity, + WorkerActivityState, + WorkerVersion, +) from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Model, ModelVersionState from arkindex.users.models import Right, Role, User @@ -335,6 +343,365 @@ class TestAPI(FixtureAPITestCase): self.task1.refresh_from_db() self.assertEqual(self.task1.state, State.Stopping) + @patch("arkindex.ponos.models.base64.encodebytes") + @patch("arkindex.project.aws.s3") + @patch("arkindex.users.models.User.objects.get") + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) + def test_update_task_final_state_worker_activities_update(self, get_user_mock, s3_mock, token_mock): + """ + When a task is updated to a final state, only the remaining worker activities that are not in + the processed (or already in error) state are updated to the error state. + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")} + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + seal(s3_mock) + + token_mock.side_effect = [b"12345", b"78945",] + + test_process = Process.objects.create( + mode=ProcessMode.Workers, + creator=self.user, + corpus=self.corpus, + activity_state=ActivityState.Ready + ) + init_run = test_process.worker_runs.create(version=WorkerVersion.objects.get(worker__slug="initialisation"), parents=[]) + test_run = test_process.worker_runs.create(version=self.recognizer, parents=[init_run.id]) + + test_process.run() + + elem_1, elem_2, elem_3, elem_4 = self.corpus.elements.all()[0:4] + activity_processed = WorkerActivity.objects.create( + element_id=elem_1.id, + state=WorkerActivityState.Processed, + worker_version=self.recognizer, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + activity_queued = WorkerActivity.objects.create( + element_id=elem_2.id, + state=WorkerActivityState.Queued, + worker_version=self.recognizer, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + activity_started = WorkerActivity.objects.create( + element_id=elem_3.id, + state=WorkerActivityState.Started, + worker_version=self.recognizer, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + activity_error = WorkerActivity.objects.create( + element_id=elem_4.id, + state=WorkerActivityState.Error, + worker_version=self.recognizer, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + + test_task = test_process.tasks.get(worker_run=test_run) + + # Authenticate from a possible agent + custom_user = copy.copy(self.docker_agent) + custom_user.is_active = True + custom_user.is_authenticated = True + + get_user_mock.return_value = custom_user + self.client.force_login(self.user) + + # Get the allowed transitions to a final state + cases = [item for item in self.docker_task_transitions if item[1] in FINAL_STATES] + + for state_from, state_to in cases: + with self.subTest(state_from=state_from, state_to=state_to): + + test_task.state = state_from + test_task.save() + + resp = self.client.put( + reverse("api:task-details", args=[test_task.id]), + data={"state": state_to.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + test_task.refresh_from_db() + activity_started.refresh_from_db() + activity_error.refresh_from_db() + activity_queued.refresh_from_db() + activity_processed.refresh_from_db() + + self.assertEqual(test_task.state, state_to) + self.assertEqual(activity_processed.state, WorkerActivityState.Processed) + self.assertEqual(activity_error.state, WorkerActivityState.Error) + self.assertEqual(activity_queued.state, WorkerActivityState.Error) + self.assertEqual(activity_started.state, WorkerActivityState.Error) + + @patch("arkindex.ponos.models.base64.encodebytes") + @patch("arkindex.project.aws.s3") + @patch("arkindex.users.models.User.objects.get") + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) + def test_update_task_final_state_worker_activities_update_worker_run(self, get_user_mock, s3_mock, token_mock): + """ + When a task is updated to a final state, only the worker activities corresponding to that task's worker + run are updated. + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")} + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + seal(s3_mock) + + token_mock.side_effect = [b"12345", b"78945",] + + test_process = Process.objects.create( + mode=ProcessMode.Workers, + creator=self.user, + corpus=self.corpus, + activity_state=ActivityState.Ready + ) + init_version = WorkerVersion.objects.get(worker__slug="initialisation") + init_run = test_process.worker_runs.create(version=init_version, parents=[]) + test_run = test_process.worker_runs.create(version=self.recognizer, parents=[init_run.id]) + + test_process.run() + + elem = self.corpus.elements.first() + init_activity = WorkerActivity.objects.create( + element_id=elem.id, + state=WorkerActivityState.Started, + worker_version=init_version, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + reco_activity = WorkerActivity.objects.create( + element_id=elem.id, + state=WorkerActivityState.Started, + worker_version=self.recognizer, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + + test_task = test_process.tasks.get(worker_run=test_run) + + # Authenticate from a possible agent + custom_user = copy.copy(self.docker_agent) + custom_user.is_active = True + custom_user.is_authenticated = True + + get_user_mock.return_value = custom_user + self.client.force_login(self.user) + + test_task.state = State.Running + test_task.save() + + resp = self.client.put( + reverse("api:task-details", args=[test_task.id]), + data={"state": State.Completed.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + test_task.refresh_from_db() + init_activity.refresh_from_db() + reco_activity.refresh_from_db() + + self.assertEqual(test_task.state, State.Completed) + self.assertEqual(init_activity.state, WorkerActivityState.Started) + self.assertEqual(reco_activity.state, WorkerActivityState.Error) + + @patch("arkindex.ponos.models.base64.encodebytes") + @patch("arkindex.project.aws.s3") + @patch("arkindex.users.models.User.objects.get") + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) + def test_update_task_final_state_worker_activities_update_chunks(self, get_user_mock, s3_mock, token_mock): + """ + When a task is updated to a final state, the worker activities corresponding to that task's worker run + are only updated if there are no other tasks for that same worker run (different chunks) that are not in + a final state. + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")} + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + seal(s3_mock) + + token_mock.side_effect = [b"12345", b"78945", b"77975"] + + test_process = Process.objects.create( + mode=ProcessMode.Workers, + creator=self.user, + corpus=self.corpus, + chunks=2, + activity_state=ActivityState.Ready + ) + test_run = test_process.worker_runs.create(version=self.recognizer) + + test_process.run() + + elem = self.corpus.elements.first() + reco_activity = WorkerActivity.objects.create( + element_id=elem.id, + state=WorkerActivityState.Started, + worker_version=self.recognizer, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + + test_task_1, test_task_2 = test_process.tasks.filter(worker_run=test_run).all() + + # Authenticate from a possible agent + custom_user = copy.copy(self.docker_agent) + custom_user.is_active = True + custom_user.is_authenticated = True + + get_user_mock.return_value = custom_user + self.client.force_login(self.user) + + test_task_2.state = State.Running + test_task_2.save() + + resp = self.client.put( + reverse("api:task-details", args=[test_task_2.id]), + data={"state": State.Completed.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + test_task_2.refresh_from_db() + reco_activity.refresh_from_db() + test_task_1.refresh_from_db() + + # Only one of the tasks is in a final state: the WorkerActivityState did not change + self.assertEqual(test_task_2.state, State.Completed) + self.assertEqual(test_task_1.state, State.Unscheduled) + self.assertEqual(reco_activity.state, WorkerActivityState.Started) + + for final_state in FINAL_STATES: + with self.subTest(final_state=final_state): + # Put test_task_1 in a final state as well + test_task_1.state = final_state + test_task_2.state = State.Running + test_task_1.save() + test_task_2.save() + + resp = self.client.put( + reverse("api:task-details", args=[test_task_2.id]), + data={"state": State.Completed.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + test_task_2.refresh_from_db() + reco_activity.refresh_from_db() + test_task_1.refresh_from_db() + + # The remaining worker activity has been updated to the Error state + self.assertEqual(test_task_2.state, State.Completed) + self.assertEqual(test_task_1.state, final_state) + self.assertEqual(reco_activity.state, WorkerActivityState.Error) + + @override_settings(PUBLIC_HOSTNAME="https://arkindex.localhost") + @patch("arkindex.ponos.models.base64.encodebytes") + @patch("arkindex.project.aws.s3") + @patch("arkindex.users.models.User.objects.get") + @patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple()) + def test_update_task_final_state_worker_activities_update_same_worker_version(self, get_user_mock, s3_mock, token_mock): + """ + When a task is updated to a final state, only the remaining worker activities corresponding to that task's worker + run are updated to the error state, not activities belonging to other workers runs with the same worker version. + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")} + s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + seal(s3_mock) + + token_mock.side_effect = [b"12345", b"78945", b"77975", b"12995"] + + test_model = Model.objects.create(name="Generic model", public=False) + test_model_version = test_model.versions.create( + state=ModelVersionState.Available, + tag="Test", + hash="A" * 32, + archive_hash="42", + size=1337, + ) + test_configuration = self.recognizer.worker.configurations.create( + name="Recognizer configuration", configuration={"value": "test"} + ) + + test_process = Process.objects.create( + mode=ProcessMode.Workers, + creator=self.user, + corpus=self.corpus, + activity_state=ActivityState.Ready + ) + test_run_1 = test_process.worker_runs.create(version=self.recognizer) + test_run = test_process.worker_runs.create(version=self.recognizer, model_version_id=test_model_version.id) + test_run_2 = test_process.worker_runs.create(version=self.recognizer, model_version_id=test_model_version.id, configuration_id=test_configuration.id) + + test_process.run() + + elem = self.corpus.elements.first() + target_activity = WorkerActivity.objects.create( + element_id=elem.id, + state=WorkerActivityState.Started, + worker_version=self.recognizer, + model_version=test_model_version, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + activity_1 = WorkerActivity.objects.create( + element_id=elem.id, + state=WorkerActivityState.Started, + worker_version=self.recognizer, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + activity_2 = WorkerActivity.objects.create( + element_id=elem.id, + state=WorkerActivityState.Started, + worker_version=self.recognizer, + model_version=test_model_version, + configuration=test_configuration, + process_id=test_process.id, + started=datetime.now(timezone.utc) + ) + + test_task = test_process.tasks.get(worker_run=test_run) + task_1 = test_process.tasks.get(worker_run=test_run_1) + task_2 = test_process.tasks.get(worker_run=test_run_2) + + # Authenticate from a possible agent + custom_user = copy.copy(self.docker_agent) + custom_user.is_active = True + custom_user.is_authenticated = True + + get_user_mock.return_value = custom_user + self.client.force_login(self.user) + + test_task.state = State.Running + task_1.state = State.Stopped + task_2.state = State.Failed + test_task.save() + task_1.save() + task_2.save() + + resp = self.client.put( + reverse("api:task-details", args=[test_task.id]), + data={"state": State.Completed.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + test_task.refresh_from_db() + target_activity.refresh_from_db() + activity_1.refresh_from_db() + activity_2.refresh_from_db() + + self.assertEqual(test_task.state, State.Completed) + self.assertEqual(target_activity.state, WorkerActivityState.Error) + self.assertEqual(activity_1.state, WorkerActivityState.Started) + self.assertEqual(activity_2.state, WorkerActivityState.Started) + def test_update_non_running_task_state_stopping(self): states = list(State) states.remove(State.Running) -- GitLab