Skip to content
Snippets Groups Projects
Commit fc247d34 authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Update remaining worker activities to error when a Task is updated to a final state

parent e489579e
No related branches found
No related tags found
1 merge request!2428Update remaining worker activities to error when a Task is updated to a final state
......@@ -12,6 +12,7 @@ from rest_framework.exceptions import ValidationError
from arkindex.ponos.models import FINAL_STATES, Agent, AgentMode, Artifact, State, Task
from arkindex.ponos.signals import task_failure
from arkindex.process.models import ActivityState, WorkerActivityState
from arkindex.project.serializer_fields import EnumField
from arkindex.project.triggers import notify_process_completion
......@@ -148,6 +149,32 @@ class TaskSerializer(TaskLightSerializer):
.update(state=State.Pending)
)
# When a task is updated to a final state, if it is the last one of its worker_run + run + depth to be
# so, then all the remaining started or queued worker activities must get their state updated to error
# (to allow the elements to be reprocessed)
# Only check for this when the process' activity state is `ready` and the task is linked to a worker run,
# as it is otherwise not relevant.
if (
instance.process.activity_state == ActivityState.Ready
and instance.worker_run
and not (
instance.process.tasks
.using("default")
.filter(run=instance.run, depth=instance.depth, worker_run=instance.worker_run)
.exclude(state__in=FINAL_STATES)
.exists()
)
):
instance.process.activities.filter(
worker_version_id=instance.worker_run.version_id,
configuration_id=instance.worker_run.configuration_id,
model_version_id=instance.worker_run.model_version_id
).exclude(
state__in=[WorkerActivityState.Processed, WorkerActivityState.Error]
).update(
state=WorkerActivityState.Error
)
# When some tasks have failed, tasks that depend on them will remain unscheduled.
# When a task is completed, its children will be set to pending.
# All children have a depth above zero, since they have at least one dependency.
......
......@@ -11,9 +11,17 @@ from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.ponos.models import Agent, AgentMode, Farm, State, Task
from arkindex.process.models import Process, ProcessMode, WorkerVersion
from arkindex.ponos.models import FINAL_STATES, Agent, AgentMode, Farm, State, Task
from arkindex.process.models import (
ActivityState,
Process,
ProcessMode,
WorkerActivity,
WorkerActivityState,
WorkerVersion,
)
from arkindex.project.tests import FixtureAPITestCase
from arkindex.training.models import Model, ModelVersionState
from arkindex.users.models import Right, Role, User
......@@ -335,6 +343,365 @@ class TestAPI(FixtureAPITestCase):
self.task1.refresh_from_db()
self.assertEqual(self.task1.state, State.Stopping)
@patch("arkindex.ponos.models.base64.encodebytes")
@patch("arkindex.project.aws.s3")
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
def test_update_task_final_state_worker_activities_update(self, get_user_mock, s3_mock, token_mock):
"""
When a task is updated to a final state, only the remaining worker activities that are not in
the processed (or already in error) state are updated to the error state.
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
token_mock.side_effect = [b"12345", b"78945",]
test_process = Process.objects.create(
mode=ProcessMode.Workers,
creator=self.user,
corpus=self.corpus,
activity_state=ActivityState.Ready
)
init_run = test_process.worker_runs.create(version=WorkerVersion.objects.get(worker__slug="initialisation"), parents=[])
test_run = test_process.worker_runs.create(version=self.recognizer, parents=[init_run.id])
test_process.run()
elem_1, elem_2, elem_3, elem_4 = self.corpus.elements.all()[0:4]
activity_processed = WorkerActivity.objects.create(
element_id=elem_1.id,
state=WorkerActivityState.Processed,
worker_version=self.recognizer,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
activity_queued = WorkerActivity.objects.create(
element_id=elem_2.id,
state=WorkerActivityState.Queued,
worker_version=self.recognizer,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
activity_started = WorkerActivity.objects.create(
element_id=elem_3.id,
state=WorkerActivityState.Started,
worker_version=self.recognizer,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
activity_error = WorkerActivity.objects.create(
element_id=elem_4.id,
state=WorkerActivityState.Error,
worker_version=self.recognizer,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
test_task = test_process.tasks.get(worker_run=test_run)
# Authenticate from a possible agent
custom_user = copy.copy(self.docker_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
# Get the allowed transitions to a final state
cases = [item for item in self.docker_task_transitions if item[1] in FINAL_STATES]
for state_from, state_to in cases:
with self.subTest(state_from=state_from, state_to=state_to):
test_task.state = state_from
test_task.save()
resp = self.client.put(
reverse("api:task-details", args=[test_task.id]),
data={"state": state_to.value},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
test_task.refresh_from_db()
activity_started.refresh_from_db()
activity_error.refresh_from_db()
activity_queued.refresh_from_db()
activity_processed.refresh_from_db()
self.assertEqual(test_task.state, state_to)
self.assertEqual(activity_processed.state, WorkerActivityState.Processed)
self.assertEqual(activity_error.state, WorkerActivityState.Error)
self.assertEqual(activity_queued.state, WorkerActivityState.Error)
self.assertEqual(activity_started.state, WorkerActivityState.Error)
@patch("arkindex.ponos.models.base64.encodebytes")
@patch("arkindex.project.aws.s3")
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
def test_update_task_final_state_worker_activities_update_worker_run(self, get_user_mock, s3_mock, token_mock):
"""
When a task is updated to a final state, only the worker activities corresponding to that task's worker
run are updated.
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
token_mock.side_effect = [b"12345", b"78945",]
test_process = Process.objects.create(
mode=ProcessMode.Workers,
creator=self.user,
corpus=self.corpus,
activity_state=ActivityState.Ready
)
init_version = WorkerVersion.objects.get(worker__slug="initialisation")
init_run = test_process.worker_runs.create(version=init_version, parents=[])
test_run = test_process.worker_runs.create(version=self.recognizer, parents=[init_run.id])
test_process.run()
elem = self.corpus.elements.first()
init_activity = WorkerActivity.objects.create(
element_id=elem.id,
state=WorkerActivityState.Started,
worker_version=init_version,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
reco_activity = WorkerActivity.objects.create(
element_id=elem.id,
state=WorkerActivityState.Started,
worker_version=self.recognizer,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
test_task = test_process.tasks.get(worker_run=test_run)
# Authenticate from a possible agent
custom_user = copy.copy(self.docker_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
test_task.state = State.Running
test_task.save()
resp = self.client.put(
reverse("api:task-details", args=[test_task.id]),
data={"state": State.Completed.value},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
test_task.refresh_from_db()
init_activity.refresh_from_db()
reco_activity.refresh_from_db()
self.assertEqual(test_task.state, State.Completed)
self.assertEqual(init_activity.state, WorkerActivityState.Started)
self.assertEqual(reco_activity.state, WorkerActivityState.Error)
@patch("arkindex.ponos.models.base64.encodebytes")
@patch("arkindex.project.aws.s3")
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
def test_update_task_final_state_worker_activities_update_chunks(self, get_user_mock, s3_mock, token_mock):
"""
When a task is updated to a final state, the worker activities corresponding to that task's worker run
are only updated if there are no other tasks for that same worker run (different chunks) that are not in
a final state.
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
token_mock.side_effect = [b"12345", b"78945", b"77975"]
test_process = Process.objects.create(
mode=ProcessMode.Workers,
creator=self.user,
corpus=self.corpus,
chunks=2,
activity_state=ActivityState.Ready
)
test_run = test_process.worker_runs.create(version=self.recognizer)
test_process.run()
elem = self.corpus.elements.first()
reco_activity = WorkerActivity.objects.create(
element_id=elem.id,
state=WorkerActivityState.Started,
worker_version=self.recognizer,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
test_task_1, test_task_2 = test_process.tasks.filter(worker_run=test_run).all()
# Authenticate from a possible agent
custom_user = copy.copy(self.docker_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
test_task_2.state = State.Running
test_task_2.save()
resp = self.client.put(
reverse("api:task-details", args=[test_task_2.id]),
data={"state": State.Completed.value},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
test_task_2.refresh_from_db()
reco_activity.refresh_from_db()
test_task_1.refresh_from_db()
# Only one of the tasks is in a final state: the WorkerActivityState did not change
self.assertEqual(test_task_2.state, State.Completed)
self.assertEqual(test_task_1.state, State.Unscheduled)
self.assertEqual(reco_activity.state, WorkerActivityState.Started)
for final_state in FINAL_STATES:
with self.subTest(final_state=final_state):
# Put test_task_1 in a final state as well
test_task_1.state = final_state
test_task_2.state = State.Running
test_task_1.save()
test_task_2.save()
resp = self.client.put(
reverse("api:task-details", args=[test_task_2.id]),
data={"state": State.Completed.value},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
test_task_2.refresh_from_db()
reco_activity.refresh_from_db()
test_task_1.refresh_from_db()
# The remaining worker activity has been updated to the Error state
self.assertEqual(test_task_2.state, State.Completed)
self.assertEqual(test_task_1.state, final_state)
self.assertEqual(reco_activity.state, WorkerActivityState.Error)
@override_settings(PUBLIC_HOSTNAME="https://arkindex.localhost")
@patch("arkindex.ponos.models.base64.encodebytes")
@patch("arkindex.project.aws.s3")
@patch("arkindex.users.models.User.objects.get")
@patch("arkindex.ponos.api.TaskDetails.permission_classes", tuple())
def test_update_task_final_state_worker_activities_update_same_worker_version(self, get_user_mock, s3_mock, token_mock):
"""
When a task is updated to a final state, only the remaining worker activities corresponding to that task's worker
run are updated to the error state, not activities belonging to other workers runs with the same worker version.
"""
s3_mock.Object.return_value.bucket_name = "ponos"
s3_mock.Object.return_value.key = "somelog"
s3_mock.Object.return_value.get.return_value = {"Body": BytesIO(b"Failed successfully")}
s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere"
seal(s3_mock)
token_mock.side_effect = [b"12345", b"78945", b"77975", b"12995"]
test_model = Model.objects.create(name="Generic model", public=False)
test_model_version = test_model.versions.create(
state=ModelVersionState.Available,
tag="Test",
hash="A" * 32,
archive_hash="42",
size=1337,
)
test_configuration = self.recognizer.worker.configurations.create(
name="Recognizer configuration", configuration={"value": "test"}
)
test_process = Process.objects.create(
mode=ProcessMode.Workers,
creator=self.user,
corpus=self.corpus,
activity_state=ActivityState.Ready
)
test_run_1 = test_process.worker_runs.create(version=self.recognizer)
test_run = test_process.worker_runs.create(version=self.recognizer, model_version_id=test_model_version.id)
test_run_2 = test_process.worker_runs.create(version=self.recognizer, model_version_id=test_model_version.id, configuration_id=test_configuration.id)
test_process.run()
elem = self.corpus.elements.first()
target_activity = WorkerActivity.objects.create(
element_id=elem.id,
state=WorkerActivityState.Started,
worker_version=self.recognizer,
model_version=test_model_version,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
activity_1 = WorkerActivity.objects.create(
element_id=elem.id,
state=WorkerActivityState.Started,
worker_version=self.recognizer,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
activity_2 = WorkerActivity.objects.create(
element_id=elem.id,
state=WorkerActivityState.Started,
worker_version=self.recognizer,
model_version=test_model_version,
configuration=test_configuration,
process_id=test_process.id,
started=datetime.now(timezone.utc)
)
test_task = test_process.tasks.get(worker_run=test_run)
task_1 = test_process.tasks.get(worker_run=test_run_1)
task_2 = test_process.tasks.get(worker_run=test_run_2)
# Authenticate from a possible agent
custom_user = copy.copy(self.docker_agent)
custom_user.is_active = True
custom_user.is_authenticated = True
get_user_mock.return_value = custom_user
self.client.force_login(self.user)
test_task.state = State.Running
task_1.state = State.Stopped
task_2.state = State.Failed
test_task.save()
task_1.save()
task_2.save()
resp = self.client.put(
reverse("api:task-details", args=[test_task.id]),
data={"state": State.Completed.value},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
test_task.refresh_from_db()
target_activity.refresh_from_db()
activity_1.refresh_from_db()
activity_2.refresh_from_db()
self.assertEqual(test_task.state, State.Completed)
self.assertEqual(target_activity.state, WorkerActivityState.Error)
self.assertEqual(activity_1.state, WorkerActivityState.Started)
self.assertEqual(activity_2.state, WorkerActivityState.Started)
def test_update_non_running_task_state_stopping(self):
states = list(State)
states.remove(State.Running)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment