Skip to content
Snippets Groups Projects
Commit 96ee6d26 authored by Valentin Rigal's avatar Valentin Rigal Committed by Erwan Rouchet
Browse files

Execute docker tasks in RQ

parent 87e7af5d
No related branches found
No related tags found
1 merge request!2227Execute docker tasks in RQ
Showing
with 288 additions and 18 deletions
......@@ -19,7 +19,7 @@ binary:
CI_PROJECT_DIR=$(ROOT_DIR) CI_REGISTRY_IMAGE=$(IMAGE_TAG) $(ROOT_DIR)/ci/build.sh Dockerfile.binary -binary
worker:
arkindex/manage.py rqworker -v 2 default high
arkindex/manage.py rqworker -v 2 default high tasks
test-fixtures:
$(eval export PGPASSWORD=devdata)
......
......@@ -160,6 +160,11 @@ We use [rq](https://python-rq.org/), integrated via [django-rq](https://pypi.org
To run them, use `make worker` to start a RQ worker. You will need to have Redis running; `make slim` or `make` in the architecture will provide it. `make` in the architecture also provides a RQ worker running in Docker from a binary build.
Process tasks are run in RQ by default (Community Edition). Two RQ workers must be running at the same time to actually run a process with worker activities, so the initialisation task can wait for the worker activity task to finish:
```sh
$ manage.py rqworker -v 3 default high & manage.py rqworker -v 3 tasks
```
## Metrics
The application serves metrics for Prometheus under the `/metrics` prefix.
A specific port can be used by setting the `PROMETHEUS_METRICS_PORT` environment variable, thus separating the application from the metrics API.
import logging
import tempfile
from io import BytesIO
from pathlib import Path
from time import sleep
from urllib.parse import urljoin
from django.conf import settings
......@@ -9,10 +13,17 @@ from django.shortcuts import reverse
from django.template.loader import render_to_string
from django_rq import job
import docker
from arkindex.ponos.models import State, Task
from arkindex.ponos.utils import upload_artifact
from arkindex.process.models import Process, WorkerActivityState
from docker.errors import APIError, ImageNotFound
logger = logging.getLogger(__name__)
# Delay for polling docker task's logs in seconds
TASK_DOCKER_POLLING = 1
@job("default", timeout=settings.RQ_TIMEOUTS["notify_process_completion"])
def notify_process_completion(
......@@ -64,3 +75,153 @@ def notify_process_completion(
recipient_list=[process.creator.email],
fail_silently=False,
)
def upload_logs(task, text):
try:
task.logs.s3_object.upload_fileobj(
BytesIO(text),
ExtraArgs={"ContentType": "text/plain; charset=utf-8"},
)
except Exception as e:
logger.warning(f"Failed uploading logs for task {task}: {e}")
def run_docker_task(client, task, temp_dir):
# 1. Pull the docker image
logger.debug(f"Pulling docker image '{task.image}'")
try:
client.images.pull(task.image)
except (ImageNotFound, APIError) as e:
# Pulling is allowed to fail when the image is already present locally (local builds)
if not client.images.list(task.image):
raise Exception(f"Image not found locally nor remotely: {e}")
logger.info("Remote image could not be fetched, using the local image.")
# 2. Fetch artifacts
logger.info("Fetching artifacts from parents")
for parent in task.parents.order_by("depth", "id"):
folder = temp_dir / str(parent.slug)
folder.mkdir()
for artifact in parent.artifacts.all():
path = (folder / artifact.path).resolve()
# Ensure path is a children of folder
assert str(folder.resolve()) in str(path.resolve()), "Invalid artifact path: {artifact.path}."
artifact.download_to(str(path))
# 3. Do run the container asynchronously
logger.debug("Running container")
kwargs = {
"environment": {
**task.env,
"PONOS_DATA": settings.PONOS_DATA_DIR,
},
"detach": True,
"network": "host",
"volumes": {temp_dir: {"bind": str(settings.PONOS_DATA_DIR), "mode": "rw"}},
}
artifacts_dir = temp_dir / str(task.id)
artifacts_dir.mkdir()
# The symlink will only work within docker context as bound to PONOS_DATA_DIR/<task_uuid>/
(temp_dir / "current").symlink_to(Path(settings.PONOS_DATA_DIR) / str(task.id))
if task.requires_gpu:
# Assign all GPUs to that container
# https://github.com/docker/docker-py/issues/2395#issuecomment-907243275
kwargs["device_requests"] = [
docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])
]
logger.info("Starting container with GPU support")
if task.command is not None:
kwargs["command"] = task.command
container = client.containers.run(task.image, **kwargs)
while container.status == "created":
container.reload()
task.state = State.Running
task.save()
# 4. Read logs
logger.debug("Reading logs from the docker container")
previous_logs = b""
while container.status == "running":
logs = container.logs()
if logs != previous_logs:
upload_logs(task, logs)
previous_logs = logs
# Handle a task that is being stopped during execution
task.refresh_from_db()
if task.state == State.Stopping:
container.stop()
task.state = State.Stopped
task.save()
break
sleep(TASK_DOCKER_POLLING)
container.reload()
# Upload logs one last time so we do not miss any data
upload_logs(task, container.logs())
# 5. Retrieve the state of the container
container.reload()
exit_code = container.attrs["State"]["ExitCode"]
if exit_code != 0:
logger.info("Task failed")
task.state = State.Failed
task.save()
return
task.state = State.Completed
task.save()
# 6. Upload artifacts
logger.info(f"Uploading artifacts for task {task}")
for path in Path(artifacts_dir).glob("**/*"):
if path.is_dir():
continue
try:
upload_artifact(task, path, artifacts_dir)
except Exception as e:
logger.warning(
f"Failed uploading artifact {path} for task {task}: {e}"
)
@job("tasks", timeout=settings.RQ_TIMEOUTS["task"])
def run_task_rq(task: Task):
"""Run a single task in RQ"""
# Update task and parents from the DB
task.refresh_from_db()
parents = list(task.parents.order_by("depth", "id"))
client = docker.from_env()
if not task.image:
raise ValueError(f"Task {task} has no docker image.")
if task.state != State.Pending:
raise ValueError(f"Task {task} must be in pending state to run in RQ.")
# Automatically update children in case an error occurred
if (parent_state := next(
(parent.state for parent in parents if parent.state in (State.Stopped, State.Error, State.Failed)),
None
)) is not None:
task.state = parent_state
task.save()
return
with tempfile.TemporaryDirectory(suffix=f"_{task.id}") as temp_dir:
try:
run_docker_task(client, task, Path(temp_dir))
except Exception as e:
logger.error(f"An unexpected error occurred, updating state to Error: {e}")
task.state = State.Error
task.save()
# Add unexpected error details to task logs
text = BytesIO()
if task.logs.exists():
task.logs.s3_object.download_fileobj(text)
text = text.getvalue()
text += f"\nPonos exception: {e}".encode()
upload_logs(task, text)
raise e
from unittest.mock import call, patch
from django.test import override_settings
from arkindex.ponos.models import Farm
from arkindex.process.models import ProcessMode, WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
class TestModels(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.farm = Farm.objects.create(name="Farm")
cls.process = cls.corpus.processes.create(
creator=cls.user,
mode=ProcessMode.Workers,
corpus=cls.corpus,
farm=cls.farm,
)
cls.worker_version1 = WorkerVersion.objects.get(worker__slug="reco")
cls.worker_version2 = WorkerVersion.objects.get(worker__slug="dla")
cls.run1 = cls.process.worker_runs.create(version=cls.worker_version1, parents=[])
cls.run2 = cls.process.worker_runs.create(version=cls.worker_version2, parents=[cls.run1.id])
@override_settings(PONOS_RQ_EXECUTION=True)
@patch("arkindex.ponos.tasks.run_task_rq.delay")
def test_run_process_schedules_tasks(self, run_task_mock):
run_task_mock.side_effect = ["init_job", "job1", "job2"]
# on_commit will not be called automatically as savepoints are used for tests transactions
with self.captureOnCommitCallbacks(execute=True):
self.process.run()
init, t1, t2 = self.process.tasks.order_by("depth")
self.assertEqual(run_task_mock.call_count, 3)
call1, call2, call3 = run_task_mock.call_args_list
self.assertEqual(call1, call(init))
self.assertTupleEqual(call2.args, (t1,))
self.assertListEqual(list(call2.kwargs.keys()), ["depends_on"])
task1_depends = call2.kwargs["depends_on"]
self.assertEqual(vars(task1_depends), {
"dependencies": ["init_job"],
"allow_failure": True,
"enqueue_at_front": False,
})
self.assertTupleEqual(call3.args, (t2,))
self.assertListEqual(list(call3.kwargs.keys()), ["depends_on"])
task2_depends = call3.kwargs["depends_on"]
self.assertEqual(vars(task2_depends), {
"dependencies": ["job1"],
"allow_failure": True,
"enqueue_at_front": False,
})
import os
import magic
from arkindex.ponos.models import Task
......@@ -8,3 +11,14 @@ def is_admin_or_ponos_task(request):
def get_process_from_task_auth(request):
if isinstance(request.auth, Task):
return request.auth.process
def upload_artifact(task, path, artifacts_dir):
content_type = magic.from_file(path, mime=True)
size = os.path.getsize(path)
artifact = task.artifacts.create(
path=os.path.relpath(path, artifacts_dir),
content_type=content_type,
size=size,
)
artifact.s3_object.upload_file(str(path), ExtraArgs={"ContentType": content_type})
......@@ -393,3 +393,4 @@ class ProcessBuilder(object):
for child_slug, parent_slugs in self.tasks_parents.items()
for parent_slug in parent_slugs
)
return tasks
import urllib.parse
import uuid
from functools import partial
from typing import Optional
from django.conf import settings
......@@ -390,11 +391,16 @@ class Process(IndexableModel):
"""
Build and start a new run for this process.
"""
from arkindex.project.triggers import schedule_tasks
process_builder = ProcessBuilder(self)
process_builder.validate()
process_builder.build()
# Save all tasks and their relations
process_builder.save()
if settings.PONOS_RQ_EXECUTION:
# Trigger tasks execution in RQ after the current transaction so tasks have been created
transaction.on_commit(partial(schedule_tasks, process=self, run=process_builder.run))
self.started = timezone.now()
self.finished = None
......
from django.urls import reverse
from rest_framework import status
from rest_framework.serializers import DateTimeField
......
......@@ -154,6 +154,8 @@ def get_settings_parser(base_dir):
job_timeouts_parser.add_option("process_delete", type=int, default=3600)
job_timeouts_parser.add_option("reindex_corpus", type=int, default=7200)
job_timeouts_parser.add_option("notify_process_completion", type=int, default=120)
# Task execution in RQ timeouts after 10 hours by default
job_timeouts_parser.add_option("task", type=int, default=36000)
csrf_parser = parser.add_subparser("csrf", default={})
csrf_parser.add_option("cookie_name", type=str, default="arkindex.csrf")
......
......@@ -338,21 +338,17 @@ elif conf["cache"]["type"] == CacheType.Dummy:
}
}
_rq_queue_conf = {
"HOST": conf["redis"]["host"],
"PORT": conf["redis"]["port"],
"DB": conf["redis"]["db"],
"PASSWORD": conf["redis"]["password"],
"DEFAULT_TIMEOUT": conf["redis"]["timeout"],
}
RQ_QUEUES = {
"default": {
"HOST": conf["redis"]["host"],
"PORT": conf["redis"]["port"],
"DB": conf["redis"]["db"],
"PASSWORD": conf["redis"]["password"],
"DEFAULT_TIMEOUT": conf["redis"]["timeout"],
},
"high": {
"HOST": conf["redis"]["host"],
"PORT": conf["redis"]["port"],
"DB": conf["redis"]["db"],
"PASSWORD": conf["redis"]["password"],
"DEFAULT_TIMEOUT": conf["redis"]["timeout"],
}
"default": _rq_queue_conf,
"high": _rq_queue_conf,
"tasks": _rq_queue_conf,
}
RQ_TIMEOUTS = conf["job_timeouts"]
......@@ -504,6 +500,9 @@ PONOS_DEFAULT_ENV = _ponos_env
PONOS_PRIVATE_KEY = conf["ponos"]["private_key"]
PONOS_DEFAULT_FARM = conf["ponos"]["default_farm"]
PONOS_ARTIFACT_MAX_SIZE = conf["ponos"]["artifact_max_size"]
PONOS_RQ_EXECUTION = True
# Base data directory for RQ tasks execution (in the docker container)
PONOS_DATA_DIR = "/data"
# Docker images used by our ponos workflow
ARKINDEX_TASKS_IMAGE = conf["docker"]["tasks_image"]
......@@ -596,6 +595,9 @@ if TEST_ENV:
warnings.filterwarnings("error", category=RuntimeWarning, module="django.core.paginator")
warnings.filterwarnings("error", category=RuntimeWarning, module="rest_framework.pagination")
# Disable RQ tasks scheduler during tests
PONOS_RQ_EXECUTION = False
# Optional unit tests runner with code coverage
try:
import xmlrunner # noqa
......
......@@ -61,6 +61,7 @@ job_timeouts:
notify_process_completion: 120
process_delete: 3600
reindex_corpus: 7200
task: 36000
worker_results_delete: 3600
jwt_signing_key: null
local_imageserver_id: 1
......
......@@ -43,6 +43,7 @@ job_timeouts:
move_element:
a: b
reindex_corpus: {}
task: ''
worker_results_delete: null
jwt_signing_key: null
local_imageserver_id: 1
......
......@@ -24,6 +24,7 @@ job_timeouts:
export_corpus: "int() argument must be a string, a bytes-like object or a real number, not 'list'"
move_element: "int() argument must be a string, a bytes-like object or a real number, not 'dict'"
reindex_corpus: "int() argument must be a string, a bytes-like object or a real number, not 'dict'"
task: "invalid literal for int() with base 10: ''"
worker_results_delete: "int() argument must be a string, a bytes-like object or a real number, not 'NoneType'"
ponos:
artifact_max_size: cannot convert float NaN to integer
......
......@@ -75,7 +75,8 @@ job_timeouts:
notify_process_completion: 6
process_delete: 7
reindex_corpus: 8
worker_results_delete: 9
task: 9
worker_results_delete: 10
jwt_signing_key: deadbeef
local_imageserver_id: 45
metrics_port: 4242
......
......@@ -5,6 +5,7 @@ from typing import Literal, Optional, Union
from uuid import UUID
from django.db.models import Prefetch, prefetch_related_objects
from rq.job import Dependency
from arkindex.documents import export
from arkindex.documents import tasks as documents_tasks
......@@ -211,3 +212,17 @@ def notify_process_completion(process: Process):
process=process,
subject=f"Your process {process_name} finished {state_msg[state]}",
)
def schedule_tasks(process: Process, run: int):
"""Run tasks of a process in RQ, one by one"""
tasks = process.tasks.using("default").filter(run=run).order_by("depth", "id")
# Initially mark all tasks as pending
tasks.update(state=State.Pending)
# Build a simple dependency scheme between tasks, based on depth
parent_job = None
for task in tasks:
kwargs = {}
if parent_job:
kwargs["depends_on"] = Dependency(jobs=[parent_job], allow_failure=True)
parent_job = ponos_tasks.run_task_rq.delay(task, **kwargs)
......@@ -34,6 +34,12 @@ from arkindex.users.serializers import (
logger = logging.getLogger(__name__)
# Process tasks running in RQ are hidden from user jobs
VISIBLE_QUEUES = [
q for q in QUEUES.keys()
if q != "tasks"
]
@extend_schema(tags=["users"])
@extend_schema_view(
......@@ -304,7 +310,7 @@ class JobRetrieve(RetrieveDestroyAPIView):
serializer_class = JobSerializer
def get_object(self):
for queue_name in QUEUES.keys():
for queue_name in VISIBLE_QUEUES:
job = get_queue(queue_name).fetch_job(str(self.kwargs["pk"]))
if not job:
continue
......
......@@ -9,7 +9,9 @@ django-pgtrigger==4.7.0
django-rq==2.8.1
djangorestframework==3.12.4
djangorestframework-simplejwt==5.2.2
docker==7.0.0
drf-spectacular==0.18.2
python-magic==0.4.27
python-memcached==1.59
pytz==2023.3
PyYAML==6.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment