Skip to content
Snippets Groups Projects
tasks.py 9.66 KiB
import logging
import tempfile
from io import BytesIO
from pathlib import Path
from time import sleep
from urllib.parse import urljoin

import requests
from django.conf import settings
from django.core.mail import send_mail
from django.db.models import Count, F, Q
from django.db.models.functions import Round
from django.shortcuts import reverse
from django.template.loader import render_to_string
from django_rq import job

import docker
from arkindex.ponos.models import State, Task
from arkindex.ponos.utils import decompress_zst_archive, extract_tar_archive, upload_artifact
from arkindex.process.models import Process, WorkerActivityState
from docker.errors import APIError, ImageNotFound

logger = logging.getLogger(__name__)

# Delay for polling docker task's logs in seconds
TASK_DOCKER_POLLING = 1
# Timeout for HTTP requests: (connect timeout, read timeout)
REQUEST_TIMEOUT = (30, 60)


@job("default", timeout=settings.RQ_TIMEOUTS["notify_process_completion"])
def notify_process_completion(
    process: Process,
    subject: str = "Your process finished",
) -> None:
    current_run = process.get_last_run()
    # Inspect the state of tasks in the last run
    tasks_stats = {
        task.slug: task.state
        for task in process.tasks.all()
        # Run check is done in Python as tasks can be prefetched for the last run only
        if task.run == current_run
    }
    # Aggregate statistics about worker activities failures
    worker_failures = (
        process.activities
        .values("worker_version__worker__name")
        .annotate(
            failures=Count("id", filter=Q(state=WorkerActivityState.Error)),
            total=Count("id")
        )
        .annotate(
            percent=Round(100 * F("failures") / F("total"))
        )
        .filter(failures__gt=0)
        .values("worker_version__worker__name", "failures", "total", "percent")
    )

    send_mail(
        subject=subject,
        message=render_to_string(
            "process_completion.html",
            context={
                "process": process,
                "run": current_run,
                "tasks_stats": tasks_stats,
                "worker_failures": worker_failures,
                "url": urljoin(
                    settings.PUBLIC_HOSTNAME,
                    reverse("frontend-process-details", kwargs={
                        "pk": process.id,
                        "run": current_run,
                    })
                ),
            },
        ),
        from_email=None,
        recipient_list=[process.creator.email],
        fail_silently=False,
    )


def upload_logs(task, text):
    try:
        task.logs.s3_object.upload_fileobj(
            BytesIO(text),
            ExtraArgs={"ContentType": "text/plain; charset=utf-8"},
        )
    except Exception as e:
        logger.warning(f"Failed uploading logs for task {task}: {e}")


def download_extra_files(task) -> None:
    """
    Download the task's extra_files and store them in a dedicated `extra_files` folder.
    This folder is mounted as `PONOS_DATA_DIR/extra_files` in docker containers.
    If a downloaded file has a content-type of `application/zstd`, it is decompressed using zstandard.
    """
    # Download every declared extra file
    for path_name, file_url in task.extra_files.items():
        logger.info(f"Downloading file {path_name} using url: {file_url}")

        # Download file using the provided url
        with requests.get(file_url, stream=True, timeout=REQUEST_TIMEOUT) as resp:
            resp.raise_for_status()

            # Write file to a specific data directory
            extra_files_dir = settings.PONOS_DATA_DIR / "extra_files"
            extra_files_dir.mkdir(exist_ok=True)
            with open(extra_files_dir / path_name, "wb") as f:
                for chunk in resp.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)

        if resp.headers["Content-Type"] == "application/zstd":
            # If type is `application/zstd`, decompress using zstandard
            archive_fd, archive_path = decompress_zst_archive(
                compressed_archive=extra_files_dir / path_name,
            )
            # Extract Tar archive
            extract_tar_archive(
                archive_path=archive_path,
                archive_fd=archive_fd,
                destination=extra_files_dir,
            )


