Skip to content
Snippets Groups Projects
Commit 381b03b1 authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

Cancel tasks when they reach their TTL

parent 32600e7c
No related branches found
No related tags found
1 merge request!2484Cancel tasks when they reach their TTL
......@@ -415,6 +415,10 @@ class Task(models.Model):
"""
return self.state in FINAL_STATES
@property
def is_over_ttl(self) -> bool:
return self.ttl > 0 and self.started + timedelta(seconds=self.ttl) < timezone.now()
@property
def logs(self) -> TaskLogs:
return TaskLogs(self)
......
......@@ -16,7 +16,7 @@ from django.template.loader import render_to_string
from django_rq import job
import docker
from arkindex.ponos.models import State, Task
from arkindex.ponos.models import FINAL_STATES, State, Task
from arkindex.ponos.utils import decompress_zst_archive, extract_tar_archive, upload_artifact
from arkindex.process.models import Process, WorkerActivityState
from arkindex.project.tools import should_verify_cert
......@@ -211,9 +211,11 @@ def run_docker_task(client, task, temp_dir):
previous_logs = b""
while container.status == "running":
logs = container.logs()
if logs != previous_logs:
upload_logs(task, logs)
previous_logs = logs
# Handle a task that is being stopped during execution
task.refresh_from_db()
if task.state == State.Stopping:
......@@ -222,10 +224,28 @@ def run_docker_task(client, task, temp_dir):
task.finished = datetime.now(timezone.utc)
task.save()
break
# If a task has been running for longer than what its TTL allows, cancel it
if task.is_over_ttl:
container.stop()
task.state = State.Cancelled
task.finished = datetime.now(timezone.utc)
task.save()
break
sleep(TASK_DOCKER_POLLING)
container.reload()
# Upload logs one last time so we do not miss any data
upload_logs(task, container.logs())
logs = container.logs()
if task.state == State.Cancelled:
# For cancelled tasks, log the cancellation explicitly to make it more distinguishable from Error or Failed states
if logs:
# Add a line break if there were some existing logs
logs += b"\n"
logs += f"[ERROR] This task has been cancelled because it has exceeded its TTL of {task.ttl} second{'s' if task.ttl != 1 else ''}.".encode()
upload_logs(task, logs)
# 6. Retrieve the state of the container
container.reload()
......@@ -248,9 +268,9 @@ def run_docker_task(client, task, temp_dir):
logger.warning(
f"Failed uploading artifact {path} for task {task}: {e}"
)
elif task.state != State.Stopped:
# Stopping a task will usually result in a non-zero exit code,
# but we want to report them as Stopped and not Failed so we skip stopped tasks.
elif task.state not in (State.Stopped, State.Cancelled):
# Canceling or stopping a task will usually result in a non-zero exit code,
# but we want to report them as Stopped or Cancelled and not Failed, so we skip those states.
logger.info("Task failed")
task.state = State.Failed
task.finished = datetime.now(timezone.utc)
......@@ -275,7 +295,7 @@ def run_task_rq(task: Task):
# Automatically update children in case an error occurred
if (parent_state := next(
(parent.state for parent in parents if parent.state in (State.Stopped, State.Error, State.Failed)),
(parent.state for parent in parents if parent.state in FINAL_STATES and parent.state != State.Completed),
None
)) is not None:
task.state = parent_state
......
import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import MagicMock, PropertyMock, call, patch, seal
from django.test import override_settings
import docker
from arkindex.ponos.models import Farm, State, Task
from arkindex.ponos.models import State, Task
from arkindex.ponos.tasks import run_docker_task
from arkindex.process.models import ProcessMode
from arkindex.project.tests import FixtureTestCase
......@@ -16,12 +17,10 @@ class TestRunDockerTask(FixtureTestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.farm = Farm.objects.first()
cls.process = cls.corpus.processes.create(
creator=cls.user,
mode=ProcessMode.Workers,
corpus=cls.corpus,
farm=cls.farm,
)
cls.task = cls.process.tasks.create(
slug="something",
......@@ -296,7 +295,7 @@ class TestRunDockerTask(FixtureTestCase):
def add_artifacts(*args):
# sleep() is called after one iteration of the running task loop is complete.
# We use this mock to not wait during tests, check that the task is properly running,
# and create artifacts that the function should upload once the container exists.
# and create artifacts that the function should not upload once the container exists.
self.task.refresh_from_db()
self.assertEqual(self.task.state, State.Running)
self.assertIsNotNone(self.task.started)
......@@ -362,7 +361,6 @@ class TestRunDockerTask(FixtureTestCase):
self.assertEqual(upload_artifact_mock.call_count, 0)
@override_settings(PONOS_DOCKER_AUTO_REMOVE_CONTAINER=True)
@patch("arkindex.ponos.utils.upload_artifact")
@patch("arkindex.ponos.tasks.upload_logs")
......@@ -460,3 +458,98 @@ class TestRunDockerTask(FixtureTestCase):
self.assertEqual(sleep_mock.call_args, call(1))
self.assertEqual(upload_artifact_mock.call_count, 0)
@override_settings(PONOS_DOCKER_AUTO_REMOVE_CONTAINER=False)
@patch("arkindex.ponos.utils.upload_artifact")
@patch("arkindex.ponos.tasks.upload_logs")
@patch("arkindex.ponos.tasks.sleep")
def test_cancelled(self, sleep_mock, upload_logs_mock, upload_artifact_mock):
client_mock = MagicMock()
container = client_mock.containers.run.return_value
# The first two calls occur while the task is running, the third after it has finished.
container.logs.side_effect = [b"Running", b"Running", b"(whilhelm scream)"]
# This will be accessed after the container has been stopped
container.attrs = {"State": {"ExitCode": 1}}
def add_artifacts(*args):
# sleep() is called after one iteration of the running task loop is complete.
# We use this mock to not wait during tests, check that the task is properly running,
# configure it to exceed its TTL, and create artifacts that the function should not
# upload after cancelling the task.
self.task.refresh_from_db()
self.assertEqual(self.task.state, State.Running)
self.assertIsNotNone(self.task.started)
self.assertIsNone(self.task.finished)
self.task.ttl = 1
self.task.created = self.task.started = datetime.now(timezone.utc) - timedelta(seconds=2)
self.task.save()
# This artifact should never be uploaded
(Path(temp_dir) / str(self.task.id) / "something.txt").write_text("blah")
# Set up all the remaining attributes, then seal the mocks so nonexistent attributes can't be accessed
sleep_mock.side_effect = add_artifacts
seal(sleep_mock)
client_mock.images.pull.return_value = None
client_mock.containers.get.side_effect = docker.errors.NotFound("Not found.")
container.reload.return_value = None
container.stop.return_value = None
# Sealing is not friendly with PropertyMocks, so we put a placeholder first
container.status = None
seal(client_mock)
# Limit the amount of times container.status can be accessed, so we can't get stuck in an infinite loop
type(container).status = PropertyMock(side_effect=[
"running", # Loop that checks whether the container is `created` and needs to be awaited
"running", # First iteration of the running task loop
"running", # Second iteration where the task should be cancelled
"exited", # This should not be called but protects us from an infinite loop if the test doesn't go as planned
])
upload_logs_mock.return_value = None
seal(upload_logs_mock)
# We only mock this so that we can make sure we never upload any artifact
seal(upload_artifact_mock)
with tempfile.TemporaryDirectory() as temp_dir:
run_docker_task(client_mock, self.task, Path(temp_dir))
self.task.refresh_from_db()
self.assertEqual(self.task.state, State.Cancelled)
self.assertIsNotNone(self.task.finished)
self.assertEqual(client_mock.images.pull.call_count, 1)
self.assertEqual(client_mock.images.pull.call_args, call("image"))
self.assertEqual(client_mock.containers.get.call_count, 1)
self.assertEqual(client_mock.containers.get.call_args, call(f"ponos-{self.task.id}"))
self.assertEqual(client_mock.containers.run.call_count, 1)
self.assertEqual(client_mock.containers.run.call_args, call(
"image",
environment={"PONOS_DATA": "/data"},
detach=True,
network="host",
volumes={temp_dir: {"bind": "/data", "mode": "rw"}},
name=f"ponos-{self.task.id}",
))
self.assertEqual(container.reload.call_count, 2)
self.assertEqual(container.reload.call_args_list, [call(), call()])
self.assertEqual(container.logs.call_count, 3)
self.assertEqual(container.logs.call_args_list, [call(), call(), call()])
self.assertEqual(container.stop.call_count, 1)
self.assertEqual(container.stop.call_args, call())
self.assertEqual(upload_logs_mock.call_count, 2)
self.assertEqual(upload_logs_mock.call_args_list, [
call(self.task, b"Running"),
call(self.task, b"(whilhelm scream)\n[ERROR] This task has been cancelled because it has exceeded its TTL of 1 second."),
])
self.assertEqual(sleep_mock.call_count, 1)
self.assertEqual(sleep_mock.call_args, call(1))
self.assertEqual(upload_artifact_mock.call_count, 0)
......@@ -65,7 +65,7 @@ class TestRunTaskRQ(FixtureTestCase):
depth=0,
ttl=0,
)
for state in {State.Stopped, State.Error, State.Failed}:
for state in {State.Stopped, State.Error, State.Failed, State.Cancelled}:
self.task.state = State.Pending
self.task.save()
with self.subTest(state=state):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment