diff --git a/arkindex/ponos/__init__.py b/arkindex/ponos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..071fa5fb3e4b85dd8cdc47bb5ee63de475a7031d --- /dev/null +++ b/arkindex/ponos/__init__.py @@ -0,0 +1,3 @@ +from ponos.helpers import get_ponos_run, get_ponos_task_id, is_ponos_task # noqa: F401 + +__all__ = ["is_ponos_task", "get_ponos_task_id", "get_ponos_run"] diff --git a/arkindex/ponos/admin.py b/arkindex/ponos/admin.py new file mode 100644 index 0000000000000000000000000000000000000000..d39d3f0ff14ae4ca579c52d011f8143a95114ae5 --- /dev/null +++ b/arkindex/ponos/admin.py @@ -0,0 +1,283 @@ +from django import forms +from django.contrib import admin, messages +from django.core.exceptions import ValidationError +from django.db.models import Max, TextField +from enumfields.admin import EnumFieldListFilter + +from ponos.keys import gen_nonce +from ponos.models import ( + FINAL_STATES, + GPU, + Agent, + Artifact, + Farm, + Secret, + Task, + Workflow, + encrypt, +) + + +class ArtifactInline(admin.TabularInline): + model = Artifact + fields = ("id", "path", "size", "content_type", "created") + readonly_fields = fields + extra = 0 + + +class TaskAdmin(admin.ModelAdmin): + list_display = ( + "id", + "slug", + "run", + "state", + "workflow_id", + "updated", + "agent", + "tags", + "priority", + ) + list_filter = [("state", EnumFieldListFilter), "agent"] + list_select_related = ("agent",) + inlines = [ + ArtifactInline, + ] + readonly_fields = ( + "id", + "created", + "updated", + "container", + "short_logs", + "shm_size", + ) + fieldsets = ( + ( + None, + { + "fields": ( + "id", + "slug", + "run", + "depth", + "state", + "workflow", + "agent", + "tags", + "priority", + ), + }, + ), + ("GPU", {"fields": ("requires_gpu", "gpu")}), + ("Dates", {"fields": ("created", "updated", "expiry")}), + ( + "Docker", + { + "fields": ( + "image", + "command", + "env", + "container", + "shm_size", + "short_logs", + "extra_files", + ) + }, + ), + ) + + +class TaskInline(admin.TabularInline): + model = Task + fields = ("id", "slug", "run", "depth", "state", "container") + readonly_fields = ("slug", "run", "depth", "container") + extra = 0 + + +def workflow_retry(modeladmin, request, queryset): + """ + Retry selected workflows + """ + for w in queryset.all(): + w.retry() + + +class WorkflowAdmin(admin.ModelAdmin): + list_display = ("id", "updated", "state") + actions = (workflow_retry,) + readonly_fields = ( + "id", + "state", + ) + inlines = [ + TaskInline, + ] + + # Use a monospace font for the workflow recipe + formfield_overrides = { + TextField: { + "widget": admin.widgets.AdminTextareaWidget( + attrs={"style": "font-family: monospace"} + ) + }, + } + + def get_queryset(self, *args, **kwargs): + return ( + super() + .get_queryset(*args, **kwargs) + .prefetch_related("tasks") + .annotate(last_run=Max("tasks__run")) + ) + + +class GPUInline(admin.TabularInline): + model = GPU + fields = ("id", "name", "index", "ram_total") + readonly_fields = fields + extra = 0 + + +class AgentAdmin(admin.ModelAdmin): + model = Agent + list_display = ("id", "hostname") + inlines = (GPUInline,) + readonly_fields = ( + "id", + "created", + "updated", + "last_ping", + # Use custom admin fields to format total RAM and CPU max frequency + "ram_total_human", + "cpu_frequency_human", + "cpu_load", + "ram_load", + "cpu_cores", + "public_key", + ) + fieldsets = ( + ( + None, + {"fields": ("id", "farm", "hostname")}, + ), + ("Dates", {"fields": ("created", "updated", "last_ping")}), + ("Auth", {"fields": ("public_key",)}), + ( + "Hardware", + { + "fields": ( + "ram_total_human", + "cpu_frequency_human", + "cpu_cores", + "cpu_load", + "ram_load", + ) + }, + ), + ) + + def has_add_permission(self, request): + return False + + def ram_total_human(self, instance): + """Returns total amount of RAM expressed in GiB""" + return "{:.1f} GiB".format((instance.ram_total or 0) / (1024**3)) + + def cpu_frequency_human(self, instance): + """Returns CPU max frequency expressed in GHz""" + return "{:.1f} GHz".format((instance.cpu_frequency or 0) / 1e9) + + # Overrides both admin deletion methods to show error messages when agents + # have tasks in non-final states. This will display a "deleted successfully" + # message along the errors, because the Django admin does not make it easy to + # remove this message, but nothing gets actually deleted. + + def delete_model(self, request, agent): + try: + super().delete_model(request, agent) + except ValidationError as e: + messages.error(request, e.message) + + def delete_queryset(self, request, queryset): + hostnames = ( + Task.objects.filter(agent__in=queryset) + .exclude(state__in=FINAL_STATES) + .values_list("agent__hostname", flat=True) + .distinct() + ) + if hostnames: + messages.error( + request, + "The following agents have tasks in non-final states and cannot be deleted: {}".format( + ", ".join(list(hostnames)) + ), + ) + else: + super().delete_queryset(request, queryset) + + +class FarmAdmin(admin.ModelAdmin): + model = Farm + list_display = ("id", "name") + fields = ("id", "name", "seed") + readonly_fields = ("id",) + + +class ClearTextSecretForm(forms.ModelForm): + """ + Allow an administrator to edit a secret content as a cleartext value + A nonce is generated for newly created secrets + """ + + content = forms.CharField(widget=forms.Textarea, required=True) + + def __init__(self, *args, **kwargs): + self.instance = kwargs.get("instance") + # Set initial decrypted value + if self.instance is not None: + if "initial" not in kwargs: + kwargs["initial"] = {} + kwargs["initial"]["content"] = self.instance.decrypt() + + super().__init__(*args, **kwargs) + + def clean_name(self): + # Check that name is not already used + secrets = Secret.objects.all() + if self.instance: + secrets = Secret.objects.exclude(pk=self.instance.pk) + + if secrets.filter(name=self.cleaned_data["name"]).exists(): + raise ValidationError("A secret with this name already exists.") + + return self.cleaned_data["name"] + + def clean(self): + if not self.cleaned_data.get("content"): + raise ValidationError("You must specify some content") + + if self.instance: + nonce = self.instance.nonce + else: + nonce = gen_nonce() + + # Encrypt secret content + encrypted_content = encrypt(nonce, self.cleaned_data["content"]) + return {**self.cleaned_data, "nonce": nonce, "content": encrypted_content} + + class Meta: + model = Secret + fields = ("id", "name", "content") + + +class SecretAdmin(admin.ModelAdmin): + form = ClearTextSecretForm + + fields = ("name", "content") + + +admin.site.register(Task, TaskAdmin) +admin.site.register(Workflow, WorkflowAdmin) +admin.site.register(Agent, AgentAdmin) +admin.site.register(Farm, FarmAdmin) +admin.site.register(Secret, SecretAdmin) +workflow_retry.short_description = "Retry all selected workflows (a new run is created)" diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py new file mode 100644 index 0000000000000000000000000000000000000000..1696b9d4efb808830ce99c72e04f8f07fbbdd065 --- /dev/null +++ b/arkindex/ponos/api.py @@ -0,0 +1,387 @@ +import hashlib +import uuid +from collections import defaultdict +from textwrap import dedent + +from django.core.exceptions import PermissionDenied +from django.db.models import Count, Max, Q +from django.shortcuts import get_object_or_404, redirect +from django.utils import timezone +from drf_spectacular.utils import OpenApiExample, extend_schema, extend_schema_view +from rest_framework.exceptions import ValidationError +from rest_framework.generics import ( + CreateAPIView, + ListAPIView, + ListCreateAPIView, + RetrieveAPIView, + RetrieveUpdateAPIView, + UpdateAPIView, +) +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework_simplejwt.views import TokenRefreshView + +from ponos.authentication import AgentAuthentication +from ponos.keys import load_private_key +from ponos.models import Agent, Artifact, Farm, Secret, State, Task, Workflow +from ponos.permissions import IsAgent, IsAssignedAgentOrReadOnly +from ponos.recipe import parse_recipe +from ponos.renderers import PublicKeyPEMRenderer +from ponos.serializers import ( + AgentActionsSerializer, + AgentCreateSerializer, + AgentDetailsSerializer, + AgentLightSerializer, + AgentStateSerializer, + ArtifactSerializer, + ClearTextSecretSerializer, + FarmSerializer, + NewTaskSerializer, + TaskDefinitionSerializer, + TaskSerializer, + TaskTinySerializer, + WorkflowSerializer, +) + + +class PublicKeyEndpoint(APIView): + """ + Fetch the server's public key in PEM format to perform agent registration. + """ + + renderer_classes = (PublicKeyPEMRenderer,) + + @extend_schema( + operation_id="GetPublicKey", + responses={200: {"type": "string"}}, + tags=["ponos"], + auth=[{"agentAuth": []}], + examples=[ + OpenApiExample( + name="Public key response", + response_only=True, + media_type="application/x-pem-file", + status_codes=["200"], + value=dedent( + """ + -----BEGIN PUBLIC KEY----- + MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEmK2L6lwGzSVZwFSo0eR1z4XV6jJwjeWK + YCiPKdMcQnn6u5J016k9U8xZm6XyFnmgvkhnC3wreGBTFzwLCLZCD+F3vo5x8ivz + aTgNWsA3WFlqjSIEGz+PAVHSNMobBaJm + -----END PUBLIC KEY----- + """ + ), + ) + ], + ) + def get(self, request, *args, **kwargs): + return Response(load_private_key().public_key()) + + +@extend_schema(tags=["ponos"]) +@extend_schema_view( + get=extend_schema(description="Retrieve a Ponos workflow status"), + put=extend_schema(description="Update a workflow's status and tasks"), + patch=extend_schema(description="Partially update a workflow"), +) +class WorkflowDetails(RetrieveUpdateAPIView): + """ + Retrieve information about a workflow, or update its state. + Updating a workflow's state to :attr:`~ponos.models.State.Stopping` will cause it to stop. + """ + + queryset = Workflow.objects.prefetch_related("tasks__parents").annotate( + last_run=Max("tasks__run") + ) + serializer_class = WorkflowSerializer + + def perform_update(self, serializer): + serializer.instance.stop() + + +@extend_schema(tags=["ponos"]) +@extend_schema_view( + get=extend_schema( + operation_id="RetrieveTaskFromAgent", description="Retrieve a Ponos task status" + ), + put=extend_schema( + operation_id="UpdateTaskFromAgent", description="Update a task, from an agent" + ), + patch=extend_schema( + operation_id="PartialUpdateTaskFromAgent", + description="Partially update a task, from an agent", + ), +) +class TaskDetailsFromAgent(RetrieveUpdateAPIView): + """ + Retrieve information about a single task, including its logs. + Authenticated agents assigned to a task can use this endpoint to report its state. + """ + + # Avoid stale read when a recently assigned agent wants to update + # the state of one of its tasks + queryset = Task.objects.all().using("default") + permission_classes = (IsAssignedAgentOrReadOnly,) + serializer_class = TaskSerializer + + +@extend_schema_view( + post=extend_schema( + operation_id="CreateAgent", + tags=["ponos"], + ), +) +class AgentRegister(CreateAPIView): + """ + Perform agent registration and authentication. + """ + + serializer_class = AgentCreateSerializer + + def get_object(self): + if not hasattr(self.request, "data") or "public_key" not in self.request.data: + return + key_hash = uuid.UUID( + hashlib.md5(self.request.data["public_key"].encode("utf-8")).hexdigest() + ) + return Agent.objects.filter(id=key_hash).first() + + def get_serializer(self, *args, **kwargs): + return super().get_serializer(self.get_object(), *args, **kwargs) + + +@extend_schema_view( + post=extend_schema( + operation_id="RefreshAgentToken", + tags=["ponos"], + ) +) +class AgentTokenRefresh(TokenRefreshView): + """ + Refresh a Ponos agent token when it expires + """ + + +@extend_schema(tags=["ponos"]) +class AgentDetails(RetrieveAPIView): + """ + Retrieve details of an agent including its running tasks + """ + + serializer_class = AgentDetailsSerializer + queryset = Agent.objects.all() + + +@extend_schema( + description="List the state of all Ponos agents", + tags=["ponos"], +) +class AgentsState(ListAPIView): + """ + List all agents on the system with their health state. + No authentication nor permission is required to read agents state. + """ + + serializer_class = AgentStateSerializer + + queryset = ( + Agent.objects.all() + .annotate( + running_tasks_count=Count("tasks", filter=Q(tasks__state=State.Running)) + ) + .prefetch_related("farm") + .order_by("hostname") + ) + + +@extend_schema(tags=["ponos"]) +class AgentActions(RetrieveAPIView): + """ + Fetch the next actions an agent should perform. + """ + + permission_classes = (IsAgent,) + authentication_classes = (AgentAuthentication,) + serializer_class = AgentActionsSerializer + + def get_object(self): + return self.request.user + + def retrieve(self, request, *args, **kwargs): + # Update agent load and last_ping timestamp + errors = defaultdict(list) + cpu_load = request.query_params.get("cpu_load") + ram_load = request.query_params.get("ram_load") + if not cpu_load: + errors["cpu_load"].append("This query parameter is required.") + if not ram_load: + errors["ram_load"].append("This query parameter is required.") + if errors: + raise ValidationError(errors) + # Handle fields validation with DRF as for a PATCH + agent_serializer = AgentLightSerializer( + self.request.user, + data={"cpu_load": cpu_load, "ram_load": ram_load}, + partial=True, + ) + agent_serializer.is_valid(raise_exception=True) + # Update agent load and last_ping attributes + agent_serializer.save(last_ping=timezone.now()) + # Retrieve next tasks after the DB has been updated with the new agent load + response = super().retrieve(self, request, *args, **kwargs) + return response + + +@extend_schema_view( + get=extend_schema( + operation_id="RetrieveTaskDefinition", + tags=["ponos"], + ) +) +class TaskDefinition(RetrieveAPIView): + """ + Obtain a task's definition as an agent or admin. + This holds all the required data to start a task, except for the artifacts. + """ + + # We need to specify the default database to avoid stale reads + # when a task is updated by an agent, then the agent immediately fetches its definition + queryset = Task.objects.using("default").select_related("workflow") + permission_classes = (IsAgent,) + serializer_class = TaskDefinitionSerializer + + +@extend_schema(tags=["ponos"]) +@extend_schema_view( + get=extend_schema( + operation_id="ListArtifacts", + description="List all the artifacts of a task", + ), + post=extend_schema( + operation_id="CreateArtifact", description="Create an artifact on a task" + ), +) +class TaskArtifacts(ListCreateAPIView): + """ + List all artifacts linked to a task or create one + """ + + # Used for OpenAPI schema serialization: the ID in the path is the task ID + queryset = Task.objects.none() + permission_classes = (IsAuthenticated,) + serializer_class = ArtifactSerializer + + # Force no pagination, even when global settings add them + pagination_class = None + + def get_task(self): + if "pk" not in self.kwargs: + return Artifact.objects.none() + return get_object_or_404(Task, pk=self.kwargs["pk"]) + + def get_queryset(self): + task = self.get_task() + return task.artifacts.all() + + def perform_create(self, serializer): + user = self.request.user + if not user.is_staff and not getattr(user, "is_agent", False): + raise PermissionDenied() + + # Assign task when creating through the API + serializer.save(task=self.get_task()) + + +class TaskArtifactDownload(APIView): + """ + Redirect to the S3 url of an artifact in order to download it + """ + + permission_classes = (IsAgent,) + + def get_object(self, pk, path): + artifact = get_object_or_404(Artifact, task_id=pk, path=path) + self.check_object_permissions(self.request, artifact) + return artifact + + @extend_schema( + operation_id="DownloadArtifact", + responses={302: None}, + tags=["ponos"], + ) + def get(self, request, *args, **kwargs): + artifact = self.get_object(*args, **kwargs) + return redirect(artifact.s3_get_url) + + +@extend_schema_view( + post=extend_schema( + operation_id="CreateTask", + tags=["ponos"], + ) +) +class TaskCreate(CreateAPIView): + """ + Create a task with a parent + """ + + serializer_class = NewTaskSerializer + + def perform_create(self, serializer): + """ + Merge workflow's recipe environment variables with the ones set on the task + Task variables have priority over Workflow variables + """ + env, _ = parse_recipe( + Workflow.objects.get(id=serializer.validated_data["workflow_id"]).recipe + ) + env.update(serializer.validated_data["env"]) + serializer.save(env=env) + + +@extend_schema(tags=["ponos"]) +@extend_schema_view( + put=extend_schema( + description="Update a task, allowing humans to change the task's state" + ), + patch=extend_schema( + description="Partially update a task, allowing humans to change the task's state" + ), +) +class TaskUpdate(UpdateAPIView): + """ + Admins and task creators can use this endpoint to update a task + Permissions must be implemented by the top-level Django application + """ + + queryset = Task.objects.all() + serializer_class = TaskTinySerializer + + +@extend_schema_view( + get=extend_schema( + operation_id="RetrieveSecret", + tags=["ponos"], + ) +) +class SecretDetails(RetrieveAPIView): + """ + Retrieve a Ponos secret content as cleartext + """ + + permission_classes = (IsAgent,) + serializer_class = ClearTextSecretSerializer + + def get_object(self): + return get_object_or_404(Secret, name=self.kwargs.get("name")) + + +@extend_schema(tags=["ponos"]) +class FarmList(ListAPIView): + """ + List all available farms + """ + + serializer_class = FarmSerializer + queryset = Farm.objects.order_by("name") diff --git a/arkindex/ponos/apps.py b/arkindex/ponos/apps.py new file mode 100644 index 0000000000000000000000000000000000000000..8f83f089c26820d70aa5f18cf85be81fd77f57b0 --- /dev/null +++ b/arkindex/ponos/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class PonosConfig(AppConfig): + name = "ponos" + verbose_name = "Ponos: tasks manager" diff --git a/arkindex/ponos/authentication.py b/arkindex/ponos/authentication.py new file mode 100644 index 0000000000000000000000000000000000000000..327f2b11809e1d03fe16b069551561bf323735fe --- /dev/null +++ b/arkindex/ponos/authentication.py @@ -0,0 +1,73 @@ +from drf_spectacular.contrib.rest_framework_simplejwt import SimpleJWTScheme +from rest_framework.exceptions import AuthenticationFailed +from rest_framework_simplejwt.authentication import JWTAuthentication +from rest_framework_simplejwt.exceptions import InvalidToken +from rest_framework_simplejwt.settings import api_settings + +from ponos.models import Agent + + +class AgentUser(Agent): + """ + A proxy model to implement the Django User interface on a Ponos agent. + Allows Django REST Framework's usual permissions (like IsAuthenticated) to work. + """ + + is_staff: bool = False + is_superuser: bool = False + is_active: bool = True + is_agent: bool = True + + @property + def username(self) -> str: + return str(self.id) + + @property + def is_anonymous(self) -> bool: + return False + + @property + def is_authenticated(self) -> bool: + return True + + def get_username(self) -> str: + return self.username + + class Meta: + proxy = True + + def save(self, *args, update_fields=None, **kwargs): + """ + django.contrib.auth adds a 'user_logged_in' signal and a 'update_last_login' receiver, + which tries to update the User.last_login field. + The receiver is only added when the User model has a last_login field, which depends on + the Django app integrating Ponos and not on Ponos itself, so we have to handle it. + The receiver calls .save(update_fields=['last_login']), so we remove it. + An empty `update_fields` will cause a save to abort silently. + """ + if update_fields is not None and "last_login" in update_fields: + update_fields = set(update_fields) - {"last_login"} + return super().save(*args, update_fields=update_fields, **kwargs) + + +class AgentAuthentication(JWTAuthentication): + """ + Allows authenticating as a Ponos agent using a JSON Web Token. + """ + + def get_user(self, validated_token): + if api_settings.USER_ID_CLAIM not in validated_token: + raise InvalidToken("Token does not hold agent information") + try: + return AgentUser.objects.get(id=validated_token[api_settings.USER_ID_CLAIM]) + except AgentUser.DoesNotExist: + raise AuthenticationFailed("Agent not found") + + +class AgentAuthenticationExtension(SimpleJWTScheme): + """ + drf_spectacular extension to make it recognize the AgentAuthentication. + """ + + target_class = "ponos.authentication.AgentAuthentication" + name = "agentAuth" diff --git a/arkindex/ponos/aws.py b/arkindex/ponos/aws.py new file mode 100644 index 0000000000000000000000000000000000000000..1d22e777afac5808c15906c74e75c565361e9d8f --- /dev/null +++ b/arkindex/ponos/aws.py @@ -0,0 +1,29 @@ +from boto3.session import Session +from botocore.config import Config +from django.conf import settings + +session = Session( + aws_access_key_id=settings.PONOS_AWS_ACCESS_KEY, + aws_secret_access_key=settings.PONOS_AWS_SECRET_KEY, +) + +config = Config( + region_name=settings.PONOS_AWS_REGION, + signature_version="s3v4", +) + +s3 = session.resource( + "s3", + endpoint_url=settings.PONOS_AWS_ENDPOINT, + config=config, +) + + +def object_url(action: str, s3_obj) -> str: + """ + Helper to obtain a presigned URL for a client method such as ``get_object`` + on an S3 Object instance. + """ + return s3.meta.client.generate_presigned_url( + action, Params={"Bucket": s3_obj.bucket_name, "Key": s3_obj.key} + ) diff --git a/arkindex/ponos/fields.py b/arkindex/ponos/fields.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d91ce39427fd6a45d404341ff3f2cc7f6aa1fb --- /dev/null +++ b/arkindex/ponos/fields.py @@ -0,0 +1,56 @@ +from django.core.exceptions import ValidationError +from django.db.models import CharField, JSONField + +from ponos import forms + + +class CommaSeparatedListField(CharField): + """ + Store a list of strings as a single comma-separated string into a CharField + """ + + def from_db_value(self, value, *args): + return self.to_python(value) + + def to_python(self, value): + value = super().to_python(value) + if value is None: + return value + if not value: + return [] + return value.split(",") + + def get_prep_value(self, value): + if not isinstance(value, str) and value is not None: + value = ",".join(map(str, value)) + return value + + +class StringDictField(JSONField): + """ + A JSONField that only accepts objects with string values, and allows empty objects. + """ + + empty_values = [None, "", [], ()] + + def validate(self, value, model_instance): + super().validate(value, model_instance) + + if not isinstance(value, dict): + raise ValidationError( + "Field value should be an object.", + code="invalid", + params={"value": value}, + ) + + if not all(isinstance(v, str) for v in value.values()): + raise ValidationError( + "All object values should be strings.", + code="invalid", + params={"value": value}, + ) + + def formfield(self, **kwargs): + # Ponos does not use forms, but this method allows the Django admin to use a custom form field + # that will allow an empty object to be used. + return super().formfield(**{"form_class": forms.EmptyableJSONField, **kwargs}) diff --git a/arkindex/ponos/forms.py b/arkindex/ponos/forms.py new file mode 100644 index 0000000000000000000000000000000000000000..80e3b4a5fc796e8d4973677d7a79e14e03f8b75b --- /dev/null +++ b/arkindex/ponos/forms.py @@ -0,0 +1,10 @@ +from django.forms import JSONField + + +class EmptyableJSONField(JSONField): + """ + A JSONField that accepts empty objects. + """ + + # JSONField.empty_values, but allowing the empty dict + empty_values = [None, "", [], ()] diff --git a/arkindex/ponos/helpers.py b/arkindex/ponos/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..431082ee42113b254e5333687c368867a429b9a1 --- /dev/null +++ b/arkindex/ponos/helpers.py @@ -0,0 +1,37 @@ +import os +from uuid import UUID + + +def get_ponos_run(): + """ + Obtain the Ponos run number inside a task from the environment variables, + provided the agent implementation sets the ``PONOS_RUN`` variable. + + :returns: A positive or zero integer representing the task's run, + or None if ``PONOS_RUN`` was not defined. + """ + value = os.environ.get("PONOS_RUN") + if value is None: + return + return int(value) + + +def get_ponos_task_id(): + """ + Obtain the Ponos task ID when running inside a task from an environment variable, + provided the agent implementation sets the ``PONOS_TASK`` variable. + + :returns: An UUID representing the task's ID, + or None if ``PONOS_TASK`` was not deefined. + """ + value = os.environ.get("PONOS_TASK") + if not value: + return + return UUID(value) + + +def is_ponos_task() -> bool: + """ + Check whether this process is running inside a Ponos task container. + """ + return isinstance(get_ponos_task_id(), UUID) diff --git a/arkindex/ponos/keys.py b/arkindex/ponos/keys.py new file mode 100644 index 0000000000000000000000000000000000000000..31d77be0cd4cbd9d9e266157758111dcb94b9713 --- /dev/null +++ b/arkindex/ponos/keys.py @@ -0,0 +1,117 @@ +import logging +import os.path +from os import urandom + +from cryptography.exceptions import InvalidKey +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.hashes import SHA256 +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, + load_pem_private_key, +) +from django.conf import settings + +logger = logging.getLogger(__name__) + + +def gen_nonce(size=16): + """ + Generates a simple nonce + Number size si defined in bytes (defaults to 128 bits) + https://cryptography.io/en/latest/glossary/#term-nonce + """ + return urandom(size) + + +def gen_private_key(dest_path) -> None: + """ + Generates an elliptic curve private key and saves it to a local file in PEM format. + See https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/ + + :param dest_path: Path to save the new key to. + :type dest_path: str or path-like object + """ + key = ec.generate_private_key( + ec.SECP384R1(), + default_backend(), + ) + + with open(dest_path, "wb") as f: + f.write( + key.private_bytes( + Encoding.PEM, + PrivateFormat.PKCS8, + NoEncryption(), + ) + ) + + +def load_private_key(): + """ + Load an existing private key from the path given in the ``PONOS_PRIVATE_KEY`` setting. + + :returns: An elliptic curve private key instance. + :raises Exception: When the Django ``DEBUG`` setting is set to False + and the server is misconfigured or the key is not found or invalid + """ + + def _abort(message): + """ + On Debug, be nice with developers, just display a warning + On Prod, simply crash + """ + if getattr(settings, "DEBUG", False): + logger.warning("Please fix your security configuration: {}".format(message)) + else: + raise Exception(message) + + if not getattr(settings, "PONOS_PRIVATE_KEY", None): + return _abort("Missing setting PONOS_PRIVATE_KEY") + + if not os.path.exists(settings.PONOS_PRIVATE_KEY): + return _abort( + "Invalid PONOS_PRIVATE_KEY path: {}".format(settings.PONOS_PRIVATE_KEY) + ) + + with open(settings.PONOS_PRIVATE_KEY, "rb") as f: + key = load_pem_private_key( + f.read(), + password=None, + backend=default_backend(), + ) + assert isinstance( + key, ec.EllipticCurvePrivateKey + ), "Private {} key is not an ECDH key".format(settings.PONOS_PRIVATE_KEY) + return key + + +def check_agent_key(agent_public_key, agent_derivation, seed) -> bool: + """ + Authenticates a new agent using its public key and a derivation + of its private key with the server's public key. + + :param agent_public_key: An agent's public key. + :type agent_public_key: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey + :param agent_derivation bytes: A bytestring representing the server's public key + derived with the agent's private key using HKDF and holding a :class:`~`onos.models.Farm`'s seed. + :param seed str: The expected farm seed. + """ + shared_key = load_private_key().exchange(ec.ECDH(), agent_public_key) + + hkdf = HKDF( + algorithm=SHA256(), + backend=default_backend(), + length=32, + salt=None, + info=seed.encode("utf-8"), + ) + + try: + hkdf.verify(shared_key, agent_derivation) + return True + except InvalidKey: + return False diff --git a/arkindex/ponos/management/__init__.py b/arkindex/ponos/management/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/arkindex/ponos/management/commands/__init__.py b/arkindex/ponos/management/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/arkindex/ponos/management/commands/generate_private_key.py b/arkindex/ponos/management/commands/generate_private_key.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ca0ae550612c6551acbf59d00caf1a6f632c1e --- /dev/null +++ b/arkindex/ponos/management/commands/generate_private_key.py @@ -0,0 +1,19 @@ +from django.core.management.base import BaseCommand + +from ponos.keys import gen_private_key + + +class Command(BaseCommand): + help = "Generate a Ponos server private key" + + def add_arguments(self, parser): + parser.add_argument( + "path", + help="Destination file for the key in PEM format.", + ) + + def handle(self, *args, path=None, **kwargs): + gen_private_key(path) + self.stdout.write( + self.style.SUCCESS("Generated a new private key to {}".format(path)) + ) diff --git a/arkindex/ponos/managers.py b/arkindex/ponos/managers.py new file mode 100644 index 0000000000000000000000000000000000000000..afd9badf052529d60127300936683028a8f30dcd --- /dev/null +++ b/arkindex/ponos/managers.py @@ -0,0 +1,36 @@ +from django.db import connections +from django.db.models import Manager + + +class TaskManager(Manager): + def parents(self, task): + """ + List all tasks that a task depends on, directly or not. + + When using a PostgreSQL database, this will use a recursive query. + This falls back to a recursive generator on other databases. + """ + if connections[self.db].vendor == "postgresql": + return self.raw( + """ + with recursive parents(id) as ( + select to_task_id + from ponos_task_parents j + where j.from_task_id = %s + union all + select j.to_task_id + from ponos_task_parents as j, parents + where j.from_task_id = parents.id + ) + select distinct * from parents + """, + [task.id], + ) + else: + + def _parents_python(task): + yield from task.parents.all() + for parent in task.parents.all(): + yield from _parents_python(parent) + + return _parents_python(task) diff --git a/arkindex/ponos/migrations/0001_initial.py b/arkindex/ponos/migrations/0001_initial.py new file mode 100644 index 0000000000000000000000000000000000000000..f9af1350750aaa06ccedad4c2e76ca5da4fe885f --- /dev/null +++ b/arkindex/ponos/migrations/0001_initial.py @@ -0,0 +1,80 @@ +# Generated by Django 2.1.4 on 2018-12-20 16:03 + +import uuid + +import django.db.models.deletion +import enumfields.fields +from django.db import migrations, models + +import ponos.models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="Task", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, primary_key=True, serialize=False + ), + ), + ("run", models.PositiveIntegerField()), + ("depth", models.PositiveIntegerField()), + ("slug", models.CharField(max_length=250)), + ( + "state", + enumfields.fields.EnumField( + default="unscheduled", enum=ponos.models.State, max_length=20 + ), + ), + ("container", models.CharField(blank=True, max_length=64, null=True)), + ("created", models.DateTimeField(auto_now_add=True)), + ("updated", models.DateTimeField(auto_now=True)), + ( + "parent", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="children", + to="ponos.Task", + ), + ), + ], + options={"ordering": ("workflow", "run", "depth", "slug")}, + ), + migrations.CreateModel( + name="Workflow", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, primary_key=True, serialize=False + ), + ), + ("recipe", models.TextField()), + ("created", models.DateTimeField(auto_now_add=True)), + ("updated", models.DateTimeField(auto_now=True)), + ], + ), + migrations.AddField( + model_name="task", + name="workflow", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="tasks", + to="ponos.Workflow", + ), + ), + migrations.AlterUniqueTogether( + name="task", + unique_together={("workflow", "run", "slug")}, + ), + ] diff --git a/arkindex/ponos/migrations/0002_recipe_validator.py b/arkindex/ponos/migrations/0002_recipe_validator.py new file mode 100644 index 0000000000000000000000000000000000000000..80d380e4fb003b052d81a38f065bfa6c55614ab9 --- /dev/null +++ b/arkindex/ponos/migrations/0002_recipe_validator.py @@ -0,0 +1,20 @@ +# Generated by Django 2.1.7 on 2019-06-24 08:57 + +from django.db import migrations, models + +import ponos.models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0001_initial"), + ] + + operations = [ + migrations.AlterField( + model_name="workflow", + name="recipe", + field=models.TextField(validators=[ponos.models.recipe_validator]), + ), + ] diff --git a/arkindex/ponos/migrations/0003_no_empty_slugs.py b/arkindex/ponos/migrations/0003_no_empty_slugs.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3d1978d60b9bf6b50d5242860c1227dd2a8f7b --- /dev/null +++ b/arkindex/ponos/migrations/0003_no_empty_slugs.py @@ -0,0 +1,22 @@ +# Generated by Django 2.1.7 on 2019-06-24 09:05 + +import django.core.validators +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0002_recipe_validator"), + ] + + operations = [ + migrations.AlterField( + model_name="task", + name="slug", + field=models.CharField( + max_length=250, + validators=[django.core.validators.MinLengthValidator(1)], + ), + ), + ] diff --git a/arkindex/ponos/migrations/0004_agent.py b/arkindex/ponos/migrations/0004_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..26a3a3a84600be534045edd81f9a9e402caf08f5 --- /dev/null +++ b/arkindex/ponos/migrations/0004_agent.py @@ -0,0 +1,76 @@ +# Generated by Django 2.1.7 on 2019-06-05 08:48 + +import uuid + +import django.core.validators +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0003_no_empty_slugs"), + ] + + operations = [ + migrations.CreateModel( + name="Agent", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, primary_key=True, serialize=False + ), + ), + ("created", models.DateTimeField(auto_now_add=True)), + ("updated", models.DateTimeField(auto_now=True)), + ( + "token", + models.CharField( + max_length=250, + unique=True, + validators=[ + django.core.validators.RegexValidator(r"^[0-9a-f]{64}$") + ], + ), + ), + ("hostname", models.SlugField(db_index=False, max_length=64)), + ( + "cpu_cores", + models.PositiveSmallIntegerField( + validators=[django.core.validators.MinValueValidator(1)], + ), + ), + ( + "cpu_frequency", + models.BigIntegerField( + validators=[django.core.validators.MinValueValidator(1)], + ), + ), + ("gpu_names", models.TextField()), + ("gpu_count", models.PositiveSmallIntegerField()), + ], + ), + migrations.AlterModelOptions( + name="workflow", + options={"ordering": ("-updated",)}, + ), + migrations.AddField( + model_name="task", + name="agent", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="tasks", + to="ponos.Agent", + ), + ), + migrations.CreateModel( + name="AgentUser", + fields=[], + options={"proxy": True, "indexes": []}, + bases=("ponos.agent",), + ), + ] diff --git a/arkindex/ponos/migrations/0005_gpu_names_blank.py b/arkindex/ponos/migrations/0005_gpu_names_blank.py new file mode 100644 index 0000000000000000000000000000000000000000..5a33f4aea40121a12a35b110081360547a31e583 --- /dev/null +++ b/arkindex/ponos/migrations/0005_gpu_names_blank.py @@ -0,0 +1,18 @@ +# Generated by Django 2.1.7 on 2019-07-02 08:00 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0004_agent"), + ] + + operations = [ + migrations.AlterField( + model_name="agent", + name="gpu_names", + field=models.TextField(blank=True, null=True), + ), + ] diff --git a/arkindex/ponos/migrations/0006_add_parents.py b/arkindex/ponos/migrations/0006_add_parents.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9f103262922647fdf4533e376f24184b95eb93 --- /dev/null +++ b/arkindex/ponos/migrations/0006_add_parents.py @@ -0,0 +1,20 @@ +# Generated by Django 2.1.7 on 2019-06-21 09:31 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0005_gpu_names_blank"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="parents", + field=models.ManyToManyField( + related_name="children", to="ponos.Task", symmetrical=False + ), + ), + ] diff --git a/arkindex/ponos/migrations/0007_migrate_parents.py b/arkindex/ponos/migrations/0007_migrate_parents.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d630314886714c8edd99b90d8ce35bc008612a --- /dev/null +++ b/arkindex/ponos/migrations/0007_migrate_parents.py @@ -0,0 +1,51 @@ +# Generated by Django 2.1.7 on 2019-06-21 09:31 + +from django.db import migrations +from django.db.models import Count + + +def parent_to_parents(apps, schema_editor): + db_alias = schema_editor.connection.alias + Task = apps.get_model("ponos", "Task") + TaskParent = Task.parents.through + + new_parents = [] + for task in ( + Task.objects.using(db_alias).exclude(parent=None).only("id", "parent_id") + ): + new_parents.append( + TaskParent( + from_task_id=task.id, + to_task_id=task.parent_id, + ) + ) + TaskParent.objects.using(db_alias).bulk_create(new_parents) + + +def parents_to_parent(apps, schema_editor): + db_alias = schema_editor.connection.alias + Task = apps.get_model("ponos", "Task") + assert ( + not Task.objects.using(db_alias) + .annotate(parents_count=Count("parents")) + .filter(parents_count__gt=1) + .exists() + ), "Cannot migrate task with multiple parents backwards" + + for task in Task.objects.using(db_alias).filter(parents__isnull=False): + task.parent = task.parents.get() + task.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0006_add_parents"), + ] + + operations = [ + migrations.RunPython( + parent_to_parents, + parents_to_parent, + ), + ] diff --git a/arkindex/ponos/migrations/0008_remove_parent.py b/arkindex/ponos/migrations/0008_remove_parent.py new file mode 100644 index 0000000000000000000000000000000000000000..396fc3dedc3cdfb3e46e07a172997e6efc1265e5 --- /dev/null +++ b/arkindex/ponos/migrations/0008_remove_parent.py @@ -0,0 +1,17 @@ +# Generated by Django 2.1.7 on 2019-06-21 09:32 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0007_migrate_parents"), + ] + + operations = [ + migrations.RemoveField( + model_name="task", + name="parent", + ), + ] diff --git a/arkindex/ponos/migrations/0009_tags.py b/arkindex/ponos/migrations/0009_tags.py new file mode 100644 index 0000000000000000000000000000000000000000..667b586caad0e5ce3fefad0ab25af4d977574f6e --- /dev/null +++ b/arkindex/ponos/migrations/0009_tags.py @@ -0,0 +1,25 @@ +# Generated by Django 2.1.7 on 2019-07-10 14:20 + +from django.db import migrations + +import ponos.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0008_remove_parent"), + ] + + operations = [ + migrations.AddField( + model_name="agent", + name="tags", + field=ponos.fields.CommaSeparatedListField(default=list, max_length=250), + ), + migrations.AddField( + model_name="task", + name="tags", + field=ponos.fields.CommaSeparatedListField(default=list, max_length=250), + ), + ] diff --git a/arkindex/ponos/migrations/0010_farm.py b/arkindex/ponos/migrations/0010_farm.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca3cf970ffbb4b100a2e0ab3285b015abb6d276 --- /dev/null +++ b/arkindex/ponos/migrations/0010_farm.py @@ -0,0 +1,91 @@ +# Generated by Django 2.1.7 on 2019-07-03 14:07 + +import uuid + +import django.core.validators +import django.db.models.deletion +from django.db import migrations, models + +import ponos.models + + +def default_farm(apps, schema_editor): + db_alias = schema_editor.connection.alias + Agent = apps.get_model("ponos", "Agent") + if not Agent.objects.using(db_alias).exists(): + return + + Farm = apps.get_model("ponos", "Farm") + default_farm = Farm.objects.using(db_alias).create(name="default") + Agent.objects.using(db_alias).all().update(farm=default_farm) + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0009_tags"), + ] + + operations = [ + migrations.CreateModel( + name="Farm", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + primary_key=True, + serialize=False, + ), + ), + ( + "name", + models.CharField( + max_length=250, + ), + ), + ( + "seed", + models.CharField( + default=ponos.models.generate_seed, + max_length=64, + unique=True, + validators=[ + django.core.validators.RegexValidator("^[0-9a-f]{64}$") + ], + ), + ), + ], + ), + migrations.AddField( + model_name="agent", + name="farm", + field=models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to="ponos.Farm", + blank=True, + null=True, + ), + ), + migrations.RunPython( + default_farm, + reverse_code=migrations.RunPython.noop, + elidable=True, + ), + migrations.AlterField( + model_name="agent", + name="farm", + field=models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + to="ponos.Farm", + ), + ), + migrations.AddField( + model_name="agent", + name="public_key", + field=models.TextField( + default="", + ), + preserve_default=False, + ), + ] diff --git a/arkindex/ponos/migrations/0011_remove_agent_token.py b/arkindex/ponos/migrations/0011_remove_agent_token.py new file mode 100644 index 0000000000000000000000000000000000000000..6000d2cf471a709a1ed861f126fc47203ed9d19f --- /dev/null +++ b/arkindex/ponos/migrations/0011_remove_agent_token.py @@ -0,0 +1,17 @@ +# Generated by Django 2.1.7 on 2019-07-04 13:34 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0010_farm"), + ] + + operations = [ + migrations.RemoveField( + model_name="agent", + name="token", + ), + ] diff --git a/arkindex/ponos/migrations/0012_advanced_tags.py b/arkindex/ponos/migrations/0012_advanced_tags.py new file mode 100644 index 0000000000000000000000000000000000000000..2a98bed6e6f340e9cc39bee0e3d36c0637fb7195 --- /dev/null +++ b/arkindex/ponos/migrations/0012_advanced_tags.py @@ -0,0 +1,25 @@ +# Generated by Django 2.2.5 on 2019-09-04 13:34 + +from django.db import migrations + +import ponos.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0011_remove_agent_token"), + ] + + operations = [ + migrations.RenameField( + model_name="agent", + old_name="tags", + new_name="include_tags", + ), + migrations.AddField( + model_name="agent", + name="exclude_tags", + field=ponos.fields.CommaSeparatedListField(default=list, max_length=250), + ), + ] diff --git a/arkindex/ponos/migrations/0013_artifact.py b/arkindex/ponos/migrations/0013_artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..8f212321b4434e15296111680da3d9aec65c4067 --- /dev/null +++ b/arkindex/ponos/migrations/0013_artifact.py @@ -0,0 +1,49 @@ +# Generated by Django 2.2.12 on 2020-05-27 09:29 + +import uuid + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0012_advanced_tags"), + ] + + operations = [ + migrations.CreateModel( + name="Artifact", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, primary_key=True, serialize=False + ), + ), + ("path", models.CharField(max_length=500)), + ("size", models.PositiveIntegerField()), + ( + "content_type", + models.CharField( + default="application/octet-stream", max_length=250 + ), + ), + ( + "task", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="artifacts", + to="ponos.Task", + ), + ), + ("created", models.DateTimeField(auto_now_add=True)), + ("updated", models.DateTimeField(auto_now=True)), + ], + options={ + "ordering": ("task", "path"), + "unique_together": {("task", "path")}, + }, + ), + ] diff --git a/arkindex/ponos/migrations/0014_modify_task_model.py b/arkindex/ponos/migrations/0014_modify_task_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0393f195262473413268753960da2d49aaef29 --- /dev/null +++ b/arkindex/ponos/migrations/0014_modify_task_model.py @@ -0,0 +1,66 @@ +# Generated by Django 2.2.12 on 2020-06-05 08:53 + +from django.db import migrations, models + +from ponos.recipe import parse_recipe + +try: + import jsonfield + + json_field_cls = jsonfield.JSONField +except ImportError: + json_field_cls = models.JSONField + + +def parse_workflows_recipe(apps, schema_editor): + """ + Retrieve and parse recipe for reach workflow to populate new attributes + (image, command, env) on its associated tasks + """ + + db_alias = schema_editor.connection.alias + Workflow = apps.get_model("ponos", "Workflow") + for workflow in Workflow.objects.using(db_alias).all(): + _, recipes = parse_recipe(workflow.recipe) + for slug, task_recipe in recipes.items(): + for task in workflow.tasks.filter(slug=slug): + task.image = task_recipe.image + if task_recipe.command: + task.command = task_recipe.command + if task_recipe.environment: + task.env = task_recipe.environment + task.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0013_artifact"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="command", + field=models.TextField(null=True), + ), + migrations.AddField( + model_name="task", + name="env", + field=json_field_cls(null=True), + ), + migrations.AddField( + model_name="task", + name="image", + field=models.CharField(null=True, max_length=250), + ), + migrations.RunPython( + parse_workflows_recipe, + reverse_code=migrations.RunPython.noop, + ), + migrations.AlterField( + model_name="task", + name="image", + field=models.CharField(max_length=250), + ), + ] diff --git a/arkindex/ponos/migrations/0015_task_has_docker_socket.py b/arkindex/ponos/migrations/0015_task_has_docker_socket.py new file mode 100644 index 0000000000000000000000000000000000000000..de50d041674c133fdc9d393a2528c7d2bef95102 --- /dev/null +++ b/arkindex/ponos/migrations/0015_task_has_docker_socket.py @@ -0,0 +1,18 @@ +# Generated by Django 2.2.12 on 2020-06-11 08:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0014_modify_task_model"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="has_docker_socket", + field=models.BooleanField(default=False), + ), + ] diff --git a/arkindex/ponos/migrations/0016_task_image_artifact.py b/arkindex/ponos/migrations/0016_task_image_artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3cf3ea307e44e8fc1805e56e5754ad55864980 --- /dev/null +++ b/arkindex/ponos/migrations/0016_task_image_artifact.py @@ -0,0 +1,24 @@ +# Generated by Django 2.2.13 on 2020-07-07 08:42 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0015_task_has_docker_socket"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="image_artifact", + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="tasks_using_image", + to="ponos.Artifact", + ), + ), + ] diff --git a/arkindex/ponos/migrations/0017_new_jsonfield.py b/arkindex/ponos/migrations/0017_new_jsonfield.py new file mode 100644 index 0000000000000000000000000000000000000000..24e2875da652649b89095790f04cd2cea926923f --- /dev/null +++ b/arkindex/ponos/migrations/0017_new_jsonfield.py @@ -0,0 +1,18 @@ +# Generated by Django 3.1 on 2020-08-10 14:59 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0016_task_image_artifact"), + ] + + operations = [ + migrations.AlterField( + model_name="task", + name="env", + field=models.JSONField(null=True), + ), + ] diff --git a/arkindex/ponos/migrations/0018_auto_20200814_0818.py b/arkindex/ponos/migrations/0018_auto_20200814_0818.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce079bf148d6e03cfb4f8a4b8c2620938750686 --- /dev/null +++ b/arkindex/ponos/migrations/0018_auto_20200814_0818.py @@ -0,0 +1,18 @@ +# Generated by Django 3.1 on 2020-08-14 08:18 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0017_new_jsonfield"), + ] + + operations = [ + migrations.AlterField( + model_name="artifact", + name="size", + field=models.PositiveBigIntegerField(), + ), + ] diff --git a/arkindex/ponos/migrations/0019_secret.py b/arkindex/ponos/migrations/0019_secret.py new file mode 100644 index 0000000000000000000000000000000000000000..cb30c62056dd16f064601b52e47f992ab2e550cd --- /dev/null +++ b/arkindex/ponos/migrations/0019_secret.py @@ -0,0 +1,34 @@ +# Generated by Django 3.1 on 2020-09-24 14:12 + +import uuid + +from django.db import migrations, models + +import ponos.keys + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0018_auto_20200814_0818"), + ] + + operations = [ + migrations.CreateModel( + name="Secret", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, primary_key=True, serialize=False + ), + ), + ("name", models.CharField(max_length=250, unique=True)), + ( + "nonce", + models.BinaryField(default=ponos.keys.gen_nonce, max_length=16), + ), + ("content", models.BinaryField(editable=True)), + ], + ), + ] diff --git a/arkindex/ponos/migrations/0020_fix_admin_blank.py b/arkindex/ponos/migrations/0020_fix_admin_blank.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e47bf09aec75e20c02f390d845067d087e9182 --- /dev/null +++ b/arkindex/ponos/migrations/0020_fix_admin_blank.py @@ -0,0 +1,35 @@ +# Generated by Django 3.1.2 on 2020-10-19 10:58 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0019_secret"), + ] + + operations = [ + migrations.AlterField( + model_name="task", + name="command", + field=models.TextField(blank=True, null=True), + ), + migrations.AlterField( + model_name="task", + name="env", + field=models.JSONField(blank=True, null=True), + ), + migrations.AlterField( + model_name="task", + name="image_artifact", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="tasks_using_image", + to="ponos.artifact", + ), + ), + ] diff --git a/arkindex/ponos/migrations/0021_add_ping_and_ram_fields.py b/arkindex/ponos/migrations/0021_add_ping_and_ram_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..d681f6af656ad549b669e85ba252f8966a3eb66a --- /dev/null +++ b/arkindex/ponos/migrations/0021_add_ping_and_ram_fields.py @@ -0,0 +1,53 @@ +# Generated by Django 3.1 on 2020-10-27 10:44 + +import django.core.validators +from django.db import migrations, models + + +def set_default_ping(apps, schema_editor): + Agent = apps.get_model("ponos", "Agent") + Agent.objects.all().update(last_ping=models.F("updated")) + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0020_fix_admin_blank"), + ] + + operations = [ + migrations.AddField( + model_name="agent", + name="cpu_load", + field=models.FloatField(blank=True, null=True), + ), + migrations.AddField( + model_name="agent", + name="ram_load", + field=models.FloatField(blank=True, null=True), + ), + migrations.AddField( + model_name="agent", + name="ram_total", + field=models.BigIntegerField( + default=1, validators=[django.core.validators.MinValueValidator(1)] + ), + preserve_default=False, + ), + migrations.AddField( + model_name="agent", + name="last_ping", + field=models.DateTimeField(auto_now=True), + preserve_default=False, + ), + migrations.RunPython( + set_default_ping, + reverse_code=migrations.RunPython.noop, + elidable=True, + ), + migrations.AlterField( + model_name="agent", + name="last_ping", + field=models.DateTimeField(editable=False), + ), + ] diff --git a/arkindex/ponos/migrations/0022_rm_excluded_included_tags.py b/arkindex/ponos/migrations/0022_rm_excluded_included_tags.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2e173b008a6d203722a7a033a420d1205e85a5 --- /dev/null +++ b/arkindex/ponos/migrations/0022_rm_excluded_included_tags.py @@ -0,0 +1,21 @@ +# Generated by Django 3.1 on 2020-11-02 10:37 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0021_add_ping_and_ram_fields"), + ] + + operations = [ + migrations.RemoveField( + model_name="agent", + name="exclude_tags", + ), + migrations.RemoveField( + model_name="agent", + name="include_tags", + ), + ] diff --git a/arkindex/ponos/migrations/0023_gpus.py b/arkindex/ponos/migrations/0023_gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9cf6b56357a8e8137cce4349519df7f1466b7b --- /dev/null +++ b/arkindex/ponos/migrations/0023_gpus.py @@ -0,0 +1,62 @@ +# Generated by Django 3.1.2 on 2020-11-18 09:30 + +import django.core.validators +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0022_rm_excluded_included_tags"), + ] + + operations = [ + migrations.RemoveField( + model_name="agent", + name="gpu_count", + ), + migrations.RemoveField( + model_name="agent", + name="gpu_names", + ), + migrations.CreateModel( + name="GPU", + fields=[ + ( + "id", + models.UUIDField(primary_key=True, serialize=False), + ), + ("name", models.CharField(max_length=250)), + ("index", models.PositiveIntegerField()), + ( + "ram_total", + models.BigIntegerField( + validators=[django.core.validators.MinValueValidator(1)] + ), + ), + ( + "agent", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="gpus", + to="ponos.agent", + ), + ), + ], + options={ + "unique_together": {("agent_id", "index")}, + }, + ), + migrations.AddField( + model_name="task", + name="gpu", + field=models.OneToOneField( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="task", + to="ponos.gpu", + ), + ), + ] diff --git a/arkindex/ponos/migrations/0024_task_requires_gpu.py b/arkindex/ponos/migrations/0024_task_requires_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..0d919b167e18a0e673c10073e1431d01223b302c --- /dev/null +++ b/arkindex/ponos/migrations/0024_task_requires_gpu.py @@ -0,0 +1,18 @@ +# Generated by Django 3.1.2 on 2020-12-10 16:00 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0023_gpus"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="requires_gpu", + field=models.BooleanField(default=False), + ), + ] diff --git a/arkindex/ponos/migrations/0025_workflow_farm.py b/arkindex/ponos/migrations/0025_workflow_farm.py new file mode 100644 index 0000000000000000000000000000000000000000..b679ae75622d79f82dada0c8b072ff7368a35b86 --- /dev/null +++ b/arkindex/ponos/migrations/0025_workflow_farm.py @@ -0,0 +1,53 @@ +# Generated by Django 3.1.5 on 2021-02-23 15:52 + +import uuid + +import django.db.models.deletion +from django.db import migrations, models + +DEFAULT_FARM_ID = uuid.uuid4() + + +def set_default_farm(apps, schema_editor): + Workflow = apps.get_model("ponos", "Workflow") + Farm = apps.get_model("ponos", "Farm") + workflows = Workflow.objects.all() + if not workflows.exists(): + return + default_farm = Farm.objects.order_by().first() + if not default_farm: + # Create a default farm if required + default_farm = Farm.objects.create(name="Default farm") + workflows.update(farm_id=default_farm.id) + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0024_task_requires_gpu"), + ] + + operations = [ + migrations.AddField( + model_name="workflow", + name="farm", + field=models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, null=True, to="ponos.farm" + ), + ), + migrations.RunPython( + set_default_farm, + reverse_code=migrations.RunPython.noop, + # No workflow exists initially + elidable=True, + ), + migrations.AlterField( + model_name="workflow", + name="farm", + field=models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="workflows", + to="ponos.farm", + ), + ), + ] diff --git a/arkindex/ponos/migrations/0026_alter_artifact_size.py b/arkindex/ponos/migrations/0026_alter_artifact_size.py new file mode 100644 index 0000000000000000000000000000000000000000..f2898ad0571057f98d4dddc92a07b117cecb5cf0 --- /dev/null +++ b/arkindex/ponos/migrations/0026_alter_artifact_size.py @@ -0,0 +1,34 @@ +# Generated by Django 3.2.3 on 2021-07-07 09:09 + +import django.core.validators +from django.db import migrations, models + + +def delete_large_artifacts(apps, schema_editor): + Artifact = apps.get_model("ponos", "Artifact") + Artifact.objects.filter( + models.Q(size__lt=1) | models.Q(size__gt=5368709120) + ).delete() + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0025_workflow_farm"), + ] + + operations = [ + migrations.RunPython( + delete_large_artifacts, reverse_code=migrations.RunPython.noop + ), + migrations.AlterField( + model_name="artifact", + name="size", + field=models.BigIntegerField( + validators=[ + django.core.validators.MinValueValidator(1), + django.core.validators.MaxValueValidator(5368709120), + ] + ), + ), + ] diff --git a/arkindex/ponos/migrations/0027_task_priority.py b/arkindex/ponos/migrations/0027_task_priority.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1033c6a0f7624de65e1f21fb0d89a3353cb5d9 --- /dev/null +++ b/arkindex/ponos/migrations/0027_task_priority.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.5 on 2021-07-20 10:56 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0026_alter_artifact_size"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="priority", + field=models.PositiveIntegerField(default=10), + ), + ] diff --git a/arkindex/ponos/migrations/0028_workflow_finished.py b/arkindex/ponos/migrations/0028_workflow_finished.py new file mode 100644 index 0000000000000000000000000000000000000000..fef4a65faef0718ae7a65f5fbf5fe331305553bf --- /dev/null +++ b/arkindex/ponos/migrations/0028_workflow_finished.py @@ -0,0 +1,26 @@ +# Generated by Django 4.0.1 on 2022-02-15 09:55 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0027_task_priority"), + ] + + operations = [ + migrations.AddField( + model_name="workflow", + name="finished", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddConstraint( + model_name="workflow", + constraint=models.CheckConstraint( + check=models.Q(finished=None) + | models.Q(finished__gte=models.F("created")), + name="ponos_workflow_finished_after_created", + ), + ), + ] diff --git a/arkindex/ponos/migrations/0029_task_expiry.py b/arkindex/ponos/migrations/0029_task_expiry.py new file mode 100644 index 0000000000000000000000000000000000000000..932463a91f091803220d127d00d743681d1d84a5 --- /dev/null +++ b/arkindex/ponos/migrations/0029_task_expiry.py @@ -0,0 +1,36 @@ +# Generated by Django 4.0.1 on 2022-03-28 15:53 + +from datetime import timedelta + +from django.db import migrations, models + +from ponos.models import expiry_default + + +def set_expiry(apps, schema_editor): + Task = apps.get_model("ponos", "Task") + Task.objects.update(expiry=models.F("updated") + timedelta(days=30)) + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0028_workflow_finished"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="expiry", + field=models.DateTimeField(null=True), + ), + migrations.RunPython( + set_expiry, + reverse_code=migrations.RunPython.noop, + ), + migrations.AlterField( + model_name="task", + name="expiry", + field=models.DateTimeField(default=expiry_default), + ), + ] diff --git a/arkindex/ponos/migrations/0030_task_extra_files.py b/arkindex/ponos/migrations/0030_task_extra_files.py new file mode 100644 index 0000000000000000000000000000000000000000..f36fa9f7c7526f369725a2e7f06183388e53aa5d --- /dev/null +++ b/arkindex/ponos/migrations/0030_task_extra_files.py @@ -0,0 +1,18 @@ +# Generated by Django 4.0.2 on 2022-06-07 11:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0029_task_expiry"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="extra_files", + field=models.JSONField(default=dict), + ), + ] diff --git a/arkindex/ponos/migrations/0031_emptyable_jsonfield.py b/arkindex/ponos/migrations/0031_emptyable_jsonfield.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b9bef89df4104d9a0a9d58c88d94de77802c1e --- /dev/null +++ b/arkindex/ponos/migrations/0031_emptyable_jsonfield.py @@ -0,0 +1,25 @@ +# Generated by Django 4.0.5 on 2022-06-27 15:03 + +from django.db import migrations + +import ponos.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0030_task_extra_files"), + ] + + operations = [ + migrations.AlterField( + model_name="task", + name="env", + field=ponos.fields.StringDictField(blank=True, null=True), + ), + migrations.AlterField( + model_name="task", + name="extra_files", + field=ponos.fields.StringDictField(default=dict), + ), + ] diff --git a/arkindex/ponos/migrations/0032_stringify_json.py b/arkindex/ponos/migrations/0032_stringify_json.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2a768810615fb3084ca7c1a59be5400a5b7ed4 --- /dev/null +++ b/arkindex/ponos/migrations/0032_stringify_json.py @@ -0,0 +1,68 @@ +# Generated by Django 4.0.5 on 2022-06-27 15:08 + +from django.db import migrations + + +def stringify_json(apps, schema_editor): + if schema_editor.connection.vendor == "postgresql": + # Postgres' jsonb type allows to stringify all payloads in one query per column + schema_editor.execute( + """ + UPDATE ponos_task + SET env = ( + SELECT jsonb_object(array_agg(key), array_agg(value)) + FROM jsonb_each_text(env) + ) + WHERE id in ( + SELECT id FROM ponos_task, jsonb_each(env) AS env_items + WHERE jsonb_typeof(env_items.value) <> 'string' + ); + """ + ) + schema_editor.execute( + """ + UPDATE ponos_task + SET extra_files = ( + SELECT jsonb_object(array_agg(key), array_agg(value)) + FROM jsonb_each_text(extra_files) + ) + WHERE id in ( + SELECT id FROM ponos_task, jsonb_each(extra_files) AS extra_files_items + WHERE jsonb_typeof(extra_files_items.value) <> 'string' + ); + """ + ) + + else: + Task = apps.get_model("ponos", "Task") + to_update = [] + for task in Task.objects.only("id", "env", "extra_files"): + updated = False + + if task.env and not all(isinstance(value, str) for value in task.env): + task.env = {key: str(value) for key, value in task.env.items()} + updated = True + + if task.extra_files and not all( + isinstance(value, str) for value in task.extra_files + ): + task.extra_files = { + key: str(value) for key, value in task.extra_files.items() + } + updated = True + + if updated: + to_update.append(task) + + Task.objects.bulk_update(to_update, fields=["env", "extra_files"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0031_emptyable_jsonfield"), + ] + + operations = [ + migrations.RunPython(stringify_json, reverse_code=migrations.RunPython.noop), + ] diff --git a/arkindex/ponos/migrations/0033_task_shm_size.py b/arkindex/ponos/migrations/0033_task_shm_size.py new file mode 100644 index 0000000000000000000000000000000000000000..9623c8b669b387c17df7a2bc4a074b8973bd495d --- /dev/null +++ b/arkindex/ponos/migrations/0033_task_shm_size.py @@ -0,0 +1,20 @@ +# Generated by Django 4.0.4 on 2022-10-20 14:37 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0032_stringify_json"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="shm_size", + field=models.CharField( + blank=True, editable=False, max_length=80, null=True + ), + ), + ] diff --git a/arkindex/ponos/migrations/0034_alter_artifact_size.py b/arkindex/ponos/migrations/0034_alter_artifact_size.py new file mode 100644 index 0000000000000000000000000000000000000000..06a5d668f549b36cbea97fd43e2986b0f6357835 --- /dev/null +++ b/arkindex/ponos/migrations/0034_alter_artifact_size.py @@ -0,0 +1,24 @@ +# Generated by Django 4.1.3 on 2023-01-09 09:52 + +from django.core.validators import MinValueValidator +from django.db import migrations, models + +from ponos.models import artifact_max_size +from ponos.validators import MaxValueValidator + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0033_task_shm_size"), + ] + + operations = [ + migrations.AlterField( + model_name="artifact", + name="size", + field=models.BigIntegerField( + validators=[MinValueValidator(1), MaxValueValidator(artifact_max_size)] + ), + ), + ] diff --git a/arkindex/ponos/migrations/__init__.py b/arkindex/ponos/migrations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef1b2994c752f7d62de75a621b41e268ce02f20 --- /dev/null +++ b/arkindex/ponos/models.py @@ -0,0 +1,908 @@ +import logging +import os.path +import random +import uuid +from collections import namedtuple +from datetime import timedelta +from hashlib import sha256 + +from botocore.exceptions import ClientError +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from django.conf import settings +from django.core.exceptions import ValidationError +from django.core.validators import MinLengthValidator, MinValueValidator, RegexValidator +from django.db import models, transaction +from django.db.models import Q +from django.urls import reverse +from django.utils import timezone +from django.utils.functional import cached_property +from enumfields import Enum, EnumField +from rest_framework_simplejwt.tokens import RefreshToken +from yaml import YAMLError + +from ponos.aws import object_url, s3 +from ponos.fields import CommaSeparatedListField, StringDictField +from ponos.keys import gen_nonce +from ponos.managers import TaskManager +from ponos.recipe import parse_recipe, recipe_depth +from ponos.validators import MaxValueValidator + +# Maximum allowed time until an agent is considered inactive since last request +AGENT_TIMEOUT = timedelta( + **getattr(settings, "PONOS_ACTIVE_AGENT_TIMEOUT", {"seconds": 30}) +) +# Estimation of required resources to run a task on an agent. Defaults to 1 core and 1GB of RAM +AGENT_SLOT = getattr(settings, "PONOS_TASK_SLOT", {"cpu": 1, "ram": 1e9}) + +Action = namedtuple("Action", "action, task") +Action.__doc__ = """ +Describes an instruction sent to an agent. +Can optionally be associated with a :class:`Task`. +""" + +Action.action.__doc__ = """ +Type of action to perform. + +:type: ponos.models.ActionType +""" + +Action.task.__doc__ = """ +Optionally associated Task, when relevant. + +:type: ponos.models.Task +""" + +logger = logging.getLogger(__name__) + + +class ActionType(Enum): + """ + Describes which action an agent should perform. + """ + + StartTask = "start_task" + """ + Instruct the agent to start a new task. + The action must have an associated task. + """ + + StopTask = "stop_task" + """ + Instruct the agent to stop a running task. + The action must have an associated task. + """ + + Kill = "kill" + """ + Instruct the agent to shut itself down. + This is the only case where the agent will have a 0 exit code. + """ + + +def generate_seed() -> str: + return "{:064x}".format(random.getrandbits(256)) + + +class Farm(models.Model): + """ + A group of agents, whose ID and seed can be used to register new agents + automatically in auto-scaling contexts. + """ + + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + name = models.CharField(max_length=250) + seed = models.CharField( + max_length=64, + unique=True, + default=generate_seed, + validators=[RegexValidator(r"^[0-9a-f]{64}$")], + ) + + def __str__(self) -> str: + return "Farm {}".format(self.name) + + +class Agent(models.Model): + """ + A remote host that can run tasks. + """ + + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + created = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + farm = models.ForeignKey(Farm, on_delete=models.PROTECT) + public_key = models.TextField() + + hostname = models.SlugField(max_length=64, db_index=False) + cpu_cores = models.PositiveSmallIntegerField(validators=[MinValueValidator(1)]) + cpu_frequency = models.BigIntegerField(validators=[MinValueValidator(1)]) + # Total amount of RAM on this agent in bytes + ram_total = models.BigIntegerField(validators=[MinValueValidator(1)]) + # Last minute average CPU load measure on this agent + cpu_load = models.FloatField(null=True, blank=True) + # Last RAM load measure expressed as a percentage (0 ≤ ram_load ≤ 1) + ram_load = models.FloatField(null=True, blank=True) + last_ping = models.DateTimeField(editable=False) + + @property + def token(self) -> RefreshToken: + """ + JSON Web Token for this agent. + """ + return RefreshToken.for_user(self) + + @property + def active(self) -> bool: + return self.last_ping >= timezone.now() - AGENT_TIMEOUT + + def __str__(self) -> str: + return self.hostname + + def delete(self) -> None: + if self.tasks.exclude(state__in=FINAL_STATES).exists(): + raise ValidationError( + "This agent has one or more tasks in non-final states." + ) + return super().delete() + + def _estimate_new_tasks_cost(self, tasks=1): + """ + Metric used to estimate the load on this agent starting new tasks. + Used as the cost function to minimize overall agents load while attributing tasks. + + :param tasks: Number of tasks to estimate the cost for. + :returns: A cost expressed as a percentage. If > 1, the agent would be overloaded. + """ + cpu_cost = (self.cpu_load + tasks * AGENT_SLOT["cpu"]) / self.cpu_cores + ram_cost = self.ram_load + tasks * AGENT_SLOT["ram"] / self.ram_total + return max(cpu_cost, ram_cost) + + def next_tasks(self): + """ + Compute the next tasks that should be run by an agent. + + :returns: A list of tasks. + """ + if self._estimate_new_tasks_cost() >= 1: + # The capacity of this agent does not allow a new task to be started + return [] + + # Filter pending task on the agent farm ordered by higher priority first and then seniority (older first) + pending_tasks = Task.objects.filter( + Q(workflow__farm_id=self.farm_id) + & Q( + Q(state=State.Pending, agent=None) | Q(state=State.Unscheduled, depth=0) + ) + ).order_by("-priority", "updated") + + if not pending_tasks.exists(): + return [] + + # Retrieve active agents within the same farm + active_agents = Agent.objects.filter( + last_ping__gte=timezone.now() - AGENT_TIMEOUT + ).filter(farm_id=self.farm_id) + + # List available gpus (without any pending or running task assigned) + available_gpus = list( + self.gpus.filter(Q(task__isnull=True) & ~Q(task__state__in=ACTIVE_STATES)) + ) + + # Simulate the attribution of pending tasks minimizing overall load + attributed_tasks = {agent: [] for agent in active_agents} + for task in pending_tasks: + + # High priority for tasks with GPU requirements + if task.requires_gpu: + if available_gpus: + task.gpu = available_gpus.pop() + min_cost_agent = self + else: + # Skip tasks requiring GPU when none is available + logger.info(f"No GPU available for task {task.id} - {task}") + continue + else: + # Compare the cost of adding a new task on compatible agents + min_cost_agent = min( + active_agents, + key=lambda agent: agent._estimate_new_tasks_cost( + tasks=len(attributed_tasks[agent]) + 1 + ), + ) + + # Append the task to the queue of the agent with the minimal cost + tasks = attributed_tasks[min_cost_agent] + if min_cost_agent._estimate_new_tasks_cost(len(tasks) + 1) > 1: + # Attributing the next task would overload the system + break + tasks.append(task) + + # Return tasks attributed to the agent making the request + return attributed_tasks[self] + + @transaction.atomic + def next_actions(self): + """ + Compute the next actions to send to an agent. + + This method must run in a single transaction to avoid fetching tasks that are being + assigned to another agent. + + :returns: List of :obj:`Action`. + """ + pending_tasks = self.tasks.filter(state=State.Pending) + next_tasks = self.next_tasks() + stop_tasks = self.tasks.filter(state=State.Stopping) + + actions = [] + for task in stop_tasks: + actions.append(Action(ActionType.StopTask, task)) + for task in pending_tasks: + actions.append(Action(ActionType.StartTask, task)) + for task in next_tasks: + task.agent = self + task.state = State.Pending + task.save() + actions.append(Action(ActionType.StartTask, task)) + return actions + + +class State(Enum): + """ + Describes the possible :class:`Task` states. + + Inspired by + `Taskcluster's states <https://docs.taskcluster.net/docs/manual/tasks/runs>`_ + """ + + Unscheduled = "unscheduled" + """ + This is the default initial state of all tasks. Tasks will keep this state in two cases: + + * They do not have dependencies in a :class:`Workflow`, but no :class:`Agent` + has yet been assigned to them; + * They depend on other tasks that are not yet :attr:`~State.Completed`. + """ + + Pending = "pending" + """ + Intermediate state where a task is being assigned to an :class:`Agent` and starting. + + Tasks without dependencies will enter this state when they are assigned to an agent. + Tasks with dependencies will reach this state as soon as all their dependencies are met. + + Once a task is assigned to an agent, the agent is responsible for updating its state. + """ + + Running = "running" + """ + State where a task is currently being run by an :class:`Agent`. + Agents will update the state of their assigned pending tasks to Running as soon + as the Docker container has started. API users may start monitoring for logs. + """ + + Completed = "completed" + """ + State where a task has finished successfully. + Reaching this state means the artifacts are available, and any dependent tasks with + all their dependencies met will enter the Pending state. + """ + + Failed = "failed" + """ + State where a task has been run, but ended with a non-zero error code. + More information might be found in the task's logs. + """ + + Error = "error" + """ + State where a task may or may not have been run, but an error occurred on the + :class:`Agent`'s side. This is usually caused by Docker or host-related issues. + """ + + Stopping = "stopping" + """ + State that can be set by API users to ask for a running task to be stopped. + The task's assigned agent will be instructed to stop the task, and will report the + Stopped state when done. + """ + + Stopped = "stopped" + """ + State where a task that entered the Stopping state has successfully stopped. + """ + + +# States where a task is considered final. +# Once a task reaches a final state, its state should no longer change. +FINAL_STATES = ( + State.Completed, + State.Failed, + State.Error, + State.Stopped, +) + +# States where a task is considered active on an agent +ACTIVE_STATES = ( + State.Pending, + State.Running, + State.Stopping, +) + +# Tasks priority to determine the overall state of multiple tasks +# If there are failed tasks, the workflow is failed. +# Else, if there are any errors, the workflow errored. +# If there are any running tasks, no matter in which state another task may be, +# the workflow is set as running to prevent retrying and allow stopping. +STATES_ORDERING = [ + State.Running, + State.Failed, + State.Error, + State.Stopping, + State.Stopped, + State.Pending, + State.Unscheduled, + State.Completed, +] + + +def recipe_validator(value) -> None: + """ + Validator for the ``recipe`` field in a :class:`Workflow`. + + :raises ValidationError: When the recipe is not valid YAML, or not a valid recipe. + """ + try: + parse_recipe(value) + except (YAMLError, AssertionError) as e: + raise ValidationError(str(e)) + + +class Workflow(models.Model): + """ + A group of tasks that can depend on each other, + created from a YAML recipe. + """ + + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + recipe = models.TextField(validators=[recipe_validator]) + created = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + finished = models.DateTimeField(blank=True, null=True) + farm = models.ForeignKey( + to="ponos.Farm", related_name="workflows", on_delete=models.PROTECT + ) + + class Meta: + ordering = ("-updated",) + constraints = [ + # A workflow cannot be finished before it is created + models.CheckConstraint( + check=models.Q(finished=None) + | models.Q(finished__gte=models.F("created")), + name="ponos_workflow_finished_after_created", + ) + ] + + def __str__(self) -> str: + return str(self.id) + + @cached_property + def recipes(self): + _, recipes = parse_recipe(self.recipe) + return recipes + + def get_absolute_url(self) -> str: + """ + :returns: URL to the workflow details API for this workflow. + """ + return reverse("ponos:workflow-details", args=[self.id]) + + @transaction.atomic + def build_tasks(self, run: int = 0): + """ + Parse this workflow's recipe and create unscheduled :class:`Task` instances in database. + + :param run: Run number to assign to each task. + :returns: A dict mapping task slugs to task instances. + :rtype: dict + """ + assert isinstance(run, int) + assert run >= 0 + assert not self.tasks.filter(run=run).exists(), "Run {} already exists".format( + run + ) + + # Create tasks without any parent + tasks = { + slug: self.tasks.create( + run=run, + slug=slug, + tags=recipe.tags, + depth=recipe_depth(slug, self.recipes), + image=recipe.image, + command=recipe.command, + shm_size=recipe.shm_size, + env=recipe.environment, + has_docker_socket=recipe.has_docker_socket, + image_artifact=Artifact.objects.get(id=recipe.artifact) + if recipe.artifact + else None, + requires_gpu=recipe.requires_gpu, + extra_files=recipe.extra_files if recipe.extra_files else {}, + ) + for slug, recipe in self.recipes.items() + } + + # Apply parents + for slug, recipe in self.recipes.items(): + task = tasks[slug] + task.parents.set([tasks[parent] for parent in recipe.parents]) + + return tasks + + @property + def state(self) -> State: + """ + Deduce the workflow's state from the tasks of its latest run. + A workflow's state is deduced by picking the first state that any tasks in this run have, in this order: + + #. :attr:`~State.Running` + #. :attr:`~State.Failed` + #. :attr:`~State.Error` + #. :attr:`~State.Stopping` + #. :attr:`~State.Stopped` + #. :attr:`~State.Pending` + #. :attr:`~State.Unscheduled` + """ + return self.get_state(self.get_last_run()) + + def get_state(self, run): + """ + A separate method to get a workflow's state on a given run. + + Most users will only use the :meth:`Workflow.state` property to get the state for a workflow's last run. + However, when trying to get a state for many workflows at once, using ``.annotate(last_run=Max('tasks__run'))`` + and using the annotation with this method will prevent many useless SQL requests. + + Further performance improvements can be achieved with ``prefetch_related('tasks')``. + """ + # Negative run numbers never have tasks + if run < 0: + return State.Unscheduled + + # This prevents performing another SQL request when tasks have already been prefetched. + # See https://stackoverflow.com/a/19651840/5990435 + if ( + hasattr(self, "_prefetched_objects_cache") + and self.tasks.field.remote_field.get_cache_name() + in self._prefetched_objects_cache + ): + task_states = set(t.state for t in self.tasks.all() if t.run == run) + else: + task_states = set( + self.tasks.filter(run=run).values_list("state", flat=True) + ) + + # This run has no tasks + if not task_states: + return State.Unscheduled + + # All tasks have the same state + if len(task_states) == 1: + return task_states.pop() + + for state in STATES_ORDERING: + if state in task_states: + return state + + raise NotImplementedError("Something went wrong") + + def get_last_run(self) -> int: + """ + Get the last run number. If the ``last_run`` attribute is defined on this workflow, + possibly from a ``.annotate(last_run=Max('tasks__run'))`` annotation in a Django QuerySet, + this method will return the attribute's value instead of making another SQL request. + """ + if hasattr(self, "last_run"): + if self.last_run is None: + return -1 + return self.last_run + + if not self.tasks.exists(): + return -1 + return self.tasks.all().aggregate(models.Max("run"))["run__max"] + + def is_final(self) -> bool: + """ + Helper to tell whether a workflow is final. + A workflow is considered final when it is in one of the final states: + :attr:`~State.Completed`, :attr:`~State.Failed`, :attr:`~State.Error`, :attr:`~State.Stopped` + + :returns: Whether or not the workflow is considered final. + """ + return self.state in FINAL_STATES + + @property + def expiry(self): + """ + A workflow's expiry date. This is the latest expiry date of its tasks. + No action is taken when a workflow is expired. + + :returns: The latest expiry date of the workflow's tasks, or None if there are no tasks. + :rtype: datetime or None + """ + # This prevents performing another SQL request when tasks have already been prefetched. + # See https://stackoverflow.com/a/19651840/5990435 + if ( + hasattr(self, "_prefetched_objects_cache") + and self.tasks.field.remote_field.get_cache_name() + in self._prefetched_objects_cache + ): + return max(t.expiry for t in self.tasks.all()) + else: + return self.tasks.aggregate(models.Max("expiry"))["expiry__max"] + + def start(self): + """ + Build new :class:`Task` instances associated to this workflow + + :raises AssertionError: If the workflow has already been started + :returns: A dict mapping task slugs to task instances. + :rtype: dict + """ + assert ( + not self.tasks.exists() + ), "Could not start the workflow as it has associated tasks" + built_tasks = self.build_tasks() + # setting last_run to 0 when a workflow gets started to avoid unnecessary db queries + self.last_run = 0 + return built_tasks + + def retry(self): + """ + Create new :class:`Task` instances with a new run number and resets the completion date. + + :raises AssertionError: If the workflow is not in a final state. + :returns: A dict mapping task slugs to task instances. + :rtype: dict + """ + last_run = self.get_last_run() + assert self.is_final() + tasks = self.build_tasks(last_run + 1) + # setting last_run so that subsequent calls to get_last_run do not require db queries + self.last_run = last_run + 1 + self.finished = None + self.save() + return tasks + + def stop(self) -> None: + """ + Fully stop the workflow by updating every running task to the :attr:`~State.Stopping` state, + and every unscheduled task to the :attr:`~State.Stopped` state. + """ + assert not self.is_final() + stopping_count = Task.objects.filter( + workflow=self, state__in=[State.Pending, State.Running] + ).update(state=State.Stopping) + Task.objects.filter(workflow=self, state=State.Unscheduled).update( + state=State.Stopped + ) + # If all the tasks are immediately stopped, then UpdateTask will not be able to update + # the finished attribute, so we do it here. + if not stopping_count: + self.finished = timezone.now() + self.save() + + +def expiry_default(): + """ + Default value for Task.expiry. + + :rtype: datetime + """ + return timezone.now() + timedelta(days=30) + + +class Task(models.Model): + """ + A task created from a workflow's recipe. + """ + + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + run = models.PositiveIntegerField() + depth = models.PositiveIntegerField() + # blank=False is only used in Django form validation, to prevent blank slugs we use a minimum length + slug = models.CharField(max_length=250, validators=[MinLengthValidator(1)]) + priority = models.PositiveIntegerField(default=10) + state = EnumField(State, default=State.Unscheduled, max_length=20) + tags = CommaSeparatedListField(default=list, max_length=250) + image = models.CharField(max_length=250) + shm_size = models.CharField(max_length=80, blank=True, null=True, editable=False) + command = models.TextField(blank=True, null=True) + env = StringDictField(blank=True, null=True) + has_docker_socket = models.BooleanField(default=False) + image_artifact = models.ForeignKey( + "ponos.Artifact", + related_name="tasks_using_image", + on_delete=models.SET_NULL, + blank=True, + null=True, + ) + + agent = models.ForeignKey( + Agent, + related_name="tasks", + blank=True, + null=True, + on_delete=models.SET_NULL, + ) + requires_gpu = models.BooleanField(default=False) + gpu = models.OneToOneField( + "ponos.GPU", + related_name="task", + blank=True, + null=True, + on_delete=models.SET_NULL, + ) + workflow = models.ForeignKey( + Workflow, + related_name="tasks", + on_delete=models.CASCADE, + ) + parents = models.ManyToManyField( + "self", + related_name="children", + symmetrical=False, + ) + + container = models.CharField( + max_length=64, + null=True, + blank=True, + ) + + created = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + expiry = models.DateTimeField(default=expiry_default) + + # Remote files required to start the container + extra_files = StringDictField(default=dict) + + objects = TaskManager() + + class Meta: + unique_together = (("workflow", "run", "slug"),) + ordering = ("workflow", "run", "depth", "slug") + + def __str__(self) -> str: + return "Task {}, run {}, depth {}".format(self.slug, self.run, self.depth) + + def get_absolute_url(self) -> str: + """ + :returns: URL to the workflow details API for this workflow. + """ + return reverse("ponos:task-details", args=[self.id]) + + def is_final(self) -> bool: + """ + Helper to tell whether a task is final. + A task is considered final when it is in one of the final states: + :attr:`~State.Completed`, :attr:`~State.Failed`, :attr:`~State.Error`, :attr:`~State.Stopped` + + :returns: Whether or not the workflow is considered final. + """ + return self.state in FINAL_STATES + + @cached_property + def s3_logs(self): + """ + Get an S3 object instance for this task's logs file. + This does not check for the file's existence. + """ + return s3.Object( + settings.PONOS_S3_LOGS_BUCKET, + os.path.join( + str(self.workflow.id), + "run_{}".format(self.run), + "{!s}.log".format(self.id), + ), + ) + + @cached_property + def s3_logs_get_url(self) -> str: + """ + Get a presigned S3 GET URL for the task's logs file. + This does not check for the file's existence. + """ + return object_url("get_object", self.s3_logs) + + @cached_property + def s3_logs_put_url(self) -> str: + """ + Get a presigned S3 PUT URL for the task's logs file. + This does not check for the file's existence. + """ + return object_url("put_object", self.s3_logs) + + @property + def has_logs(self) -> bool: + """ + Test for the logs file's existence on S3. + """ + try: + self.s3_logs.load() + return True + except ClientError as e: + if e.response["Error"]["Code"] != "404": + raise + return False + + @property + def short_logs(self) -> str: + """ + Fetch the last N bytes from the task's log file on S3, and return them decoded. + An empty string is returned if the logs file is missing. + The decoded text is cut to start after the first found newline character to avoid incomplete lines. + + The logs will start with ``[Logs were truncated]`` to help API users or frontends + make it clear there are more logs available, unless less bytes than the requested amount are returned. + + The amount of fetched bytes is defined by the ``PONOS_LOG_TAIL`` setting, which defaults to 10000. + """ + max_length = getattr(settings, "PONOS_LOG_TAIL", 10000) + + try: + log_bytes = self.s3_logs.get( + Range="bytes=-{}".format(max_length), + )["Body"].read() + + except s3.meta.client.exceptions.NoSuchKey: + return "" + + text = log_bytes.decode("utf-8", errors="replace") + + if len(text) < max_length: + return text + + # Not starting from the beginning of the file + # Start at the next new line + text = text[text.find("\n") + 1 :] + return "[Logs were truncated]\n" + text + + +def artifact_max_size(): + """ + AWS restricts uploads to 5GiB per PUT request, + but some S3 implementations might not have this restriction, + so we default to 5 GiB and allow overriding through the Django settings. + + We use a function to allow this setting to change at runtime, + which makes unit testing a lot easier. + """ + setting = getattr(settings, "PONOS_ARTIFACT_MAX_SIZE", None) + return setting if setting is not None else 5 * 1024**3 + + +class Artifact(models.Model): + """ + A task Artifact (Json report, docker images, ML Models...) + """ + + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + task = models.ForeignKey(Task, related_name="artifacts", on_delete=models.CASCADE) + path = models.CharField(max_length=500) + size = models.BigIntegerField( + validators=[ + MinValueValidator(1), + MaxValueValidator(artifact_max_size), + ] + ) + content_type = models.CharField(max_length=250, default="application/octet-stream") + created = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ("task", "path") + unique_together = (("task", "path"),) + + @cached_property + def s3(self): + """ + Get an S3 object instance for this Artifact + This does not check for the file's existence. + """ + return s3.Object( + settings.PONOS_S3_ARTIFACTS_BUCKET, + os.path.join( + str(self.task.workflow_id), + str(self.task.id), + str(self.path), + ), + ) + + @cached_property + def s3_get_url(self) -> str: + """ + Get a presigned S3 GET URL to download this artifact + This does not check for the file's existence. + """ + return object_url("get_object", self.s3) + + @cached_property + def s3_put_url(self) -> str: + """ + Get a presigned S3 PUT URL to store this Artifact on S3 + This does not check for the file's existence. + """ + return object_url("put_object", self.s3) + + +def build_aes_cipher(nonce): + """ + Initialize an AES cipher using the Ponos private key + """ + key_path = getattr(settings, "PONOS_PRIVATE_KEY", None) + assert key_path and os.path.exists(key_path), "Missing a PONOS_PRIVATE_KEY" + # Use the Ponos private key SHA-256 as the AES key + with open(settings.PONOS_PRIVATE_KEY, "rb") as f: + ponos_key = f.read() + aes_key = sha256(ponos_key).digest() + return Cipher(algorithms.AES(aes_key), modes.CTR(nonce)) + + +def encrypt(nonce, plain_text): + """ + Encrypt a plain text using the AES cipher model content ciphering a text from the Ponos private key + """ + cipher = build_aes_cipher(nonce) + encryptor = cipher.encryptor() + plain_text = plain_text.encode() + return encryptor.update(plain_text) + + +class Secret(models.Model): + """ + A secret encrypted with a derivate of the Ponos server private key (ECDH) + Secret content is encrypted with AES, using a key generated serializing ponos private key (ECDH) with a salt + """ + + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + name = models.CharField(max_length=250, unique=True) + nonce = models.BinaryField(max_length=16, default=gen_nonce) + content = models.BinaryField(editable=True) + + def __str__(self): + return self.name + + def decrypt(self): + """ + Returns the plain text deciphered from the Ponos private key + """ + decryptor = build_aes_cipher(self.nonce).decryptor() + plain_text = decryptor.update(self.content) + return plain_text.decode() + + +class GPU(models.Model): + """ + A Graphic Card, attached to an agent + """ + + # The specific GPU uuid defined by the card itself + # It's used as GPU-{id} to identity the GPU for docker + id = models.UUIDField(primary_key=True) + agent = models.ForeignKey(Agent, related_name="gpus", on_delete=models.CASCADE) + name = models.CharField(max_length=250) + + # The numeric index assigned to the card by the host + index = models.PositiveIntegerField() + + # RAM management + ram_total = models.BigIntegerField(validators=[MinValueValidator(1)]) + + class Meta: + unique_together = (("agent_id", "index"),) + + def __str__(self): + return self.name diff --git a/arkindex/ponos/permissions.py b/arkindex/ponos/permissions.py new file mode 100644 index 0000000000000000000000000000000000000000..69525e6be6bd3370e12f8de2df80ce70d74fbe41 --- /dev/null +++ b/arkindex/ponos/permissions.py @@ -0,0 +1,43 @@ +from rest_framework.permissions import SAFE_METHODS, IsAuthenticated + +from ponos.models import Task + + +class IsAgent(IsAuthenticated): + """ + Only allow Ponos agents and admins + """ + + def has_permission(self, request, view) -> bool: + if ( + request.user.is_staff + or hasattr(request.user, "is_agent") + and request.user.is_agent + ): + return super().has_permission(request, view) + return False + + +class IsAgentOrReadOnly(IsAgent): + """ + Restricts write access to Ponos agents and admins + """ + + def has_permission(self, request, view) -> bool: + return request.method in SAFE_METHODS or super().has_permission(request, view) + + +class IsAssignedAgentOrReadOnly(IsAgentOrReadOnly): + """ + Restricts write access to Ponos agents and admins, + and restricts write access on tasks to admins and agents assigned to them. + """ + + def has_object_permission(self, request, view, obj) -> bool: + if ( + isinstance(obj, Task) + and not request.user.is_staff + and not request.user == obj.agent + ): + return request.method in SAFE_METHODS + return super().has_object_permission(request, view, obj) diff --git a/arkindex/ponos/recipe.py b/arkindex/ponos/recipe.py new file mode 100644 index 0000000000000000000000000000000000000000..0daeff53911fac83a76af64671190e23d69d87ff --- /dev/null +++ b/arkindex/ponos/recipe.py @@ -0,0 +1,139 @@ +import re +from collections import namedtuple +from uuid import UUID + +import yaml +from django.core.validators import URLValidator + +TaskRecipe = namedtuple( + "TaskRecipe", + "image, parents, command, shm_size, environment, tags, has_docker_socket, artifact, requires_gpu, extra_files", +) + + +def validate_task(task, top_env): + """ + Validate a task description + - image (required) + - command (optional) + - parents (optional) + - shm_size (optional) + - env (optional) + - tags (optional) + - has_docker_socket (optional) + - artifact (optional) + - requires_gpu (optional) + - extra_files (optional) + """ + assert isinstance(task, dict), "Task should be a dict" + assert "image" in task, "Missing image" + + if "parents" in task: + assert isinstance(task["parents"], list), "Task parents should be a list" + if "tags" in task: + assert isinstance(task["tags"], list), "Task tags should be a list" + if task.get("shm_size"): + value = str(task["shm_size"]) + assert re.match( + r"^[0-9]+[bkmgBKMG]?$", str(value) + ), f"{str(value)} is not a valid value for shm_size" + if str(value).isdigit(): + assert int(value) > 0, "shm_size value must be greater than 0" + else: + assert int(value[:-1]) > 0, "shm_size value must be greater than 0" + + # Add optional local environment variables + env = top_env.copy() + if "env" in task: + assert isinstance( + task["env"], dict + ), "Task environment variables should be a dict" + env.update(task["env"]) + + if "has_docker_socket" in task: + assert isinstance( + task["has_docker_socket"], bool + ), "Task has_docker_socket should be a boolean" + + if "artifact" in task: + assert isinstance(task["artifact"], str), "Task artifact should be a string" + try: + UUID(task["artifact"]) + except (TypeError, ValueError): + raise AssertionError("Task artifact should be a valid UUID string") + + if "requires_gpu" in task: + assert isinstance( + task["requires_gpu"], bool + ), "Task requires_gpu should be a boolean" + + if "extra_files" in task: + assert isinstance( + task["extra_files"], dict + ), "Task extra_files should be a dict of strings/url" + url_validator = URLValidator(schemes=["http", "https"]) + for key, value in task["extra_files"].items(): + assert isinstance(key, str), "All Task extra_files keys should be strings" + assert isinstance( + value, str + ), "All Task extra_files values should be strings" + url_validator(value) + + return TaskRecipe( + task["image"], + task.get("parents", []), + task.get("command"), + task.get("shm_size", None), + env, + task.get("tags", []), + task.get("has_docker_socket", False), + task.get("artifact", ""), + task.get("requires_gpu", False), + task.get("extra_files", {}), + ) + + +def parse_recipe(recipe): + """ + Parse a recipe and check its content + Build a dict with all recipes per slug + """ + content = yaml.safe_load(recipe) + assert isinstance(content, dict), "Recipe should be a dict" + + # Load optional environment variables + env = content.get("env", {}) + assert isinstance(env, dict) + assert all(map(lambda x: isinstance(x, str), env.keys())) + env = {k: str(v) for k, v in env.items()} + + # Load tasks + tasks = content.get("tasks") + assert tasks, "No tasks" + assert isinstance(tasks, dict), "Tasks should be a dict" + assert all(tasks), "Tasks should have non-blank slugs" + + # Validate all tasks + tasks = {slug: validate_task(task, env) for slug, task in tasks.items()} + assert len(tasks) > 0, "No tasks recipes" + + # Check all parents exists + for slug, task in tasks.items(): + for parent in task.parents: + assert parent in tasks, "Missing parent {} for task {}".format(parent, slug) + + return env, tasks + + +def recipe_depth(slug, recipes, depth=0): + """ + Find recursively the depth of a recipe + """ + recipe = recipes.get(slug) + assert recipe is not None, "Recipe not found" + assert isinstance(recipe, TaskRecipe), "Recipe is not a TaskRecipe" + + if not recipe.parents: + return depth + + return max((recipe_depth(parent, recipes, depth + 1) for parent in recipe.parents)) diff --git a/arkindex/ponos/renderers.py b/arkindex/ponos/renderers.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe700619135885dd9126ad13c3b6cfd5693eb34 --- /dev/null +++ b/arkindex/ponos/renderers.py @@ -0,0 +1,19 @@ +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +from rest_framework.renderers import BaseRenderer + + +class PublicKeyPEMRenderer(BaseRenderer): + """ + A Django REST Framework renderer to serialize public keys as PEM. + """ + + media_type = "application/x-pem-file" + format = "pem" + + def render(self, data: ec.EllipticCurvePublicKey, *args, **kwargs) -> bytes: + assert isinstance(data, ec.EllipticCurvePublicKey) + return data.public_bytes( + encoding=Encoding.PEM, + format=PublicFormat.SubjectPublicKeyInfo, + ) diff --git a/arkindex/ponos/requirements-server.txt b/arkindex/ponos/requirements-server.txt new file mode 100644 index 0000000000000000000000000000000000000000..2eef61a81ffb64381dfded56bae190ac7c28d885 --- /dev/null +++ b/arkindex/ponos/requirements-server.txt @@ -0,0 +1,9 @@ +boto3==1.18.13 +cryptography==3.4.7 +Django==4.1.5 +django-enumfields==2.1.1 +djangorestframework==3.12.4 +djangorestframework-simplejwt==5.2.2 +drf-spectacular==0.18.2 +pytz==2022.6 +pyyaml==6.0 diff --git a/arkindex/ponos/serializer_fields.py b/arkindex/ponos/serializer_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..8b950c2dd63cfe692eef083a8d326d2791535cf8 --- /dev/null +++ b/arkindex/ponos/serializer_fields.py @@ -0,0 +1,85 @@ +import base64 + +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import ( + Encoding, + PublicFormat, + load_pem_public_key, +) +from enumfields import Enum +from rest_framework import serializers + + +class EnumField(serializers.ChoiceField): + """ + Serializes an enum field into a JSON-compatible value. + """ + + def __init__(self, enum: Enum, *args, **kwargs): + assert issubclass(enum, Enum) or not enum + self.enum = enum + choices = [(item.value, item.name) for item in self.enum] + super().__init__(choices, *args, **kwargs) + + def to_representation(self, obj) -> str: + if not isinstance(obj, self.enum): + obj = self.to_internal_value(obj) + return obj.value + + def to_internal_value(self, data): + assert self.enum is not None, "No enum set on EnumField" + try: + return self.enum(data) + except ValueError: + raise serializers.ValidationError( + "Value is not of type {}".format(self.enum.__name__) + ) + + +class PublicKeyField(serializers.CharField): + """ + An EC public key, serialized in PEM format + """ + + default_error_messages = { + "invalid_pem": "Incorrect PEM data", + "unsupported_algorithm": "Key algorithm is not supported", + "not_ec": "Key is not an EC public key", + } + + def to_internal_value(self, data) -> ec.EllipticCurvePublicKey: + data = super().to_internal_value(data) + try: + key = load_pem_public_key( + data.encode("utf-8"), + backend=default_backend(), + ) + except ValueError: + self.fail("invalid_pem") + except UnsupportedAlgorithm: + self.fail("unsupported_algorithm") + + if not isinstance(key, ec.EllipticCurvePublicKey): + self.fail("not_ec") + + return key + + def to_representation(self, key: ec.EllipticCurvePublicKey) -> str: + return key.public_bytes( + Encoding.PEM, + PublicFormat.SubjectPublicKeyInfo, + ).decode("utf-8") + + +class Base64Field(serializers.CharField): + """ + A base64-encoded bytestring. + """ + + def to_internal_value(self, data) -> bytes: + return base64.b64decode(super().to_internal_value(data)) + + def to_representation(self, obj: bytes) -> str: + return base64.b64encode(obj) diff --git a/arkindex/ponos/serializers.py b/arkindex/ponos/serializers.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b067468770072219d7117562cbb9dbdff2669a --- /dev/null +++ b/arkindex/ponos/serializers.py @@ -0,0 +1,626 @@ +import hashlib +import logging +import uuid + +from django.db import transaction +from django.shortcuts import reverse +from django.utils import timezone +from drf_spectacular.utils import extend_schema_field +from rest_framework import serializers +from rest_framework.exceptions import ValidationError + +from ponos.keys import check_agent_key +from ponos.models import ( + ACTIVE_STATES, + FINAL_STATES, + GPU, + ActionType, + Agent, + Artifact, + Farm, + Secret, + State, + Task, + Workflow, +) +from ponos.serializer_fields import Base64Field, EnumField, PublicKeyField +from ponos.signals import task_failure + +logger = logging.getLogger(__name__) + + +class FarmSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Farm` instance. + """ + + class Meta: + model = Farm + fields = ( + "id", + "name", + ) + + +class GPUSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.GPU` instance for public access. + """ + + class Meta: + model = GPU + fields = ( + "id", + "index", + "name", + "ram_total", + ) + read_only_fields = ( + "id", + "index", + "name", + "ram_total", + ) + + +class AgentLightSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Agent` instance for public access. + """ + + farm = FarmSerializer() + gpus = GPUSerializer(many=True, read_only=True) + + class Meta: + model = Agent + fields = ( + "id", + "farm", + "hostname", + "cpu_cores", + "cpu_frequency", + "ram_total", + "cpu_load", + "ram_load", + "last_ping", + "gpus", + ) + read_only_fields = ("id", "last_ping") + + +class TaskLightSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Task` instance without logs or agent information. + Used to list tasks inside a workflow. + """ + + state = EnumField(State) + url = serializers.HyperlinkedIdentityField(view_name="ponos:task-details") + tags = serializers.ListField( + child=serializers.CharField(), + allow_empty=True, + required=False, + ) + + class Meta: + model = Task + fields = ( + "id", + "url", + "run", + "depth", + "slug", + "state", + "parents", + "tags", + "shm_size", + ) + read_only_fields = ( + "id", + "url", + "run", + "depth", + "slug", + "parents", + "tags", + "shm_size", + ) + + +class TaskSerializer(TaskLightSerializer): + """ + Serializes a :class:`~ponos.models.Task` instance with logs and agent information. + """ + + logs = serializers.CharField(source="short_logs") + full_log = serializers.URLField(source="s3_logs_get_url") + agent = AgentLightSerializer() + gpu = GPUSerializer() + extra_files = serializers.DictField(default=dict) + + class Meta(TaskLightSerializer.Meta): + fields = TaskLightSerializer.Meta.fields + ( + "logs", + "full_log", + "agent", + "gpu", + "extra_files", + ) + read_only_fields = TaskLightSerializer.Meta.read_only_fields + ( + "logs", + "full_log", + "agent", + "gpu", + "extra_files", + ) + + def update(self, instance: Task, validated_data) -> Task: + """ + Perform updates on a task instance. + If a task gets marked as :attr:`~ponos.models.State.Completed`, this will set + its child tasks to :attr:`~ponos.models.State.Pending` if all of their dependencies are met. + """ + + # Free the GPU when a task gets a finished state update + if ( + instance.gpu + and "state" in validated_data + and validated_data["state"] not in ACTIVE_STATES + ): + validated_data["gpu"] = None + + instance = super().update(instance, validated_data) + + if instance.state not in FINAL_STATES: + return instance + + with transaction.atomic(using="default"): + # When a task is completed, set its children to be started next + # We need to use the default Database here to avoid a stale read + # as the parent instance has been updated just before + children = list( + instance.children.using("default").filter(state=State.Unscheduled) + ) + if instance.state == State.Completed: + for child in children: + if ( + child.parents.using("default") + .exclude(state=State.Completed) + .exists() + ): + # This child has another parent that is not completed + continue + logging.info(f"Setting child task {child} to Pending state") + child.state = State.Pending + child.save() + + # This task has no children: this might be the last task of the workflow, so the workflow might be finished. + # If all tasks in the current run are finished, update the completion date of the workflow. + if ( + not children + and not Task.objects.filter( + workflow_id=instance.workflow_id, run=instance.run + ) + .exclude(state__in=FINAL_STATES) + .exists() + ): + instance.workflow.finished = timezone.now() + instance.workflow.save() + + # We already checked earlier that the task was in a final state. + # If this state is both final and not completed, then we should trigger the task failure signal. + if instance.state != State.Completed: + task_failure.send_robust(self.__class__, task=instance) + + return instance + + +class TaskTinySerializer(TaskSerializer): + """ + Serializes a :class:`~ponos.models.Task` instance with only its state. + Used by humans to update a task. + """ + + state = EnumField(State) + + class Meta: + model = Task + fields = ("id", "state") + read_only_fields = ("id",) + + def update(self, instance: Task, validated_data) -> Task: + if "state" in validated_data: + new_state = validated_data["state"] + if new_state == State.Stopping and instance.state != State.Running: + raise ValidationError("You can only stop a 'Running' task") + + if new_state == State.Pending: + if instance.state not in FINAL_STATES: + raise ValidationError( + "You can only restart a task with a state equal to 'Completed', 'Failed', 'Error' or 'Stopped'." + ) + instance.agent = None + + if new_state in FINAL_STATES and new_state != State.Completed: + task_failure.send_robust(self.__class__, task=instance) + + return super().update(instance, validated_data) + + +class AgentStateSerializer(AgentLightSerializer): + """ + Serialize an :class:`~ponos.models.Agent` with its state information + And the GPU state + """ + + running_tasks_count = serializers.IntegerField(min_value=0) + + class Meta(AgentLightSerializer.Meta): + fields = AgentLightSerializer.Meta.fields + ("active", "running_tasks_count") + + +class AgentDetailsSerializer(AgentLightSerializer): + """ + Serialize an :class:`~ponos.models.Agent` with its running tasks + """ + + running_tasks = serializers.SerializerMethodField() + + class Meta(AgentLightSerializer.Meta): + fields = AgentLightSerializer.Meta.fields + ("active", "running_tasks") + + @extend_schema_field(TaskLightSerializer(many=True)) + def get_running_tasks(self, agent): + running_tasks = agent.tasks.filter(state=State.Running) + task_serializer = TaskLightSerializer( + instance=running_tasks, many=True, context=self.context + ) + return task_serializer.data + + +class GPUCreateSerializer(serializers.Serializer): + """ + Serialize a :class:`~ponos.models.GPU` attached to an Agent + """ + + id = serializers.UUIDField() + name = serializers.CharField() + index = serializers.IntegerField() + ram_total = serializers.IntegerField() + + +class AgentCreateSerializer(serializers.ModelSerializer): + """ + Serializer used to register a new :class:`~ponos.models.Agent`. + """ + + public_key = PublicKeyField(write_only=True) + derivation = Base64Field(write_only=True) + access_token = serializers.CharField(read_only=True, source="token.access_token") + refresh_token = serializers.CharField(read_only=True, source="token") + gpus = GPUCreateSerializer(many=True) + + class Meta: + model = Agent + fields = ( + "id", + "farm", + "access_token", + "refresh_token", + "public_key", + "derivation", + "hostname", + "cpu_cores", + "cpu_frequency", + "ram_total", + "cpu_load", + "ram_load", + "last_ping", + "gpus", + ) + read_only_fields = ( + "id", + "access_token", + "refresh_token", + "last_ping", + ) + + def create(self, validated_data): + """Create the required agent and its GPUs""" + gpus = validated_data.pop("gpus", None) + + with transaction.atomic(using="default"): + + # Create Agent as usual + agent = super().create(validated_data) + + # Create or update GPUs + # When an agent's private key is regenerated on the same host, an existing GPU UUID might be sent, + # so we need to handle existing GPUs to avoid errors. + for gpu_data in gpus: + GPU.objects.update_or_create( + id=gpu_data.pop("id"), defaults=dict(agent=agent, **gpu_data) + ) + + return agent + + def update(self, instance, validated_data): + """Create the required agent and its GPUs""" + gpus = validated_data.pop("gpus", None) + + with transaction.atomic(using="default"): + + # Update Agent as usual + agent = super().update(instance, validated_data) + + # Delete existing GPUs + agent.gpus.all().delete() + + # Create Gpus - the ID should not evolve as it's provided + # by the host + for gpu_data in gpus: + agent.gpus.create(**gpu_data) + + return agent + + def validate(self, data): + if not check_agent_key( + data["public_key"], data["derivation"], data["farm"].seed + ): + raise serializers.ValidationError("Key verification failed") + del data["derivation"] + + # Turn the public key back into a string to be saved in the agent + # TODO: Use a custom Django model field for this? + data["public_key"] = PublicKeyField().to_representation(data["public_key"]) + + # Generate the agent ID as a MD5 hash of its public key + data["id"] = uuid.UUID( + hashlib.md5(data["public_key"].encode("utf-8")).hexdigest() + ) + + # Set last ping as agent is about to be registered + data["last_ping"] = timezone.now() + + return data + + +class WorkflowSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Workflow` instance with its tasks. + """ + + tasks = TaskLightSerializer(many=True, read_only=True) + farm = FarmSerializer(read_only=True) + state = EnumField(State) + + class Meta: + model = Workflow + fields = ( + "id", + "created", + "finished", + "state", + "farm", + "tasks", + ) + + def validate_state(self, state: State) -> None: + """ + When performing updates, prevents updating to any state other than :attr:`~ponos.models.State.Stopping` + and restricts to :attr:`~ponos.models.State.Pending` or :attr:`~ponos.models.State.Running` workflows. + """ + if state != State.Stopping: + raise ValidationError("Can only change the state to 'stopping'") + if self.instance.state not in (State.Pending, State.Running): + raise ValidationError( + "Cannot stop a {} workflow".format(self.instance.state.value) + ) + + +class ActionSerializer(serializers.Serializer): + """ + Serializes an :const:`~ponos.models.Action` instance. + """ + + action = EnumField(ActionType) + task_id = serializers.SerializerMethodField() + + @extend_schema_field(serializers.UUIDField(allow_null=True)) + def get_task_id(self, obj): + if not obj.task: + return + return str(obj.task.id) + + +class AgentActionsSerializer(serializers.Serializer): + """ + Serializes multiple next actions for an agent. + """ + + actions = ActionSerializer(many=True, source="next_actions") + + +class TaskDefinitionSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Task` instance + to provide startup information to assigned agents. + """ + + env = serializers.DictField(default={}) + workflow_id = serializers.UUIDField() + agent_id = serializers.PrimaryKeyRelatedField(queryset=Agent.objects.all()) + image_artifact_url = serializers.SerializerMethodField() + s3_logs_put_url = serializers.SerializerMethodField() + extra_files = serializers.DictField(default={}) + + @extend_schema_field(serializers.URLField(allow_null=True)) + def get_image_artifact_url(self, task): + """Build url on the API to get a fresh download link""" + if not task.image_artifact: + return + return self.context["request"].build_absolute_uri( + reverse( + "ponos:task-artifact-download", + kwargs={ + "pk": task.image_artifact.task_id, + "path": task.image_artifact.path, + }, + ) + ) + + @extend_schema_field(serializers.URLField(allow_null=True)) + def get_s3_logs_put_url(self, obj): + if "request" not in self.context or self.context["request"].user != obj.agent: + return + return obj.s3_logs_put_url + + class Meta: + model = Task + fields = ( + "id", + "slug", + "image", + "command", + "env", + "shm_size", + "has_docker_socket", + "image_artifact_url", + "agent_id", + "s3_logs_put_url", + "parents", + "workflow_id", + "gpu_id", + "extra_files", + ) + read_only_fields = fields + + +class ArtifactSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Artifact` instance to allow + agents to create them, and users to list and retrieve them + """ + + s3_put_url = serializers.SerializerMethodField() + url = serializers.SerializerMethodField() + + class Meta: + model = Artifact + fields = ( + "id", + "path", + "size", + "content_type", + "url", + "s3_put_url", + "created", + "updated", + ) + read_only_fields = ( + "id", + "s3_put_url", + "created", + "updated", + ) + + def get_url(self, obj) -> str: + """Build url on the API to get a fresh download link""" + return self.context["request"].build_absolute_uri( + reverse( + "ponos:task-artifact-download", + kwargs={"pk": obj.task_id, "path": obj.path}, + ) + ) + + @extend_schema_field(serializers.URLField(allow_null=True)) + def get_s3_put_url(self, obj): + """Only add the PUT url to store the file for agents during creation""" + request = self.context.get("request") + if not request or request.method != "POST": + return + return obj.s3_put_url + + def validate_path(self, path): + """Check that no artifacts with this path already exist in DB""" + task = self.context["view"].get_task() + if task.artifacts.filter(path=path).exists(): + raise ValidationError("An artifact with this path already exists") + + return path + + +class NewTaskSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Task` instance to permit creation by a parent. + """ + + workflow_id = serializers.UUIDField() + command = serializers.CharField(required=False) + env = serializers.DictField( + child=serializers.CharField(), required=False, default={} + ) + has_docker_socket = serializers.BooleanField(default=False) + + class Meta: + model = Task + fields = ( + "id", + "workflow_id", + "slug", + "parents", + "image", + "command", + "env", + "run", + "depth", + "has_docker_socket", + ) + read_only_fields = ( + "id", + "run", + "depth", + ) + + def validate(self, data): + parents = data["parents"] + + ids = {parent.workflow.id for parent in parents} + if len(ids) != 1 or str(ids.pop()) != str(data["workflow_id"]): + raise ValidationError( + "All parents must be in the same workflow as the child task" + ) + + runs = {parent.run for parent in parents} + if len(runs) != 1: + raise ValidationError( + "All parents must have the same run in the given workflow" + ) + data["run"] = runs.pop() + + if Task.objects.filter( + workflow_id=data["workflow_id"], run=data["run"], slug=data["slug"] + ).exists(): + raise ValidationError( + f"A task with the `{data['slug']}` slug already exists in run {data['run']}." + ) + + data["depth"] = max(parent.depth for parent in parents) + 1 + + return super().validate(data) + + +class ClearTextSecretSerializer(serializers.ModelSerializer): + """ + Serializes a :class:`~ponos.models.Secret` instance with its content in cleartext. + """ + + content = serializers.SerializerMethodField(read_only=True) + + def get_content(self, secret) -> str: + return secret.decrypt() + + class Meta: + model = Secret + fields = ("id", "name", "content") + read_only_fields = ("id", "content") diff --git a/arkindex/ponos/signals.py b/arkindex/ponos/signals.py new file mode 100644 index 0000000000000000000000000000000000000000..54e5ccea2cd85ad98da3c860f6e7afdfa4afd35c --- /dev/null +++ b/arkindex/ponos/signals.py @@ -0,0 +1,14 @@ +from django.dispatch import Signal + +# Sphinx does not detect the docstring for signals when using `task_failure.__doc__`, +# because this is an instance, not a class, and this does not look like a constant. +# We would then normally have to add a string right below the signal. +# This string gets picked up successfully by Sphinx, but pre-commit fails due to a +# false positive: https://github.com/pre-commit/pre-commit-hooks/issues/159 +# So we use an alternative, lesser known syntax that Sphinx supports but most other +# tools (such as IDE autocompletions) don't, with `#:` comments. + +#: A task has reached a final state that is not :attr:`~State.Completed`. +#: +#: This signal will be called with the related :class:`Task` instance as the `task` keyword argument. +task_failure = Signal() diff --git a/arkindex/ponos/tests/__init__.py b/arkindex/ponos/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/arkindex/ponos/tests/fixtures/farm.yml b/arkindex/ponos/tests/fixtures/farm.yml new file mode 100644 index 0000000000000000000000000000000000000000..41d5d3d0e46b303eee88dc7ab7dcb95ee95de8ce --- /dev/null +++ b/arkindex/ponos/tests/fixtures/farm.yml @@ -0,0 +1,3 @@ +url: http://localhost:8000/ +farm_id: farm_id +seed: seed_id diff --git a/arkindex/ponos/tests/fixtures/ponos.key b/arkindex/ponos/tests/fixtures/ponos.key new file mode 100644 index 0000000000000000000000000000000000000000..851b7ad4cc38ac575e52679a40dc1f31c8ccb93e --- /dev/null +++ b/arkindex/ponos/tests/fixtures/ponos.key @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDA4FrXdr6wptT05sXQS +3iNqcAPhHe1GCCkgkkqU9iM9Q7reFofdYi6n0GxBxRKQ4QehZANiAAQMd1mENyaI +h9IJwxYTPTUpeS/uB9WbdLpXRqX3lrRfHEEA+DPiN/tMtcHub0tgPuU5+NuRkz85 +1GdhEBeJFv5EkR3ahaAQVsR9NTFu7vw9rpu3V08Rf2PbkVUsf8cXlME= +-----END PRIVATE KEY----- diff --git a/arkindex/ponos/tests/fixtures/rsa.key b/arkindex/ponos/tests/fixtures/rsa.key new file mode 100644 index 0000000000000000000000000000000000000000..471d98d79ab6f887c1eb3cf059a3c0997356acba --- /dev/null +++ b/arkindex/ponos/tests/fixtures/rsa.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDDdwPF9RBxkInF +omGXksi4VXvjxuELTh+Q20ooX+Msd/fVL/rYOFu5sqD+7z1TiYnI/BpkG8/xGuF4 +to8dyZMiNaiBhBh71oqDCXQyRxTeTn0zt0/2dAh5RMvmD/M5bfma3sKoj1HlVK5W +tNN02eyZl7UVRsUycvZTFD4il7cYiWjoCkwQYmQt6sfjfQXb0a4xYX4IrZyLFElR +k73uVdkxqiFHQ5iRgZSxuvLjXVLf6mwydw7lMQfIGIrgf6k5Lsgd6vpDmDB8o5Aq +4LAt/3UAo/IlaUmo6L5aQTQOtE73sYB7kUZnTcKEJU8WNh48qeEPfLtzp+hy+o9z +NLqPEzTdAgMBAAECggEAK5q1QMf2rx7rXSdoVgPsuxJ23M1VDsySxiHlXTRm3Vi1 +9N0LPIj2DWsH91cvex9HmYqD6Kk7rjGs8tzm+GIrbIpD2QC65YVqyOim/0BUK3Of +ApZ2RCiGa2cphV0xiTI7aI0hJ7ExN9O4QKd+NKcY3Pq27tQm0aZLxSTS56yor9Qz +htDvlAZmi+BA6uD9xrJyJrwoz+AIiLgJELzrUZFs6h81PD6Jru6kr4kfeCc82IK5 +woqWyc0UkE5v1JilzvGp1DENq5hPoHcTOqgVw0I9lUGwWOX37VdV3hRr766iNwZg +4EehymKwuXNCBDdnSq0w/7xidYm6ontLUImLL6CpEQKBgQDvSvD44PuKy0y/pDXy +9MAA/PKBgswZ+UktRxzr2n84HJSdyW/kkPkJRowwr01XumxRId2CJC5a/X0r0Zrj +BHPCqQYldhYuJTA24jch4BLzCn3pWAWbc9yPdHNZziYQqWVleM8ypkrgt3UsCaFL +LgCOSO57GVqErMkMqNJLvRRnFwKBgQDRHLSiUJEJ9C8A1ohpIOtHfAuJPBzlzfLc +kfV0Lc+C+P7U44fQn4VUVAmVkJv5wAK5Q+8L8Q+XVafoAV4LED+YcLYC4N/gtVQL +t9RX9srxecS4I8eLaK5Fot/1N1XTvpAMC6T5UCquUSLNraafnFiSW0xxxJ8N+GEJ +jW8LjTe8KwKBgQDqV3T39rTAruoBf9pJjYD/NrhzNtmU0jnkupDLNVaDaBHvGFeY ++pS4jbs67mKK+ImdRtH74lz3ROoxYHsTucd4KjlXtHZySH8YMJ+XcC5+j5bRTx9m +pqeoYX2ZxDYo+QvQvOgFDS+lNGTudJvd2TY4IZpTOXgZGHFoEWipPYlejwKBgAgN +rvdBWxSjDtxdZsuFtQn/wQH8CrDfCadtB6L90KweotHYIXbrbdsdkXDtLNSljHVO +JHq1QgB2EA1jYBfU/F4GmTvrJTQmR6Jb5hWtL4u1QNpGpny7/1o3N6DeDLQm9q1A +FY50g/BKt6hsM6qZ/t9EHOGUzPtgwXv4snojai4ZAoGAW79wPuo+3Dpf3j/cWAQE +q5ud5XxUqKm89TPmJZNjcIKxqpoxce/C/Scm6S0Dledb9RjhSmjVgoLul1z/x884 +3DAU7c/Som24XQUYMozf0lRFXDtDImAlBXO4uhJFQPWRvgtX5IUi94QxMyAgfIIp +VO0VXwpfmRvLdK+dYiE3vSY= +-----END PRIVATE KEY----- diff --git a/arkindex/ponos/tests/fixtures/server.key b/arkindex/ponos/tests/fixtures/server.key new file mode 100644 index 0000000000000000000000000000000000000000..b032349c8e867a193efff59e95e028b98da2810f --- /dev/null +++ b/arkindex/ponos/tests/fixtures/server.key @@ -0,0 +1,5 @@ +-----BEGIN PUBLIC KEY----- +MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEv7E3hHknMuNNQPiqs1QRVm+U/8fr3s2Z +6brdaYgVO2W9yiNA375l3M37EfC4yqnymTNmCOt/cEv4ZuA7d4WN2j6VVHrJR3C0 +azXFmrLZeClPxFL0PwCyDO/esb5m1I0G +-----END PUBLIC KEY----- diff --git a/arkindex/ponos/tests/test_admin.py b/arkindex/ponos/tests/test_admin.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8def09605229cb920d98550bff9d84ab576d46 --- /dev/null +++ b/arkindex/ponos/tests/test_admin.py @@ -0,0 +1,33 @@ +from django.test import TestCase + +from ponos.admin import ClearTextSecretForm, SecretAdmin +from ponos.models import Secret, encrypt + + +class TestAdmin(TestCase): + @classmethod + def setUpTestData(cls): + cls.nonce = b"1337" * 4 + cls.encrypted_content = encrypt(cls.nonce, "Shhhh") + cls.secret = Secret.objects.create( + name="important_secret", nonce=cls.nonce, content=cls.encrypted_content + ) + + def test_admin_read_secret(self): + """ + Admin form display the decrypted content of the stored secrets + """ + self.assertEqual(self.secret.content, b"\xa3\xda\x9b\x91#") + form = ClearTextSecretForm(instance=self.secret) + self.assertEqual(form.initial.get("content"), "Shhhh") + + def test_admin_updates_secret(self): + secret_admin = SecretAdmin(model=Secret, admin_site=None) + form = ClearTextSecretForm( + data={"id": self.secret.id, "name": self.secret.name, "content": "Ah"}, + instance=self.secret, + ) + self.assertEqual(self.secret.content, b"\xa3\xda\x9b\x91#") + secret_admin.save_form(request=None, form=form, change=True) + self.assertEqual(self.secret.content, b"\xb1\xda") + self.assertEqual(b"\xb1\xda", encrypt(self.secret.nonce, "Ah")) diff --git a/arkindex/ponos/tests/test_api.py b/arkindex/ponos/tests/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8c0a3472522e191f0f40710172dda2a0243241 --- /dev/null +++ b/arkindex/ponos/tests/test_api.py @@ -0,0 +1,1494 @@ +import base64 +import hashlib +import random +import uuid +from io import BytesIO +from unittest.mock import call, patch + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.hashes import SHA256 +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +from django.conf import settings +from django.contrib.auth.models import User +from django.test import override_settings +from django.urls import reverse +from django.utils import timezone +from rest_framework import status +from rest_framework.test import APITestCase + +from ponos.api import timezone as api_tz +from ponos.authentication import AgentUser +from ponos.models import ( + FINAL_STATES, + GPU, + Agent, + Farm, + Secret, + State, + Task, + Workflow, + encrypt, +) +from tests.helpers import build_public_key + +RECIPE = """ +env: + top_env_variable: workflow_variable + test: test_workflow +tasks: + second-task: + image: alpine + command: echo hello + parents: + - first-task + tags: + - some-tag + first-task: + image: hello-world +""" + + +# Helper to format a datetime as output by DRF +def str_date(d): + return d.isoformat().replace("+00:00", "Z") + + +@override_settings(PONOS_LOG_TAIL=42) +class TestAPI(APITestCase): + @classmethod + def setUpTestData(cls): + + super().setUpTestData() + cls.farm = Farm.objects.create(name="Wheat farm") + pubkey = build_public_key() + cls.agent = AgentUser.objects.create( + id=uuid.UUID(hashlib.md5(pubkey.encode("utf-8")).hexdigest()), + farm=cls.farm, + hostname="ghostname", + cpu_cores=2, + cpu_frequency=1e9, + public_key=pubkey, + ram_total=2e9, + last_ping=timezone.now(), + ) + cls.workflow = Workflow.objects.create(farm=cls.farm, recipe=RECIPE) + cls.workflow.start() + cls.task1, cls.task2 = cls.workflow.tasks.all() + cls.gpu1 = cls.agent.gpus.create( + id="108c6524-c63a-4811-bbed-9723d32a0688", + name="GPU1", + index=0, + ram_total=2 * 1024 * 1024 * 1024, + ) + cls.gpu2 = cls.agent.gpus.create( + id="f30d1407-92bb-484b-84b0-0b8bae41ca91", + name="GPU2", + index=1, + ram_total=8 * 1024 * 1024 * 1024, + ) + + def _build_workflow_response(self, response, **kwargs): + """ + Return the serialization of the test workflow. + Some parameters may be updated with kwargs. + """ + self.task1.refresh_from_db() + self.task2.refresh_from_db() + self.workflow.refresh_from_db() + data = { + "id": str(self.workflow.id), + "created": self.workflow.created.strftime("%G-%m-%dT%H:%M:%S.%fZ"), + "finished": self.workflow.finished.strftime("%G-%m-%dT%H:%M:%S.%fZ") + if self.workflow.finished + else None, + "state": self.workflow.state.value, + "farm": { + "id": str(self.workflow.farm.id), + "name": self.workflow.farm.name, + }, + "tasks": [ + { + "id": str(self.task1.id), + "run": 0, + "depth": 0, + "slug": "first-task", + "state": self.task1.state.value, + "parents": [], + "tags": [], + "shm_size": self.task1.shm_size, + "url": response.wsgi_request.build_absolute_uri( + reverse("ponos:task-details", args=[self.task1.id]) + ), + }, + { + "id": str(self.task2.id), + "run": 0, + "depth": 1, + "slug": "second-task", + "state": self.task2.state.value, + "parents": [str(self.task1.id)], + "tags": ["some-tag"], + "shm_size": self.task2.shm_size, + "url": response.wsgi_request.build_absolute_uri( + reverse("ponos:task-details", args=[self.task2.id]) + ), + }, + ], + } + data.update(kwargs) + return data + + def test_workflow_details(self): + with self.assertNumQueries(4): + resp = self.client.get( + reverse("ponos:workflow-details", args=[self.workflow.id]) + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + + self.assertDictEqual( + data, self._build_workflow_response(resp, state="unscheduled") + ) + + def test_stop_workflow(self): + self.task1.state = State.Pending + self.task1.save() + resp = self.client.patch( + reverse("ponos:workflow-details", kwargs={"pk": self.workflow.id}), + {"state": State.Stopping.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + + self.assertDictEqual( + data, self._build_workflow_response(resp, state="stopping") + ) + self.assertEqual(self.task1.state.value, "stopping") + self.assertEqual(self.task2.state.value, "stopped") + + def test_stop_finished_workflow(self): + self.task1.state = State.Completed + self.task1.save() + self.task2.state = State.Completed + self.task2.save() + resp = self.client.patch( + reverse("ponos:workflow-details", kwargs={"pk": self.workflow.id}), + {"state": State.Stopping.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + data = resp.json() + self.assertEqual(data, {"state": ["Cannot stop a completed workflow"]}) + + def test_update_workflow_forbidden_fields(self): + """ + Only workflow state can be updated + """ + new_id = uuid.uuid4() + new_farm_id = uuid.uuid4() + resp = self.client.patch( + reverse("ponos:workflow-details", kwargs={"pk": self.workflow.id}), + { + "id": new_id, + "recipe": {}, + "updated": "2000-01-00T00:00:00.000000Z", + "tasks": [], + "farm": new_farm_id, + }, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertDictEqual(resp.json(), self._build_workflow_response(resp)) + self.assertEqual(self.workflow.tasks.count(), 2) + self.assertNotEqual(self.workflow.id, new_id) + self.assertNotEqual(self.workflow.farm_id, new_farm_id) + + def test_change_state_workflow(self): + resp = self.client.patch( + reverse("ponos:workflow-details", kwargs={"pk": self.workflow.id}), + {"state": State.Completed.value}, + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + data = resp.json() + self.assertEqual(data, {"state": ["Can only change the state to 'stopping'"]}) + + @patch("ponos.aws.s3") + @patch("ponos.models.s3") + def test_task_details(self, s3_mock, aws_s3_mock): + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Failed successfully") + } + aws_s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + resp = self.client.get(reverse("ponos:task-details", args=[self.task1.id])) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + + self.assertDictEqual( + data, + { + "id": str(self.task1.id), + "run": 0, + "depth": 0, + "slug": "first-task", + "state": "unscheduled", + "parents": [], + "tags": [], + "logs": "Failed successfully", + "url": resp.wsgi_request.build_absolute_uri( + reverse("ponos:task-details", args=[self.task1.id]) + ), + "full_log": "http://somewhere", + "extra_files": {}, + "agent": None, + "gpu": None, + "shm_size": None, + }, + ) + + self.assertEqual(s3_mock.Object.call_count, 1) + self.assertEqual(s3_mock.Object().get.call_count, 1) + self.assertEqual(s3_mock.Object().get.call_args, call(Range="bytes=-42")) + self.assertEqual(aws_s3_mock.meta.client.generate_presigned_url.call_count, 1) + self.assertEqual( + aws_s3_mock.meta.client.generate_presigned_url.call_args, + call("get_object", Params={"Bucket": "ponos", "Key": "somelog"}), + ) + + @patch("ponos.aws.s3") + @patch("ponos.models.s3") + @patch("ponos.serializers.timezone") + def test_task_final_updates_workflow_finished( + self, timezone_mock, s3_mock, aws_s3_mock + ): + """ + Updating a task to a final state should update the workflow's completion date + if it causes the whole workflow to be finished + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Failed successfully") + } + aws_s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + expected_datetime = timezone.datetime(3000, 1, 1, 12).astimezone() + timezone_mock.now.return_value = expected_datetime + + self.task1.agent = self.agent + self.task1.state = State.Completed + self.task1.save() + self.task2.agent = self.agent + self.task2.state = State.Running + self.task2.save() + + for state in FINAL_STATES: + with self.subTest(state=state): + self.workflow.finished = None + self.workflow.save() + + resp = self.client.patch( + reverse("ponos:task-details", args=[self.task2.id]), + data={"state": state.value}, + HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + self.task2.refresh_from_db() + self.assertEqual(self.task2.state, state) + + self.workflow.refresh_from_db() + self.assertEqual(self.workflow.finished, expected_datetime) + + @patch("ponos.signals.task_failure.send_robust") + @patch("ponos.aws.s3") + @patch("ponos.models.s3") + def test_task_final_not_completed_triggers_signal( + self, s3_mock, aws_s3_mock, send_mock + ): + """ + Updating a task to a final state that is not `State.Completed` + should trigger the task_failure signal + """ + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Failed successfully") + } + aws_s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + self.task1.agent = self.agent + self.task1.save() + + states_expecting_signal = {State.Error, State.Failed, State.Stopped} + + for state in State: + with self.subTest(state=state): + resp = self.client.patch( + reverse("ponos:task-details", args=[self.task1.id]), + data={"state": state.value}, + HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + if state in states_expecting_signal: + self.assertEqual(send_mock.call_count, 1) + _, kwargs = send_mock.call_args + self.assertDictEqual(kwargs, {"task": self.task1}) + else: + self.assertFalse(send_mock.called) + + send_mock.reset_mock() + + @patch("django.contrib.auth") + @patch("ponos.aws.s3") + @patch("ponos.models.s3") + def test_task_update_gpu(self, s3_mock, aws_s3_mock, auth_mock): + auth_mock.signals.user_logged_in = None + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"Some nice GPU task output...") + } + aws_s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + # All GPUs are available + self.assertFalse(GPU.objects.filter(task__isnull=False).exists()) + + # The task is running on a GPU + self.task1.gpu = self.agent.gpus.first() + self.task1.agent = self.agent + self.task1.state = State.Running + self.task1.save() + + # One GPU is assigned + self.assertTrue(GPU.objects.filter(task__isnull=False).exists()) + + self.maxDiff = None + resp = self.client.get(reverse("ponos:task-details", args=[self.task1.id])) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + + # Last ping is updated in the API call + self.agent.refresh_from_db() + + self.assertDictEqual( + data, + { + "id": str(self.task1.id), + "run": 0, + "depth": 0, + "extra_files": {}, + "slug": "first-task", + "state": "running", + "parents": [], + "shm_size": None, + "tags": [], + "logs": "Some nice GPU task output...", + "url": resp.wsgi_request.build_absolute_uri( + reverse("ponos:task-details", args=[self.task1.id]) + ), + "full_log": "http://somewhere", + "agent": { + "cpu_cores": 2, + "cpu_frequency": 1000000000, + "cpu_load": None, + "farm": {"id": str(self.agent.farm_id), "name": "Wheat farm"}, + "gpus": [ + { + "id": "108c6524-c63a-4811-bbed-9723d32a0688", + "index": 0, + "name": "GPU1", + "ram_total": 2147483648, + }, + { + "id": "f30d1407-92bb-484b-84b0-0b8bae41ca91", + "index": 1, + "name": "GPU2", + "ram_total": 8589934592, + }, + ], + "hostname": "ghostname", + "id": str(self.agent.id), + "last_ping": str_date(self.agent.last_ping), + "ram_load": None, + "ram_total": 2000000000, + }, + "gpu": { + "id": "108c6524-c63a-4811-bbed-9723d32a0688", + "index": 0, + "name": "GPU1", + "ram_total": 2147483648, + }, + }, + ) + + # Now let's complete that task, and check that the GPU is not assigned anymore + resp = self.client.patch( + reverse("ponos:task-details", args=[self.task1.id]), + data={"state": "completed"}, + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + + self.assertDictEqual( + data, + { + "id": str(self.task1.id), + "run": 0, + "depth": 0, + "extra_files": {}, + "slug": "first-task", + "state": "completed", + "parents": [], + "shm_size": None, + "tags": [], + "logs": "", + "url": resp.wsgi_request.build_absolute_uri( + reverse("ponos:task-details", args=[self.task1.id]) + ), + "full_log": "http://somewhere", + "agent": { + "cpu_cores": 2, + "cpu_frequency": 1000000000, + "cpu_load": None, + "farm": {"id": str(self.agent.farm_id), "name": "Wheat farm"}, + "gpus": [ + { + "id": "108c6524-c63a-4811-bbed-9723d32a0688", + "index": 0, + "name": "GPU1", + "ram_total": 2147483648, + }, + { + "id": "f30d1407-92bb-484b-84b0-0b8bae41ca91", + "index": 1, + "name": "GPU2", + "ram_total": 8589934592, + }, + ], + "hostname": "ghostname", + "id": str(self.agent.id), + "last_ping": str_date(self.agent.last_ping), + "ram_load": None, + "ram_total": 2000000000, + }, + # GPU is now un-assigned + "gpu": None, + }, + ) + + # All GPUs are available + self.assertFalse(GPU.objects.filter(task__isnull=False).exists()) + + @patch("ponos.aws.s3") + @patch("ponos.models.s3") + def test_task_logs_unicode_error(self, s3_mock, aws_s3_mock): + """ + Ensure the Task.short_logs property is able to handle sliced off Unicode characters. + Since we fetch the latest logs from S3 using `Range: bytes=-N`, sometimes the characters + can have a missing byte and cause a UnicodeDecodeError, such as U+00A0 (non-breaking space) + or U+00A9 (copyright symbol). + """ + with self.assertRaises(UnicodeDecodeError): + b"\xa0Failed successfully".decode("utf-8") + + s3_mock.Object.return_value.bucket_name = "ponos" + s3_mock.Object.return_value.key = "somelog" + s3_mock.Object.return_value.get.return_value = { + "Body": BytesIO(b"\xa0Failed successfully") + } + aws_s3_mock.meta.client.generate_presigned_url.return_value = "http://somewhere" + + resp = self.client.get(reverse("ponos:task-details", args=[self.task1.id])) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + + self.assertEqual(data["logs"], "�Failed successfully") + + @patch("ponos.serializers.check_agent_key") + @patch("ponos.serializers.timezone") + def test_agent_create(self, timezone_mock, check_mock): + check_mock.return_value = True + timezone_mock.now.return_value = timezone.datetime(2000, 1, 1, 12).astimezone() + + resp = self.client.post( + reverse("ponos:agent-register"), + { + "hostname": "toastname", + "cpu_cores": 42, + "cpu_frequency": 1337e6, + "ram_total": 16e9, + "farm": str(self.farm.id), + "public_key": build_public_key(), + "derivation": "{:032x}".format(random.getrandbits(128)), + "gpus": [ + { + "id": "b6b8d7c1-c6bd-4de6-ae92-866a270be36f", + "name": "A", + "index": 0, + "ram_total": 512 * 1024 * 1024, + }, + { + "id": "14fe053a-8014-46a3-ad7b-99fa4389d74c", + "name": "B", + "index": 1, + "ram_total": 2 * 1024 * 1024 * 1024, + }, + ], + }, + format="json", + ) + + self.assertEqual(resp.status_code, status.HTTP_201_CREATED) + self.assertEqual(check_mock.call_count, 1) + data = resp.json() + agent = Agent.objects.get(id=data["id"]) + self.assertEqual(agent.hostname, "toastname") + self.assertEqual(agent.cpu_cores, 42) + self.assertEqual(agent.cpu_frequency, 1337e6) + self.assertEqual(agent.gpus.count(), 2) + gpu_a = agent.gpus.get(name="A") + gpu_b = agent.gpus.get(name="B") + self.assertEqual(gpu_a.index, 0) + self.assertEqual(gpu_b.index, 1) + self.assertEqual(gpu_a.ram_total, 512 * 1024 * 1024) + self.assertEqual(gpu_b.ram_total, 2 * 1024 * 1024 * 1024) + + self.assertIn("access_token", data) + self.assertIn("refresh_token", data) + del data["access_token"] + del data["refresh_token"] + + self.assertDictEqual( + data, + { + "id": str(agent.id), + "hostname": "toastname", + "cpu_cores": 42, + "cpu_frequency": int(1337e6), + "farm": str(self.farm.id), + "ram_total": 16_000_000_000, + "cpu_load": None, + "ram_load": None, + "last_ping": "2000-01-01T12:00:00Z", + "gpus": [ + { + "id": "b6b8d7c1-c6bd-4de6-ae92-866a270be36f", + "index": 0, + "name": "A", + "ram_total": 512 * 1024 * 1024, + }, + { + "id": "14fe053a-8014-46a3-ad7b-99fa4389d74c", + "index": 1, + "name": "B", + "ram_total": 2 * 1024 * 1024 * 1024, + }, + ], + }, + ) + + @patch("ponos.serializers.check_agent_key") + @patch("ponos.serializers.timezone") + def test_agent_update(self, timezone_mock, check_mock): + check_mock.return_value = True + timezone_mock.now.return_value = timezone.datetime(2000, 1, 1, 12).astimezone() + + # Still a POST, but on an existing agent + resp = self.client.post( + reverse("ponos:agent-register"), + { + "hostname": self.agent.hostname, + "cpu_cores": 12, + "cpu_frequency": 1e9, + "ram_total": 32e9, + "farm": str(self.farm.id), + "public_key": self.agent.public_key, + "derivation": "{:032x}".format(random.getrandbits(128)), + "gpus": [ + { + "id": "deadbeef-c6bd-4de6-ae92-866a270be36f", + "name": "new gpu", + "index": 0, + "ram_total": 32 * 1024 * 1024 * 1024, + }, + ], + }, + format="json", + ) + + self.assertEqual(resp.status_code, status.HTTP_201_CREATED) + self.assertEqual(check_mock.call_count, 1) + data = resp.json() + self.agent.refresh_from_db() + self.assertEqual(self.agent.hostname, "ghostname") + self.assertEqual(self.agent.cpu_cores, 12) + self.assertEqual(self.agent.cpu_frequency, 1e9) + self.assertEqual(self.agent.gpus.count(), 1) + gpu = self.agent.gpus.get(name="new gpu") + self.assertEqual(gpu.index, 0) + self.assertEqual(gpu.ram_total, 32 * 1024 * 1024 * 1024) + self.assertEqual(gpu.id, uuid.UUID("deadbeef-c6bd-4de6-ae92-866a270be36f")) + + self.assertIn("access_token", data) + self.assertIn("refresh_token", data) + del data["access_token"] + del data["refresh_token"] + + self.assertDictEqual( + data, + { + "cpu_cores": 12, + "cpu_frequency": 1000000000, + "cpu_load": None, + "farm": str(self.farm.id), + "gpus": [ + { + "id": "deadbeef-c6bd-4de6-ae92-866a270be36f", + "index": 0, + "name": "new gpu", + "ram_total": 34359738368, + } + ], + "hostname": "ghostname", + "id": str(self.agent.id), + "last_ping": "2000-01-01T12:00:00Z", + "ram_load": None, + "ram_total": 32000000000, + }, + ) + + @patch("ponos.serializers.check_agent_key") + @patch("ponos.serializers.timezone") + def test_agent_create_existing_gpu(self, timezone_mock, check_mock): + check_mock.return_value = True + timezone_mock.now.return_value = timezone.datetime(2000, 1, 1, 12).astimezone() + + resp = self.client.post( + reverse("ponos:agent-register"), + { + "hostname": "toastname", + "cpu_cores": 42, + "cpu_frequency": 1337e6, + "ram_total": 16e9, + "farm": str(self.farm.id), + "public_key": build_public_key(), + "derivation": "{:032x}".format(random.getrandbits(128)), + "gpus": [ + { + "id": str(self.gpu1.id), + "name": "A", + "index": 0, + "ram_total": 512 * 1024 * 1024, + }, + { + "id": str(self.gpu2.id), + "name": "B", + "index": 1, + "ram_total": 2 * 1024 * 1024 * 1024, + }, + ], + }, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_201_CREATED) + self.assertEqual(check_mock.call_count, 1) + data = resp.json() + + self.gpu1.refresh_from_db() + self.gpu2.refresh_from_db() + self.assertEqual(self.gpu1.name, "A") + self.assertEqual(self.gpu2.name, "B") + self.assertEqual(self.gpu1.ram_total, 512 * 1024**2) + self.assertEqual(self.gpu2.ram_total, 2 * 1024**3) + + new_agent = Agent.objects.get(id=data["id"]) + self.assertEqual(self.gpu1.agent, new_agent) + self.assertEqual(self.gpu2.agent, new_agent) + + # Existing agent no longer has any assigned GPUs, since the new agent stole them + self.assertFalse(self.agent.gpus.exists()) + + @patch("ponos.serializers.check_agent_key") + @patch("ponos.serializers.timezone") + def test_agent_update_existing_gpu(self, timezone_mock, check_mock): + check_mock.return_value = True + timezone_mock.now.return_value = timezone.datetime(2000, 1, 1, 12).astimezone() + self.assertTrue(self.agent.gpus.count(), 2) + + # Still a POST, but on an existing agent + resp = self.client.post( + reverse("ponos:agent-register"), + { + "hostname": self.agent.hostname, + "cpu_cores": 12, + "cpu_frequency": 1e9, + "ram_total": 32e9, + "farm": str(self.farm.id), + "public_key": self.agent.public_key, + "derivation": "{:032x}".format(random.getrandbits(128)), + "gpus": [ + { + "id": "108c6524-c63a-4811-bbed-9723d32a0688", + "name": "GPU1", + "index": 0, + "ram_total": 2 * 1024 * 1024 * 1024, + }, + { + "id": "f30d1407-92bb-484b-84b0-0b8bae41ca91", + "name": "GPU2", + "index": 1, + "ram_total": 8 * 1024 * 1024 * 1024, + }, + ], + }, + format="json", + ) + + self.assertEqual(resp.status_code, status.HTTP_201_CREATED) + + self.agent.refresh_from_db() + self.assertTrue(self.agent.gpus.count(), 2) + + def test_agent_create_bad_seed(self): + # Build keys + server_public_key = ec.generate_private_key( + ec.SECP384R1(), + default_backend(), + ).public_key() + + agent_private_key = ec.generate_private_key( + ec.SECP384R1(), + default_backend(), + ) + agent_public_bytes = agent_private_key.public_key().public_bytes( + Encoding.PEM, + PublicFormat.SubjectPublicKeyInfo, + ) + + # Add 1 to the farm's seed + wrong_seed = "{:064x}".format(int(self.farm.seed, 16) + 1) + + # Perform derivation with the wrong seed + shared_key = agent_private_key.exchange(ec.ECDH(), server_public_key) + derived_key = HKDF( + algorithm=SHA256(), + backend=default_backend(), + length=32, + salt=None, + info=wrong_seed.encode("utf-8"), + ).derive(shared_key) + + resp = self.client.post( + reverse("ponos:agent-register"), + { + "hostname": "toastname", + "cpu_cores": 42, + "cpu_frequency": 1337e6, + "ram_total": 16e9, + "farm": str(self.farm.id), + "public_key": agent_public_bytes.decode("utf-8"), + "derivation": base64.b64encode(derived_key).decode("utf-8"), + "gpus": [], + }, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + resp.json(), {"non_field_errors": ["Key verification failed"]} + ) + + def test_agent_actions_requires_token(self): + resp = self.client.get(reverse("ponos:agent-actions")) + self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_agent_actions_no_user(self): + user = User.objects.create_superuser("root", "root@root.fr", "hunter2") + self.client.force_login(user) + try: + resp = self.client.get(reverse("ponos:agent-actions")) + self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) + finally: + self.client.logout() + + def test_agent_actions_query_params_required(self): + resp = self.client.get( + reverse("ponos:agent-actions"), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + resp.json(), + { + "cpu_load": ["This query parameter is required."], + "ram_load": ["This query parameter is required."], + }, + ) + + def test_agent_actions_query_params_validation(self): + resp = self.client.get( + reverse("ponos:agent-actions"), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + data={"cpu_load": "high", "ram_load": "low"}, + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + resp.json(), + { + "cpu_load": ["A valid number is required."], + "ram_load": ["A valid number is required."], + }, + ) + + def test_agent_non_pending_actions(self): + """ + Only pending tasks may be retrieved as new actions + """ + self.workflow.tasks.update(state=State.Error) + with self.assertNumQueries(7): + resp = self.client.get( + reverse("ponos:agent-actions"), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + data={"cpu_load": 0.9, "ram_load": 0.49}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertDictEqual(resp.json(), {"actions": []}) + + def test_agent_actions(self): + """ + Agent may retrieve one tasks using the API due to its resources limitations + """ + self.workflow.tasks.update(state=State.Pending) + now = timezone.now() + + with patch.object(api_tz, "now") as api_now_mock: + api_now_mock.return_value = now + with self.assertNumQueries(11): + resp = self.client.get( + reverse("ponos:agent-actions"), + HTTP_AUTHORIZATION="Bearer {}".format( + self.agent.token.access_token + ), + data={"cpu_load": 0.9, "ram_load": 0.49}, + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + self.agent.refresh_from_db() + self.assertEqual(self.agent._estimate_new_tasks_cost(), 0.99) + + self.assertDictEqual( + resp.json(), + { + "actions": [ + { + "action": "start_task", + "task_id": str(self.task2.id), + } + ] + }, + ) + # Agent load and last ping attributes have been updated + self.assertEqual( + (self.agent.cpu_load, self.agent.ram_load, self.agent.last_ping), + (0.9, 0.49, now), + ) + + def test_artifacts_list(self): + url = reverse("ponos:task-artifacts", args=[self.task1.id]) + + # Not accessible when logged out + resp = self.client.get(url) + self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) + + # Accessible but empty when logged in + resp = self.client.get( + url, + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + self.assertListEqual(data, []) + + # Add a few artifacts + xxx = self.task1.artifacts.create( + path="path/to/xxx", content_type="application/json", size=123 + ) + demo = self.task1.artifacts.create( + path="demo.txt", content_type="text/plain", size=12 + ) + self.task2.artifacts.create(path="demo.txt", content_type="text/plain", size=1) + + resp = self.client.get( + url, + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + data = resp.json() + + # We only get the 2 artifacts on task1, ordered by path + self.assertListEqual( + data, + [ + { + "content_type": "text/plain", + "created": str_date(demo.created), + "id": str(demo.id), + "path": "demo.txt", + "s3_put_url": None, + "size": 12, + "updated": str_date(demo.updated), + "url": f"http://testserver/ponos/v1/task/{self.task1.id}/artifact/demo.txt", + }, + { + "content_type": "application/json", + "created": str_date(xxx.created), + "id": str(xxx.id), + "path": "path/to/xxx", + "s3_put_url": None, + "size": 123, + "updated": str_date(xxx.updated), + "url": f"http://testserver/ponos/v1/task/{self.task1.id}/artifact/path/to/xxx", + }, + ], + ) + + def test_artifact_creation(self): + + # No artifacts in DB at first + self.assertFalse(self.task1.artifacts.exists()) + + # Create an artifact through the API + url = reverse("ponos:task-artifacts", args=[self.task1.id]) + resp = self.client.post( + url, + data={"path": "some/path.txt", "content_type": "text/plain", "size": 1000}, + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_201_CREATED) + + # Check response has a valid S3 put URL, without matching the parameters in the querystring + data = resp.json() + s3_put_url = data.get("s3_put_url") + self.assertIsNotNone(s3_put_url) + del data["s3_put_url"] + self.assertTrue( + s3_put_url.startswith( + f"http://somewhere/ponos-artifacts/{self.workflow.id}/{self.task1.id}/some/path.txt" + ) + ) + + # An artifact has been created + artifact = self.task1.artifacts.get() + + self.assertDictEqual( + resp.json(), + { + "content_type": "text/plain", + "created": str_date(artifact.created), + "id": str(artifact.id), + "path": "some/path.txt", + "size": 1000, + "updated": str_date(artifact.updated), + "url": f"http://testserver/ponos/v1/task/{self.task1.id}/artifact/some/path.txt", + }, + ) + + # Creating another artifact with the same path will fail cleanly + resp = self.client.post( + url, + data={ + "path": "some/path.txt", + "content_type": "text/plain", + "size": 10243, + }, + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + resp.json(), {"path": ["An artifact with this path already exists"]} + ) + + @override_settings() + def test_create_artifact_size_limits(self): + params = [ + (None, -42, "Ensure this value is greater than or equal to 1."), + (None, 0, "Ensure this value is greater than or equal to 1."), + ( + None, + 5 * 1024**3 + 1, + "Ensure this value is less than or equal to 5368709120.", + ), + (123456789000, -42, "Ensure this value is greater than or equal to 1."), + (123456789000, 0, "Ensure this value is greater than or equal to 1."), + ( + 123456789000, + 987654321000, + "Ensure this value is less than or equal to 123456789000.", + ), + ] + + for size_setting, size, expected_error in params: + with self.subTest(size=size): + settings.PONOS_ARTIFACT_MAX_SIZE = size_setting + url = reverse("ponos:task-artifacts", args=[self.task1.id]) + resp = self.client.post( + url, + data={ + "path": "some/path.txt", + "content_type": "text/plain", + "size": size, + }, + HTTP_AUTHORIZATION=f"Bearer {self.agent.token.access_token}", + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(resp.json(), {"size": [expected_error]}) + + def test_artifact_download(self): + + # Add a new artifact + self.task1.artifacts.create( + path="path/to/file.json", + content_type="application/json", + size=42, + ) + + # Request to get the real download link will fail as anonymous + url = reverse( + "ponos:task-artifact-download", args=[self.task1.id, "path/to/file.json"] + ) + resp = self.client.get(url) + self.assertEqual(resp.status_code, status.HTTP_401_UNAUTHORIZED) + + # Request to a missing artifact will lead to a 404 + bad_url = reverse( + "ponos:task-artifact-download", args=[self.task1.id, "nope.xxx"] + ) + resp = self.client.get( + bad_url, + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) + + # Valid request will redirect to the S3 url + resp = self.client.get( + url, + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(resp.status_code, status.HTTP_302_FOUND) + + # Check the S3 download link + self.assertTrue(resp.has_header("Location")) + self.assertTrue( + resp["Location"].startswith( + f"http://somewhere/ponos-artifacts/{self.workflow.id}/{self.task1.id}/path/to/file.json" + ) + ) + + def test_task_create_empty_body(self): + response = self.client.post(reverse("ponos:task-create")) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.json(), + { + "image": ["This field is required."], + "parents": ["This field is required."], + "slug": ["This field is required."], + "workflow_id": ["This field is required."], + }, + ) + + def test_task_create_no_parent(self): + response = self.client.post( + reverse("ponos:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "image": "registry.gitlab.com/test", + "parents": [], + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.json(), {"parents": ["This list may not be empty."]} + ) + + def test_task_create_distinct_workflows_on_parents(self): + workflow2 = Workflow.objects.create(farm=self.farm, recipe=RECIPE) + task3 = workflow2.tasks.create( + run=0, + depth=1, + slug="task_parent", + image="registry.gitlab.com/test", + ) + + response = self.client.post( + reverse("ponos:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "image": "registry.gitlab.com/test", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.json(), + { + "non_field_errors": [ + "All parents must be in the same workflow as the child task" + ] + }, + ) + + def test_task_create_distinct_runs_on_parents(self): + task3 = self.workflow.tasks.create( + run=1, + depth=1, + slug="task_parent", + image="registry.gitlab.com/test", + ) + + response = self.client.post( + reverse("ponos:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "image": "registry.gitlab.com/test", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.json(), + { + "non_field_errors": [ + "All parents must have the same run in the given workflow" + ] + }, + ) + + def test_task_create_duplicate(self): + self.workflow.tasks.create( + run=0, + depth=3, + slug="sibling", + image="registry.gitlab.com/test", + ) + + response = self.client.post( + reverse("ponos:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "sibling", + "image": "registry.gitlab.com/test", + "parents": [str(self.task1.id), str(self.task2.id)], + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.json(), + { + "non_field_errors": [ + "A task with the `sibling` slug already exists in run 0." + ] + }, + ) + + def test_task_create(self): + task3 = self.workflow.tasks.create( + run=0, + depth=3, + slug="task_parent", + image="registry.gitlab.com/test", + ) + + response = self.client.post( + reverse("ponos:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "image": "registry.gitlab.com/test", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + "command": "echo Test", + "env": {"test": "test", "test2": "test2"}, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + data = response.json() + del data["id"] + self.assertDictEqual( + data, + { + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + "image": "registry.gitlab.com/test", + "command": "echo Test", + "env": { + "test": "test", + "top_env_variable": "workflow_variable", + "test2": "test2", + }, + "run": 0, + "depth": 4, + "has_docker_socket": False, + }, + ) + + def test_task_create_has_docker_socket_true(self): + task3 = self.workflow.tasks.create( + run=0, + depth=3, + slug="task_parent", + image="registry.gitlab.com/test", + ) + + response = self.client.post( + reverse("ponos:task-create"), + data={ + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "image": "registry.gitlab.com/test", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + "command": "echo Test", + "env": {"test": "test", "test2": "test2"}, + "has_docker_socket": True, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + data = response.json() + del data["id"] + self.assertDictEqual( + data, + { + "workflow_id": str(self.workflow.id), + "slug": "test_task", + "parents": [str(self.task1.id), str(self.task2.id), str(task3.id)], + "image": "registry.gitlab.com/test", + "command": "echo Test", + "env": { + "test": "test", + "top_env_variable": "workflow_variable", + "test2": "test2", + }, + "run": 0, + "depth": 4, + "has_docker_socket": True, + }, + ) + + def test_retrieve_secret_requires_auth(self): + """ + Only agents may access a secret details + """ + response = self.client.get( + reverse("ponos:secret-details", kwargs={"name": "abc"}) + ) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertDictEqual( + response.json(), {"detail": "Authentication credentials were not provided."} + ) + + def test_retrieve_secret_not_found(self): + """ + A 404 should be raised when no secret match query name + """ + response = self.client.get( + reverse( + "ponos:secret-details", kwargs={"name": "the_most_important_secret"} + ), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertDictEqual(response.json(), {"detail": "Not found."}) + + def test_agent_retrieve_secret(self): + """ + Agent should be able to retrieve a secret content as cleartext + """ + account_name = "bank_account/0001/private" + secret = Secret.objects.create( + name=account_name, + nonce=b"1337" * 4, + content=encrypt(b"1337" * 4, "1337$"), + ) + self.assertEqual(secret.content, b"\xc1\x81\xc0\xceo") + response = self.client.get( + reverse("ponos:secret-details", kwargs={"name": account_name}), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual( + response.json(), + { + "id": str(secret.id), + "name": account_name, + "content": "1337$", + }, + ) + + def test_list_agents_state(self): + """ + Lists agents from all farms with their status + """ + # Add tasks with different states on the agent + self.agent.tasks.bulk_create( + [ + Task( + run=0, + depth=0, + workflow=self.workflow, + slug=state.value, + state=state, + agent=self.agent, + ) + for state in State + ] + ) + with self.assertNumQueries(4): + response = self.client.get(reverse("ponos:agents-state")) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + self.assertEqual(data["count"], 1) + agent_state = data["results"][0] + del agent_state["last_ping"] + self.assertDictEqual( + agent_state, + { + "id": str(self.agent.id), + "active": True, + "cpu_cores": 2, + "cpu_frequency": 1000000000, + "cpu_load": None, + "farm": { + "id": str(self.farm.id), + "name": "Wheat farm", + }, + "gpus": [ + { + "id": "108c6524-c63a-4811-bbed-9723d32a0688", + "index": 0, + "name": "GPU1", + "ram_total": 2147483648, + }, + { + "id": "f30d1407-92bb-484b-84b0-0b8bae41ca91", + "index": 1, + "name": "GPU2", + "ram_total": 8589934592, + }, + ], + "hostname": "ghostname", + "ram_load": None, + "ram_total": 2000000000, + "running_tasks_count": 1, + }, + ) + + def test_retrieve_agent_details(self): + """ + The view returns an agents with its details and associated running tasks + """ + # Add tasks with different states on the agent + Task.objects.bulk_create( + [ + Task( + run=0, + depth=0, + workflow=self.workflow, + slug=state.value, + state=state, + agent=self.agent, + ) + for state in State + ] + ) + running_task = self.agent.tasks.get(state=State.Running) + with self.assertNumQueries(5): + response = self.client.get( + reverse("ponos:agent-details", kwargs={"pk": str(self.agent.id)}) + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + del data["last_ping"] + self.maxDiff = None + self.assertDictEqual( + data, + { + "id": str(self.agent.id), + "active": True, + "cpu_cores": 2, + "cpu_frequency": 1000000000, + "cpu_load": None, + "farm": { + "id": str(self.farm.id), + "name": "Wheat farm", + }, + "gpus": [ + { + "id": "108c6524-c63a-4811-bbed-9723d32a0688", + "index": 0, + "name": "GPU1", + "ram_total": 2147483648, + }, + { + "id": "f30d1407-92bb-484b-84b0-0b8bae41ca91", + "index": 1, + "name": "GPU2", + "ram_total": 8589934592, + }, + ], + "hostname": "ghostname", + "ram_load": None, + "ram_total": 2000000000, + "running_tasks": [ + { + "id": str(running_task.id), + "run": 0, + "depth": 0, + "parents": [], + "slug": "running", + "state": "running", + "shm_size": None, + "tags": [], + "url": response.wsgi_request.build_absolute_uri( + reverse("ponos:task-details", args=[running_task.id]) + ), + } + ], + }, + ) + + def test_list_farms(self): + """ + Any user is able to list farms basic information + """ + barley_farm = Farm.objects.create(name="Barley") + with self.assertNumQueries(2): + response = self.client.get(reverse("ponos:farm-list")) + self.maxDiff = None + self.assertDictEqual( + response.json(), + { + "count": 2, + "previous": None, + "next": None, + "results": [ + {"id": str(barley_farm.id), "name": "Barley"}, + {"id": str(self.farm.id), "name": "Wheat farm"}, + ], + }, + ) diff --git a/arkindex/ponos/tests/test_keys.py b/arkindex/ponos/tests/test_keys.py new file mode 100644 index 0000000000000000000000000000000000000000..d25cb5a9420cf51311d0f8a891791f886b46b31c --- /dev/null +++ b/arkindex/ponos/tests/test_keys.py @@ -0,0 +1,89 @@ +import os +import tempfile + +import cryptography +from django.test import TestCase, override_settings + +from ponos.keys import gen_private_key, load_private_key + +BAD_KEY = """-----BEGIN RSA PRIVATE KEY----- +MCoCAQACBGvoyx0CAwEAAQIEUg2X0QIDAMUbAgMAjCcCAmr9AgJMawICKvo= +-----END RSA PRIVATE KEY-----""" + + +@override_settings(PONOS_PRIVATE_KEY=None) +class KeysTestCase(TestCase): + """ + Test ECDH keys + """ + + def test_no_key(self): + """ + No key by default with those settings + """ + with self.settings(DEBUG=True): + # On debug, that works, but with a warning + self.assertIsNone(load_private_key()) + + with self.settings(DEBUG=False): + # On prod, that fails + with self.assertRaisesRegex( + Exception, r"Missing setting PONOS_PRIVATE_KEY" + ): + load_private_key() + + with self.settings(DEBUG=True, PONOS_PRIVATE_KEY="/tmp/nope"): + # On debug, that works, but with a warning + self.assertIsNone(load_private_key()) + + with self.settings(DEBUG=False, PONOS_PRIVATE_KEY="/tmp/nope"): + # On prod, that fails + with self.assertRaisesRegex(Exception, r"Invalid PONOS_PRIVATE_KEY path"): + load_private_key() + + # Test with a valid RSA key (not ECDH) + _, path = tempfile.mkstemp() + open(path, "w").write(BAD_KEY) + with self.settings(DEBUG=True, PONOS_PRIVATE_KEY=path): + # On debug, that fails too ! + with self.assertRaisesRegex(Exception, r"not an ECDH"): + load_private_key() + + with self.settings(DEBUG=False, PONOS_PRIVATE_KEY=path): + # On prod, that fails + with self.assertRaisesRegex(Exception, r"not an ECDH"): + load_private_key() + os.unlink(path) + + def test_private_key(self): + """ + Test private key writing and loading + """ + + # Generate some key + _, path = tempfile.mkstemp() + gen_private_key(path) + key = open(path).read().splitlines() + self.assertTrue(len(key) > 2) + self.assertEqual(key[0], "-----BEGIN PRIVATE KEY-----") + self.assertEqual(key[-1], "-----END PRIVATE KEY-----") + + # Load it through settings + with self.settings(PONOS_PRIVATE_KEY=path): + priv = load_private_key() + self.assertTrue( + isinstance( + priv, + cryptography.hazmat.backends.openssl.ec._EllipticCurvePrivateKey, + ) + ) + self.assertTrue(priv.key_size > 256) + + # When messed up, nothing works + with open(path, "w") as f: + f.seek(100) + f.write("coffee") + with self.settings(PONOS_PRIVATE_KEY=path): + with self.assertRaisesRegex(Exception, r"Could not deserialize key data"): + load_private_key() + os.unlink(path) diff --git a/arkindex/ponos/tests/test_models.py b/arkindex/ponos/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3e20e85fbcdbf32734230a27ce901a8d866831 --- /dev/null +++ b/arkindex/ponos/tests/test_models.py @@ -0,0 +1,255 @@ +import tempfile +from unittest.mock import patch + +from django.core.exceptions import ValidationError +from django.db.models import prefetch_related_objects +from django.test import TestCase, override_settings +from django.utils import timezone + +from ponos.models import ( + FINAL_STATES, + Agent, + Farm, + Secret, + State, + Workflow, + build_aes_cipher, + encrypt, +) + +RECIPE = """ +tasks: + task1: + image: hello-world +""" + + +class TestModels(TestCase): + @classmethod + def setUpTestData(cls): + cls.farm = Farm.objects.create(name="Cryptominers") + cls.workflow = Workflow.objects.create(farm=cls.farm, recipe=RECIPE) + cls.workflow.start() + cls.task1 = cls.workflow.tasks.get() + cls.nonce = b"42" + b"0" * 14 + + def setUp(self): + self.agent = Agent.objects.create( + farm=self.farm, + hostname="roastname", + cpu_cores=100, + cpu_frequency=1e9, + public_key="Mamoswine", + ram_total=10e9, + last_ping=timezone.now(), + ) + self.agent.gpus.create( + id="108c6524-c63a-4811-bbed-9723d32a0688", + name="GPU1", + index=0, + ram_total=2 * 1024 * 1024 * 1024, + ) + self.agent.gpus.create( + id="f30d1407-92bb-484b-84b0-0b8bae41ca91", + name="GPU2", + index=1, + ram_total=8 * 1024 * 1024 * 1024, + ) + + def test_is_final(self): + for state in State: + self.task1.state = state + self.task1.save() + if state in FINAL_STATES: + self.assertTrue( + self.task1.is_final(), msg="{} should be final".format(state) + ) + self.assertTrue( + self.workflow.is_final(), msg="{} should be final".format(state) + ) + else: + self.assertFalse( + self.task1.is_final(), msg="{} should not be final".format(state) + ) + self.assertFalse( + self.workflow.is_final(), msg="{} should not be final".format(state) + ) + + def test_delete_agent_non_final(self): + """ + Agent deletion should be prevented when it has non-final tasks + """ + self.task1.agent = self.agent + self.task1.state = State.Pending + self.task1.save() + with self.assertRaisesRegex(ValidationError, "non-final"): + self.agent.delete() + + self.task1.state = State.Error + self.task1.save() + # Should no longer fail + self.agent.delete() + + self.task1.refresh_from_db() + self.assertIsNone(self.task1.agent) + + @override_settings(PONOS_PRIVATE_KEY=None) + def test_aes_missing_key(self): + with self.assertRaisesRegex(Exception, r"Missing a PONOS_PRIVATE_KEY"): + build_aes_cipher(nonce=self.nonce) + + @patch("ponos.models.settings") + def test_build_aes_cipher(self, settings_mock): + """ + AES encryption key should be a derivate of the Ponos server secret key + """ + _, path = tempfile.mkstemp() + settings_mock.PONOS_PRIVATE_KEY = path + with open(path, "wb") as f: + f.write(b"pikachu") + cipher = build_aes_cipher(nonce=self.nonce) + self.assertEqual(cipher.encryptor().update(b"hey"), b"lM\x8d") + with open(path, "wb") as f: + f.write(b"bulbasaur") + cipher = build_aes_cipher(nonce=self.nonce) + self.assertNotEqual(cipher.encryptor().update(b"hey"), b"lM\x8d") + + @patch("ponos.models.settings") + def test_secret_encrypt_decrypt(self, settings_mock): + _, path = tempfile.mkstemp() + settings_mock.PONOS_PRIVATE_KEY = path + with open(path, "wb") as f: + f.write(b"pikachu") + secret = Secret( + name="Test secret", + nonce=self.nonce, + content=encrypt(self.nonce, "secret_m3ssage"), + ) + self.assertEqual(secret.content, b"wM\x97\n\xadS\x13\x8a\x89&ZF\xbd\xee") + self.assertEqual(secret.decrypt(), "secret_m3ssage") + + def test_agent_estimate_new_tasks_cost(self): + """ + Agent has 100 cores and 10GB of RAM + One task is estimated to use 1 CPU (1%) and 1GB of RAM (10%) + """ + self.agent.cpu_load = 50 + self.agent.ram_load = 0.35 + self.agent.save() + # The CPU will define the agent load for the first task reaching 51% occupancy + self.assertEqual(self.agent._estimate_new_tasks_cost(tasks=1), 0.51) + # For the second task, the RAM will reach 55% occupancy overtaking the CPU load (52%) + self.assertEqual(self.agent._estimate_new_tasks_cost(tasks=2), 0.55) + + def test_requires_gpu(self): + """ + Check the GPU generated requirements rules + """ + # Default task does not need GPU + self.assertFalse(self.task1.requires_gpu) + + def test_activate_requires_gpu_from_recipe(self): + """ + Check that tasks.requires_gpu is enabled when enabled in recipe + """ + recipe_with_gpu = """ + tasks: + initialisation: + command: null + image: hello-world + worker: + image: hello-world + parents: + - initialisation + requires_gpu: true + """ + workflow = Workflow.objects.create(farm=self.farm, recipe=recipe_with_gpu) + init_task, worker_task = workflow.build_tasks().values() + self.assertFalse(init_task.requires_gpu) + self.assertTrue(worker_task.requires_gpu) + + @patch("ponos.models.timezone") + def test_task_expiry_default(self, timezone_mock): + timezone_mock.now.return_value = timezone.datetime(3000, 1, 12).astimezone() + # Expecting a default expiry 30 days after timezone.now + expected_expiry = timezone.datetime(3000, 2, 11).astimezone() + workflow = Workflow.objects.create( + farm=self.farm, + recipe=""" + tasks: + task1: + image: hello-world + task2: + image: hello-world + """, + ) + + # A workflow with no tasks has no expiry + self.assertFalse(workflow.tasks.exists()) + self.assertIsNone(workflow.expiry) + + task1, task2 = workflow.build_tasks().values() + + self.assertEqual(task1.expiry, expected_expiry) + self.assertEqual(task2.expiry, expected_expiry) + self.assertEqual(workflow.expiry, expected_expiry) + + # Override a task's expiry + custom_expiry = timezone.datetime(3000, 4, 20).astimezone() + task2.expiry = custom_expiry + task2.save() + + # The workflow's expiry should be the latest expiry + self.assertEqual(workflow.expiry, custom_expiry) + + def test_workflow_expiry_query_count(self): + """ + Workflow.expiry causes an SQL query only when tasks are not prefetched + """ + with self.assertNumQueries(1): + self.assertEqual(self.workflow.expiry, self.task1.expiry) + + # Request the expiry again: it is not cached, there still is an SQL query + with self.assertNumQueries(1): + self.assertEqual(self.workflow.expiry, self.task1.expiry) + + prefetch_related_objects([self.workflow], "tasks") + with self.assertNumQueries(0): + self.assertEqual(self.workflow.expiry, self.task1.expiry) + + def test_workflow_get_state(self): + with self.assertNumQueries(1): + self.assertEqual(self.workflow.state, State.Unscheduled) + + with self.assertNumQueries(1): + self.assertEqual(self.workflow.get_state(0), State.Unscheduled) + + with self.assertNumQueries(1): + self.assertEqual(self.workflow.get_state(1), State.Unscheduled) + + # Negative run numbers should not result in any SQL query, since we know they are always empty + with self.assertNumQueries(0): + self.assertEqual(self.workflow.get_state(-1), State.Unscheduled) + + self.task1.state = State.Running + self.task1.save() + + with self.assertNumQueries(1): + self.assertEqual(self.workflow.state, State.Running) + + with self.assertNumQueries(1): + self.assertEqual(self.workflow.get_state(0), State.Running) + + with self.assertNumQueries(1): + self.assertEqual(self.workflow.get_state(1), State.Unscheduled) + + with self.assertNumQueries(0): + self.assertEqual(self.workflow.get_state(-1), State.Unscheduled) + + prefetch_related_objects([self.workflow], "tasks") + + with self.assertNumQueries(0): + self.assertEqual(self.workflow.state, State.Running) + self.assertEqual(self.workflow.get_state(0), State.Running) + self.assertEqual(self.workflow.get_state(1), State.Unscheduled) + self.assertEqual(self.workflow.get_state(-1), State.Unscheduled) diff --git a/arkindex/ponos/tests/test_recipe.py b/arkindex/ponos/tests/test_recipe.py new file mode 100644 index 0000000000000000000000000000000000000000..18ef2dbe9aaf879659cb2e4e71164b9248d30a45 --- /dev/null +++ b/arkindex/ponos/tests/test_recipe.py @@ -0,0 +1,42 @@ +from textwrap import dedent + +from django.test import TestCase + +from ponos.recipe import parse_recipe + +# List of (broken recipe, expected AssertionError message) tuples +ERROR_CASES = [ + ("[]", "Recipe should be a dict"), + ("tasks: {}", "No tasks"), + ("tasks: [{image: lol}]", "Tasks should be a dict"), + ("tasks: {a: {}, '': {}}", "Tasks should have non-blank slugs"), + ("tasks: {a: []}", "Task should be a dict"), + ("tasks: {a: {}}", "Missing image"), + ( + """ + tasks: + lol: + image: blah + artifact: 42 + """, + "Task artifact should be a string", + ), + ( + """ + tasks: + lol: + image: blah + artifact: philosophers_stone + """, + "Task artifact should be a valid UUID string", + ), +] + + +class TestRecipe(TestCase): + def test_recipe_errors(self): + for recipe, expected_message in ERROR_CASES: + with self.subTest(recipe=recipe), self.assertRaisesMessage( + AssertionError, expected_message + ): + parse_recipe(dedent(recipe)) diff --git a/arkindex/ponos/tests/test_schema.py b/arkindex/ponos/tests/test_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfe6571985cdf279b1ef0ef4f81257fef8544cd --- /dev/null +++ b/arkindex/ponos/tests/test_schema.py @@ -0,0 +1,15 @@ +from django.core.management import call_command +from django.test import TestCase + + +class TestSchema(TestCase): + def test_generate_schema(self): + """ + Runs the OpenAPI schema generation command, causing this test to fail if something goes wrong with the schema validation. + """ + call_command( + "spectacular", + format="openapi-json", + fail_on_warn=True, + validate=True, + ) diff --git a/arkindex/ponos/tests/test_tasks_attribution.py b/arkindex/ponos/tests/test_tasks_attribution.py new file mode 100644 index 0000000000000000000000000000000000000000..84738085852ef00087c23e77b067cc1ca45b6f07 --- /dev/null +++ b/arkindex/ponos/tests/test_tasks_attribution.py @@ -0,0 +1,373 @@ +import uuid +from datetime import timedelta +from unittest.mock import patch + +from django.test import TestCase +from django.utils import timezone + +from ponos.models import Agent, Farm, State, Task, Workflow +from ponos.models import timezone as model_tz + + +class TasksAttributionTestCase(TestCase): + """ + Ponos server distribute tasks equally among agents. + """ + + @classmethod + def setUpTestData(cls): + cls.farm = Farm.objects.create(name="testfarm") + cls.workflow = Workflow.objects.create( + farm=cls.farm, recipe="tasks: {charmander: {image: alpine}}" + ) + + def _run_tasks(self, tasks): + """ + Mark a list of tasks as running + """ + for t in tasks: + t.state = State.Running + t.save() + + def _build_agent(self, **kwargs): + """ + Creates an agent + Default values may be overridden if passed as kwargs + """ + params = { + "hostname": "test_host", + "cpu_cores": 2, + "cpu_frequency": 4.2e9, + "public_key": "", + "farm": self.farm, + "ram_total": 2e9, + "last_ping": timezone.now(), + "ram_load": 0.49, + "cpu_load": 0.99, + } + params.update(kwargs) + return Agent.objects.create(**params) + + def _add_pending_tasks(self, number, slug_ext="", **kwargs): + """ + Creates pending tasks on the system + """ + params = { + "run": 0, + "depth": 0, + "workflow": self.workflow, + "state": State.Pending, + } + params.update(kwargs) + return Task.objects.bulk_create( + Task(**params, slug=f"slug{slug_ext}_{i}") for i in range(1, number + 1) + ) + + def _active_agent(self, agent): + from ponos.models import AGENT_TIMEOUT + + return timezone.now() - agent.last_ping < AGENT_TIMEOUT + + def test_distribute_tasks(self): + """ + Multiple tasks are attributed to 3 registered agents with the same capacity + """ + agent_1, agent_2, agent_3 = [ + # Build three agents with 10 free CPU cores and 10Go of available RAM + self._build_agent( + hostname=f"agent_{i}", + cpu_cores=11, + ram_total=20e9, + ) + for i in range(1, 4) + ] + self.assertEqual( + [a.hostname for a in Agent.objects.all() if self._active_agent(a) is True], + ["agent_1", "agent_2", "agent_3"], + ) + # Add 9 pending tasks to the system + self._add_pending_tasks(9) + + # As agents are registered, they should retrieve 3 tasks each + tasks = agent_1.next_tasks() + self.assertEqual(len(tasks), 3) + self._run_tasks(tasks) + + # Agent 1 claims tasks before agent 2 but has a too high load to retrieve any task + agent_1.ram_load = 0.75 + agent_1.cpu_load = 3.99 + agent_1.save() + tasks = agent_1.next_tasks() + self.assertEqual(len(tasks), 0) + bg_tasks = agent_2.next_tasks() + self.assertEqual(len(bg_tasks), 3) + + # Agent 1 claims tasks before agent 3 + tasks = agent_1.next_tasks() + self.assertEqual(len(tasks), 0) + tasks = agent_3.next_tasks() + self.assertEqual(len(tasks), 3) + + # No agent updated its state before agent_2 retrieved its tasks + self.assertEqual(len(set(t.id for t in [*bg_tasks, *tasks])), 6) + + self._run_tasks(bg_tasks) + self._run_tasks(tasks) + self.assertEqual(Task.objects.filter(state=State.Pending).count(), 0) + + def test_distribute_tasks_asymetric(self): + """ + Use case when two agents with different capacity are attributed + tasks equivalent to a third of the system capacity + """ + # Agent 1 has 10 free "slots" + agent_1 = self._build_agent(hostname="agent_1", cpu_cores=11, ram_total=20e9) + # Agent 2 has only 3 free "slots" + agent_2 = self._build_agent(hostname="agent_2", cpu_cores=4, ram_total=6e9) + # Add 5 pending tasks to the system + self._add_pending_tasks(5) + + # The best strategy is to attribute 4 tasks to the first agent and 1 to the second + tasks = agent_1.next_tasks() + self.assertEqual(len(tasks), 4) + self._run_tasks(tasks) + + agent_1.ram_load = 0.7 + agent_1.cpu_load = 3.99 + agent_1.save() + tasks = agent_1.next_tasks() + self.assertEqual(len(tasks), 0) + tasks = agent_2.next_tasks() + self.assertEqual(len(tasks), 1) + self._run_tasks(tasks) + self.assertEqual(Task.objects.filter(state=State.Pending).count(), 0) + + def test_distribute_tasks_determinist(self): + """ + Tasks should be distributed in a determinist way depending + on the overall load and pending tasks queue + Tests that an agent polling task twice retrieves the same tasks + """ + _, _, agent = [ + self._build_agent( + hostname=f"agent_{i}", + cpu_cores=11, + ram_total=20e9, + ) + for i in range(1, 4) + ] + self._add_pending_tasks(9) + self.assertEqual( + [a.hostname for a in Agent.objects.all() if self._active_agent(a) is True], + ["agent_1", "agent_2", "agent_3"], + ) + + tasks = agent.next_tasks() + self.assertEqual(len(tasks), 3) + + new_tasks = agent.next_tasks() + self.assertCountEqual(tasks, new_tasks) + + @patch("ponos.models.AGENT_TIMEOUT", timedelta(seconds=1)) + def test_non_active_agent(self): + """ + If an agent does not respond, task should be attributed to other running agents + """ + agent_1, agent_2, agent_3 = [ + self._build_agent( + hostname=f"agent_{i}", + cpu_cores=11, + ram_total=20e9, + ) + for i in range(1, 4) + ] + self.assertEqual( + [a.hostname for a in Agent.objects.all() if self._active_agent(a) is True], + ["agent_1", "agent_2", "agent_3"], + ) + # Add 9 pending tasks to the system + self._add_pending_tasks(9) + + # As agents are registered, they should retrieve 3 tasks each + tasks_agent_1 = agent_1.next_tasks() + self.assertEqual(len(tasks_agent_1), 3) + + tasks = agent_2.next_tasks() + self.assertEqual(len(tasks), 3) + + self._run_tasks([*tasks_agent_1, *tasks]) + + # Jump in the future. Agent 3 did never replied + future_now = timezone.now() + timedelta(seconds=2) + Agent.objects.filter(hostname__in=["agent_1", "agent_2"]).update( + ram_load=0.64, + cpu_load=3.99, + last_ping=future_now, + ) + + with patch.object(model_tz, "now") as now_mock: + now_mock.return_value = future_now + self.assertEqual( + [ + a.hostname + for a in Agent.objects.all() + if self._active_agent(a) is True + ], + ["agent_1", "agent_2"], + ) + + # 3 tasks should be distributed among 2 agents with similar load + tasks_agent_1 = agent_1.next_tasks() + self.assertEqual(len(tasks_agent_1), 2) + + tasks = agent_2.next_tasks() + self.assertEqual(len(tasks), 1) + + self._run_tasks([*tasks_agent_1, *tasks]) + + self.assertEqual(Task.objects.filter(state=State.Pending).count(), 0) + + def test_filter_active_agents(self): + """ + Assert that the DB cost to attribute tasks does + not increase with tasks or agents number + """ + agent = self._build_agent() + tasks = self._add_pending_tasks(20) + + with self.assertNumQueries(4): + tasks = agent.next_tasks() + self.assertEqual(len(tasks), 1) + + # Build some inefficient agents + for i in range(1, 20): + self._build_agent( + hostname=f"agent_{i}", + cpu_cores=2, + ram_total=2e9, + ) + + self.assertEqual(Agent.objects.count(), 20) + with self.assertNumQueries(4): + tasks = agent.next_tasks() + self.assertEqual(len(tasks), 1) + + def test_gpu_assignation(self): + """ + Check a GPU enabled agent gets a GPU task + """ + # Agent normal has 20 free "slots" + # It should get all the tasks + agent_normal = self._build_agent( + hostname="agent_normal", cpu_cores=20, ram_total=20e9 + ) + + # But this one has a GPU: it will get GPU tasks in priority + agent_gpu = self._build_agent(hostname="agent_2", cpu_cores=10, ram_total=6e9) + agent_gpu.gpus.create(id=uuid.uuid4(), index=0, ram_total=8e9, name="Fake GPU") + + # Add 6 normal + 1 GPU pending tasks to the system + tasks = self._add_pending_tasks(7) + task_gpu = tasks[-1] + + # Require gpu on that task + task_gpu.command = "/usr/bin/nvidia-smi" + task_gpu.requires_gpu = True + task_gpu.save() + + # Normal agent should eat up most of the normal tasks + tasks = agent_normal.next_tasks() + self.assertEqual(len(tasks), 5) + self.assertFalse(any(t.requires_gpu for t in tasks)) + self.assertFalse(any(t.gpu for t in tasks)) + self._run_tasks(tasks) + + # GPU agent should then get the GPU task + tasks = agent_gpu.next_tasks() + self.assertEqual(len(tasks), 1) + self.assertTrue(all(t.requires_gpu for t in tasks)) + self.assertTrue(all(t.gpu for t in tasks)) + self._run_tasks(tasks) + + # Consume last normal task + tasks = agent_normal.next_tasks() + self.assertEqual(len(tasks), 1) + self._run_tasks(tasks) + + # No more tasks left + self.assertEqual(Task.objects.filter(state=State.Pending).count(), 0) + + def test_multiple_farm_task_assignation(self): + """ + Distribute tasks depending on farm + """ + # Create one big agent on the test farm + test_agent = self._build_agent(cpu_cores=20, ram_total=64e9) + self._add_pending_tasks(3) + # and 3 small agents on the corn farm with a capacity for one task each + corn_farm = Farm.objects.create(name="Corn farm") + corn_agent_1, corn_agent_2 = [ + self._build_agent( + hostname=f"agent_{i}", + cpu_cores=2, + ram_total=2e9, + farm=corn_farm, + ) + for i in range(1, 3) + ] + + corn_workflow = corn_farm.workflows.create( + recipe="tasks: {inkay: {image: alpine}}" + ) + tasks = self._add_pending_tasks(3, workflow=corn_workflow) + + self.assertEqual(Task.objects.count(), 6) + self.assertEqual( + len([agent for agent in Agent.objects.all() if self._active_agent(agent)]), + 3, + ) + + # All except one task in the corn farm will be distributed + tasks = test_agent.next_tasks() + self.assertEqual(len(tasks), 3) + self.assertEqual( + set([task.workflow.farm_id for task in tasks]), set([self.farm.id]) + ) + + corn_tasks_1 = corn_agent_1.next_tasks() + self.assertEqual(len(corn_tasks_1), 1) + corn_tasks_2 = corn_agent_2.next_tasks() + self.assertEqual(len(corn_tasks_2), 1) + + self._run_tasks([*tasks, *corn_tasks_1, *corn_tasks_2]) + + # Update corn agents loads + Agent.objects.filter(farm=corn_farm).update(ram_load=0.95, cpu_load=1.9) + + # No agent can retrieve the last pending task + self.assertEqual(Task.objects.filter(state=State.Pending).count(), 1) + for agent in Agent.objects.all(): + self.assertEqual(len(agent.next_tasks()), 0) + + def test_next_tasks_ordering(self): + agent_1 = self._build_agent(hostname="agent_1", cpu_cores=11, ram_total=20e9) + self.assertEqual( + [a.hostname for a in Agent.objects.all() if self._active_agent(a) is True], + ["agent_1"], + ) + # Add 3 pending low priority tasks to the system + lp_tasks = self._add_pending_tasks(3, slug_ext="_lp", priority="1") + # Add 3 pending normal priority tasks to the system + np_tasks = self._add_pending_tasks(3, slug_ext="_np") + # Add 3 pending high priority tasks to the system + hp_tasks = self._add_pending_tasks(3, slug_ext="_hp", priority="100") + + # As agents are registered, they should retrieve 3 tasks each + tasks = agent_1.next_tasks() + self.assertEqual(len(tasks), 9) + # First three tasks should be the high priority ones + self.assertEqual(tasks[:3], hp_tasks) + # Second three tasks should be the normal priority ones + self.assertEqual(tasks[3:6], np_tasks) + # Last three tasks should be the low priority ones + self.assertEqual(tasks[6:], lp_tasks) diff --git a/arkindex/ponos/tests/test_util.py b/arkindex/ponos/tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..634aaa746d91ee06d1a1fd62f9c5804590891821 --- /dev/null +++ b/arkindex/ponos/tests/test_util.py @@ -0,0 +1,33 @@ +import uuid +from unittest import TestCase +from unittest.mock import patch + +from ponos import get_ponos_run, get_ponos_task_id, is_ponos_task + + +class TestUtil(TestCase): + def test_is_ponos_task(self): + with patch.dict("ponos.helpers.os.environ", clear=True): + self.assertFalse(is_ponos_task()) + + with patch.dict( + "ponos.helpers.os.environ", PONOS_TASK=str(uuid.uuid4()), clear=True + ): + self.assertTrue(is_ponos_task()) + + def test_get_ponos_task_id(self): + with patch.dict("ponos.helpers.os.environ", clear=True): + self.assertIsNone(get_ponos_task_id()) + + task_id = uuid.uuid4() + with patch.dict( + "ponos.helpers.os.environ", PONOS_TASK=str(task_id), clear=True + ): + self.assertEqual(get_ponos_task_id(), task_id) + + def test_get_ponos_run(self): + with patch.dict("ponos.helpers.os.environ", clear=True): + self.assertIsNone(get_ponos_run()) + + with patch.dict("ponos.helpers.os.environ", PONOS_RUN="42", clear=True): + self.assertEqual(get_ponos_run(), 42) diff --git a/arkindex/ponos/tests/test_workflow.py b/arkindex/ponos/tests/test_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..c6dd0a5642e9f76180a70c3c60ef169dfaf1449d --- /dev/null +++ b/arkindex/ponos/tests/test_workflow.py @@ -0,0 +1,263 @@ +from unittest.mock import patch + +from django.test import TestCase +from django.urls import reverse +from django.utils import timezone +from rest_framework import status + +from ponos.models import Agent, Farm, State, Workflow + +ONE_TASK = """ +tasks: + test: + image: hello-world +""" + +THREE_TASKS = """ +tasks: + test: + image: hello-world + test-subsub: + image: XXX + parents: + - test-sub + test-sub: + image: YYY:latest + parents: + - test +""" + +MULTIPLE_PARENTS_TASK = """ +tasks: + test: + image: hello-world + parents: + - test_aaa + - test_bbb + test_aaa: + image: a + test_bbb: + image: b +""" + +TASK_SHM_SIZE = """ +tasks: + test: + image: hello-world + shm_size: 505m +""" + +TASK_INVALID_SHM_SIZE = """ +tasks: + test: + image: hello-world + shm_size: douze +""" + +TASK_SHM_SIZE_NONE = """ +tasks: + test: + image: hello-world + shm_size: null +""" + + +class WorkflowTestCase(TestCase): + """ + Create some workflows & tasks + """ + + @classmethod + def setUpTestData(cls): + cls.farm = Farm.objects.create(name="testfarm") + # Create a fake Agent instance + cls.agent = Agent.objects.create( + hostname="test_agent", + cpu_cores=2, + cpu_frequency=1e9, + public_key="", + farm=cls.farm, + ram_total=2e9, + last_ping=timezone.now(), + ) + + def test_workflow_start_one_task(self): + w = Workflow.objects.create(farm=self.farm, recipe=ONE_TASK) + tasks = w.tasks + self.assertEqual(tasks.count(), 0) + + w.start() + self.assertEqual(tasks.count(), 1) + task = tasks.first() + self.assertFalse(task.parents.exists()) + self.assertEqual(task.run, 0) + self.assertEqual(task.depth, 0) + + def test_workflow_start_three_tasks(self): + w = Workflow.objects.create(farm=self.farm, recipe=THREE_TASKS) + tasks = w.tasks + self.assertEqual(tasks.count(), 0) + + w.start() + self.assertEqual(tasks.count(), 3) + t1, t2, t3 = w.tasks.all() + + self.assertFalse(t1.parents.exists()) + self.assertEqual(t1.run, 0) + self.assertEqual(t1.depth, 0) + + self.assertEqual(t2.parents.get(), t1) + self.assertEqual(t2.run, 0) + self.assertEqual(t2.depth, 1) + + self.assertEqual(t3.parents.get(), t2) + self.assertEqual(t3.run, 0) + self.assertEqual(t3.depth, 2) + + def test_multiple_parents_task(self): + w = Workflow.objects.create(farm=self.farm, recipe=MULTIPLE_PARENTS_TASK) + + w.start() + self.assertEqual(w.tasks.count(), 3) + + task = w.tasks.get(slug="test") + self.assertEqual([t.slug for t in task.parents.all()], ["test_aaa", "test_bbb"]) + + def test_workflow_running_override(self): + """ + Test that a single running task in a workflow will override any other state + """ + w = Workflow.objects.create(farm=self.farm, recipe=THREE_TASKS) + w.start() + + self.assertEqual(w.tasks.count(), 3) + self.assertEqual(w.state, State.Unscheduled) + t1, t2, t3 = w.tasks.all() + + t1.state = State.Running + t1.save() + self.assertEqual(w.state, State.Running) + + t2.state = State.Error + t2.save() + self.assertEqual(w.state, State.Running) + + t2.state = State.Failed + t2.save() + self.assertEqual(w.state, State.Running) + + t2.state = State.Stopping + t2.save() + self.assertEqual(w.state, State.Running) + + t2.state = State.Stopped + t2.save() + self.assertEqual(w.state, State.Running) + + t2.state = State.Pending + t2.save() + self.assertEqual(w.state, State.Running) + + t2.state = State.Completed + t3.state = State.Completed + t2.save() + t3.save() + self.assertEqual(w.state, State.Running) + + @patch("ponos.models.s3") + @patch("ponos.models.Task.s3_logs_get_url") + def test_task_parents_update(self, s3_mock, s3_logs_mock): + w = Workflow.objects.create(farm=self.farm, recipe=MULTIPLE_PARENTS_TASK) + w.start() + + child = w.tasks.get(slug="test") + parent1, parent2 = w.tasks.exclude(slug="test") + + response = self.client.get( + reverse("ponos:agent-actions"), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + data={"cpu_load": 0.99, "ram_load": 0.49}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + actions = response.json()["actions"] + self.assertEqual(len(actions), 1) + self.assertNotEqual(actions[0]["task_id"], str(child.id)) + + parent1.state = State.Completed + parent1.save() + + parent2.agent = self.agent + parent2.save() + + self.client.get( + reverse("ponos:agent-actions"), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + ) + # Completing 2nd parent should change test_task state to pending + response = self.client.patch( + reverse("ponos:task-details", kwargs={"pk": str(parent2.id)}), + HTTP_AUTHORIZATION="Bearer {}".format(self.agent.token.access_token), + data={"state": State.Completed.value}, + content_type="application/json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + child.refresh_from_db() + self.assertEqual(child.state, State.Pending) + self.assertEqual(w.state, State.Pending) + + def test_workflow_retry_updates_finished(self): + w = Workflow.objects.create(farm=self.farm, recipe=ONE_TASK) + w.start() + w.tasks.update(state=State.Completed) + w.finished = timezone.now() + w.save() + + with self.assertNumQueries(7): + w.retry() + self.assertEqual(w.tasks.count(), 2) + self.assertIsNone(w.finished) + + def test_workflow_stop_unscheduled_updates_finished(self): + w = Workflow.objects.create(farm=self.farm, recipe=ONE_TASK) + w.start() + # All tasks should be now unscheduled. + self.assertFalse(w.tasks.exclude(state=State.Unscheduled).exists()) + + # Stopping the workflow causes all tasks to immediately be Stopped + w.stop() + self.assertFalse(w.tasks.exclude(state=State.Stopped).exists()) + self.assertIsNotNone(w.finished) + + def test_workflow_start_shm_size(self): + w = Workflow.objects.create(farm=self.farm, recipe=TASK_SHM_SIZE) + tasks = w.tasks + self.assertEqual(tasks.count(), 0) + + w.start() + self.assertEqual(tasks.count(), 1) + task = tasks.first() + self.assertFalse(task.parents.exists()) + self.assertEqual(task.shm_size, "505m") + self.assertEqual(task.run, 0) + self.assertEqual(task.depth, 0) + + def test_workflow_start_invalid_shm_size(self): + w = Workflow.objects.create(farm=self.farm, recipe=TASK_INVALID_SHM_SIZE) + tasks = w.tasks + self.assertEqual(tasks.count(), 0) + with self.assertRaises(AssertionError) as e: + w.start() + self.assertEqual(str(e.exception), "douze is not a valid value for shm_size") + + def test_workflow_start_none_shm_size(self): + w = Workflow.objects.create(farm=self.farm, recipe=TASK_SHM_SIZE_NONE) + tasks = w.tasks + self.assertEqual(tasks.count(), 0) + w.start() + self.assertEqual(tasks.count(), 1) + task = tasks.first() + self.assertFalse(task.parents.exists()) + self.assertEqual(task.shm_size, None) + self.assertEqual(task.run, 0) + self.assertEqual(task.depth, 0) diff --git a/arkindex/ponos/urls.py b/arkindex/ponos/urls.py new file mode 100644 index 0000000000000000000000000000000000000000..7917b0dbb27fddfd08fd79b5a7eef1a288e1238d --- /dev/null +++ b/arkindex/ponos/urls.py @@ -0,0 +1,52 @@ +from django.urls import path + +from ponos.api import ( + AgentActions, + AgentDetails, + AgentRegister, + AgentsState, + AgentTokenRefresh, + FarmList, + PublicKeyEndpoint, + SecretDetails, + TaskArtifactDownload, + TaskArtifacts, + TaskCreate, + TaskDefinition, + TaskDetailsFromAgent, + TaskUpdate, + WorkflowDetails, +) + +app_name = "ponos" +urlpatterns = [ + path("v1/workflow/<uuid:pk>/", WorkflowDetails.as_view(), name="workflow-details"), + path("v1/task/", TaskCreate.as_view(), name="task-create"), + path("v1/task/<uuid:pk>/", TaskUpdate.as_view(), name="task-update"), + path( + "v1/task/<uuid:pk>/from-agent/", + TaskDetailsFromAgent.as_view(), + name="task-details", + ), + path( + "v1/task/<uuid:pk>/definition/", + TaskDefinition.as_view(), + name="task-definition", + ), + path( + "v1/task/<uuid:pk>/artifacts/", TaskArtifacts.as_view(), name="task-artifacts" + ), + path( + "v1/task/<uuid:pk>/artifact/<path:path>", + TaskArtifactDownload.as_view(), + name="task-artifact-download", + ), + path("v1/agent/", AgentRegister.as_view(), name="agent-register"), + path("v1/agent/<uuid:pk>/", AgentDetails.as_view(), name="agent-details"), + path("v1/agent/refresh/", AgentTokenRefresh.as_view(), name="agent-token-refresh"), + path("v1/agent/actions/", AgentActions.as_view(), name="agent-actions"), + path("v1/agents/", AgentsState.as_view(), name="agents-state"), + path("v1/public-key/", PublicKeyEndpoint.as_view(), name="public-key"), + path("v1/secret/<path:name>", SecretDetails.as_view(), name="secret-details"), + path("v1/farms/", FarmList.as_view(), name="farm-list"), +] diff --git a/arkindex/ponos/validators.py b/arkindex/ponos/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..d04e99551af79b3b5a4aa8612c835a5fbf04242f --- /dev/null +++ b/arkindex/ponos/validators.py @@ -0,0 +1,28 @@ +from django.core import validators + + +class HiddenCallableValidatorMixin(object): + """ + Implements a workaround for some issues with error messages in DRF + and with drf-spectacular OpenAPI schema generation when the `limit_value` + of any validator extending django.core.validators.BaseValidator is + a callable. This rewrites `self.limit_value` as a property, + which calls the original limit value when it is callable while making + Django, DRF and Spectacular believe it isn't callable. + + https://github.com/encode/django-rest-framework/discussions/8833 + https://github.com/tfranzel/drf-spectacular/issues/913 + """ + + def __init__(self, limit_value, message=None): + self._limit_value = limit_value + if message: + self.message = message + + @property + def limit_value(self): + return self._limit_value() if callable(self._limit_value) else self._limit_value + + +class MaxValueValidator(HiddenCallableValidatorMixin, validators.MaxValueValidator): + pass