def run_docker_task(client, task, temp_dir):
    # 1. Pull the docker image
    logger.debug(f"Pulling docker image '{task.image}'")
    try:
        client.images.pull(task.image)
    except (ImageNotFound, APIError) as e:
        # Pulling is allowed to fail when the image is already present locally (local builds)
        if not client.images.list(task.image):
            raise Exception(f"Image not found locally nor remotely: {e}")
        logger.info("Remote image could not be fetched, using the local image.")

    # 2. Fetch artifacts
    logger.info("Fetching artifacts from parents")
    for parent in task.parents.order_by("depth", "id"):
        folder = temp_dir / str(parent.slug)
        folder.mkdir()
        for artifact in parent.artifacts.all():
            path = (folder / artifact.path).resolve()
            # Ensure path is a children of folder
            assert str(folder.resolve()) in str(path.resolve()), "Invalid artifact path: {artifact.path}."
            artifact.download_to(str(path))

    # 3. Do run the container asynchronously
    logger.debug("Running container")
    kwargs = {
        "environment": {
            **task.env,
            "PONOS_DATA": settings.PONOS_DATA_DIR,
        },
        "detach": True,
        "network": "host",
        "volumes": {temp_dir: {"bind": str(settings.PONOS_DATA_DIR), "mode": "rw"}},
    }
    artifacts_dir = temp_dir / str(task.id)
    artifacts_dir.mkdir()
    # The symlink will only work within docker context as bound to PONOS_DATA_DIR/<task_uuid>/
    (temp_dir / "current").symlink_to(Path(settings.PONOS_DATA_DIR) / str(task.id))

    if task.requires_gpu:
        # Assign all GPUs to that container
        # https://github.com/docker/docker-py/issues/2395#issuecomment-907243275
        kwargs["device_requests"] = [
            docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])
        ]
        logger.info("Starting container with GPU support")
    if task.command is not None:
        kwargs["command"] = task.command
    container = client.containers.run(task.image, **kwargs)
    while container.status == "created":
        container.reload()
    task.state = State.Running
    task.save()

    # 4. Download extra_files
    if task.extra_files:
        logger.info("Downloading extra_files for task {!s}".format(task))
        try:
            download_extra_files(task)
        except Exception as e:
            logger.warning(
                "Failed downloading extra_files for task {!s}: {!s}".format(
                    task, e
                )
            )
            task.state = State.Error
            task.save()
            return

    # 5. Read logs
    logger.debug("Reading logs from the docker container")
    previous_logs = b""
    while container.status == "running":
        logs = container.logs()
        if logs != previous_logs:
            upload_logs(task, logs)
            previous_logs = logs
        # Handle a task that is being stopped during execution
        task.refresh_from_db()
        if task.state == State.Stopping:
            container.stop()
            task.state = State.Stopped
            task.save()
            break
        sleep(TASK_DOCKER_POLLING)
        container.reload()
    # Upload logs one last time so we do not miss any data
    upload_logs(task, container.logs())

    # 6. Retrieve the state of the container
    container.reload()
    exit_code = container.attrs["State"]["ExitCode"]
    if exit_code != 0:
        logger.info("Task failed")
        task.state = State.Failed
        task.save()
        return

    task.state = State.Completed
    task.save()

    # 7. Upload artifacts
    logger.info(f"Uploading artifacts for task {task}")

    for path in Path(artifacts_dir).glob("**/*"):
        if path.is_dir():
            continue
        try:
            upload_artifact(task, path, artifacts_dir)
        except Exception as e:
            logger.warning(
                f"Failed uploading artifact {path} for task {task}: {e}"
            )


@job("tasks", timeout=settings.RQ_TIMEOUTS["task"])
def run_task_rq(task: Task):
    """Run a single task in RQ"""
    # Update task and parents from the DB
    task.refresh_from_db()
    parents = list(task.parents.order_by("depth", "id"))

    client = docker.from_env()

    if not task.image:
        raise ValueError(f"Task {task} has no docker image.")

    if task.state != State.Pending:
        raise ValueError(f"Task {task} must be in pending state to run in RQ.")

    # Automatically update children in case an error occurred
    if (parent_state := next(
        (parent.state for parent in parents if parent.state in (State.Stopped, State.Error, State.Failed)),
        None
    )) is not None:
        task.state = parent_state
        task.save()
        return

    with tempfile.TemporaryDirectory(suffix=f"_{task.id}") as temp_dir:
        try:
            run_docker_task(client, task, Path(temp_dir))
        except Exception as e:
            logger.error(f"An unexpected error occurred, updating state to Error: {e}")
            task.state = State.Error
            task.save()
            # Add unexpected error details to task logs
            text = BytesIO()
            if task.logs.exists():
                task.logs.s3_object.download_fileobj(text)
            text = text.getvalue()
            text += f"\nPonos exception: {e}".encode()
            upload_logs(task, text)
            raise e