diff --git a/arkindex/process/builder.py b/arkindex/process/builder.py index dd3445337f83bb0c55cfdf514be4636feca1b67f..93e918f6aa62d4852bf854937c4fb209884d2aa2 100644 --- a/arkindex/process/builder.py +++ b/arkindex/process/builder.py @@ -1,5 +1,6 @@ import shlex from collections import defaultdict +from datetime import timedelta from functools import wraps from os import path from typing import Dict, List, Sequence, Tuple @@ -7,10 +8,12 @@ from uuid import UUID from django.conf import settings from django.db.models import Prefetch, prefetch_related_objects +from django.utils import timezone +from django.utils.functional import cached_property from rest_framework.exceptions import ValidationError from arkindex.images.models import ImageServer -from arkindex.ponos.models import Task, task_token_default +from arkindex.ponos.models import GPU, Task, task_token_default class ProcessBuilder(object): @@ -162,6 +165,10 @@ class ProcessBuilder(object): env["ARKINDEX_CORPUS_ID"] = str(self.process.corpus_id) return env + @cached_property + def active_gpu_agents(self) -> bool: + return GPU.objects.filter(agent__last_ping__gt=timezone.now() - timedelta(seconds=30)).exists() + @prefetch_worker_runs def validate_gpu_requirement(self): from arkindex.process.models import FeatureUsage @@ -298,6 +305,7 @@ class ProcessBuilder(object): chunk=index if len(chunks) > 1 else None, workflow_runs=worker_runs, run=self.run, + active_gpu_agents=self.active_gpu_agents, ) self.tasks.append(task) self.tasks_parents[task.slug].extend(parent_slugs) diff --git a/arkindex/process/models.py b/arkindex/process/models.py index 714e382c911fb479acf4b43837cda616a5b2c127..3acbede35ef467e3d796b044fc0ac552f02ed3e2 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -911,7 +911,7 @@ class WorkerRun(models.Model): # we add the WorkerRun ID at the end of the slug return f"{self.version.worker.slug}_{str(self.id)[:6]}" - def build_task(self, process, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None): + def build_task(self, process, env, import_task_name, elements_path, run=0, chunk=None, workflow_runs=None, active_gpu_agents=False): """ Build the Task that will represent this WorkerRun in ponos using : - the docker image name given by the WorkerVersion @@ -967,6 +967,12 @@ class WorkerRun(models.Model): assert self.model_version.state == ModelVersionState.Available, f"ModelVersion {self.model_version.id} is not available and cannot be used to build a task." extra_files = {"model": settings.PUBLIC_HOSTNAME + reverse("api:model-version-download", kwargs={"pk": self.model_version.id}) + f"?token={self.model_version.build_authentication_token_hash()}"} + requires_gpu = process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported) + # Do not require a GPU if there are no active agents with GPU and the GPU feature is only supported by the worker version; + # this does not make sense in the context of RQ tasks execution + if not settings.PONOS_RQ_EXECUTION and not active_gpu_agents and self.version.gpu_usage != FeatureUsage.Required: + requires_gpu = False + task = Task( command=self.version.docker_command, image=self.version.docker_image_iid or self.version.docker_image_name, @@ -981,7 +987,7 @@ class WorkerRun(models.Model): process=process, worker_run=self, extra_files=extra_files, - requires_gpu=process.use_gpu and self.version.gpu_usage in (FeatureUsage.Required, FeatureUsage.Supported) + requires_gpu=requires_gpu ) return task, parents diff --git a/arkindex/process/tests/test_create_process.py b/arkindex/process/tests/test_create_process.py index e15b55f9947473a23471c2fb95f35d61660d3a32..cfb1f59958dd66a00fe54fd99f65bce7fdf9e388 100644 --- a/arkindex/process/tests/test_create_process.py +++ b/arkindex/process/tests/test_create_process.py @@ -1,3 +1,4 @@ +import uuid from collections import namedtuple from datetime import datetime, timezone from unittest.mock import call, patch @@ -7,9 +8,10 @@ from rest_framework import status from rest_framework.reverse import reverse from arkindex.documents.models import Corpus, Element -from arkindex.ponos.models import Farm, State +from arkindex.ponos.models import GPU, Agent, Farm, State from arkindex.process.models import ( ActivityState, + FeatureUsage, Process, ProcessDataset, ProcessMode, @@ -31,6 +33,21 @@ class TestCreateProcess(FixtureAPITestCase): @classmethod def setUpTestData(cls): super().setUpTestData() + cls.agent = Agent.objects.create( + farm=Farm.objects.first(), + hostname="claude", + cpu_cores=42, + cpu_frequency=1e15, + ram_total=99e9, + last_ping=datetime.now(timezone.utc), + ) + cls.agent.gpus.create( + id=uuid.uuid4(), + name="claudette", + index=2, + ram_total=12 + ) + cls.volume = Element.objects.get(name="Volume 1") cls.pages = Element.objects.get_descending(cls.volume.id).filter(type__slug="page", polygon__isnull=False) cls.ml_class = cls.corpus.ml_classes.create(name="bretzel") @@ -585,7 +602,7 @@ class TestCreateProcess(FixtureAPITestCase): self.assertFalse(self.corpus.worker_versions.exists()) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process_2.id)}), {"worker_activity": True}, @@ -676,7 +693,7 @@ class TestCreateProcess(FixtureAPITestCase): ) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process_2.id)}), {"use_cache": True}, @@ -714,7 +731,7 @@ class TestCreateProcess(FixtureAPITestCase): @patch("arkindex.ponos.models.base64.encodebytes") def test_create_process_use_gpu_option(self, token_mock): """ - A process with the `use_gpu` parameter enables the `requires_gpu` attribute of tasks than need one + A process with the `use_gpu` parameter enables the `requires_gpu` attribute of tasks that need one """ token_mock.side_effect = [b"12345", b"67891"] process_2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) @@ -724,7 +741,7 @@ class TestCreateProcess(FixtureAPITestCase): ) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process_2.id)}), {"use_gpu": True}, @@ -755,6 +772,60 @@ class TestCreateProcess(FixtureAPITestCase): self.assertEqual(len(worker_task.parents.all()), 1) self.assertEqual(worker_task.parents.first(), init_task) + @override_settings( + ARKINDEX_TASKS_IMAGE="registry.teklia.com/tasks", + PONOS_DEFAULT_ENV={} + ) + @patch("arkindex.ponos.models.base64.encodebytes") + def test_create_process_use_gpu_option_no_available_gpus(self, token_mock): + """ + If there are no avilables Agents with GPU, then requires_gpu is not sent if the worker + version only supports and not requires it + """ + self.agent.gpus.all().delete() + token_mock.side_effect = [b"12345", b"67891", b"54321", b"19876"] + + for feature_usage, requires_gpu, task_token in [ + (FeatureUsage.Supported, False, "67891"), + (FeatureUsage.Required, True, "19876") + ]: + with self.subTest(feature_usage=feature_usage, requires_gpu=requires_gpu): + + process = self.corpus.processes.create( + creator=self.user, + mode=ProcessMode.Workers, + farm=Farm.objects.first(), + ) + run = process.worker_runs.create( + version=self.version_3, + parents=[], + ) + process.use_gpu = True + self.assertEqual(GPU.objects.count(), 0) + self.version_3.gpu_usage = feature_usage + self.version_3.save() + process.run() + + init_task = process.tasks.get(slug="initialisation") + self.assertEqual(init_task.command, f"python -m arkindex_tasks.init_elements {process.id} --chunks-number 1") + self.assertEqual(init_task.image, "registry.teklia.com/tasks") + + worker_task = process.tasks.get(slug=run.task_slug) + self.assertEqual(worker_task.command, None) + self.assertEqual(worker_task.image, f"my_repo.fake/workers/worker/worker-gpu:{self.version_3.id}") + self.assertEqual(worker_task.image_artifact.id, self.version_3.docker_image.id) + self.assertEqual(worker_task.shm_size, None) + self.assertEqual(worker_task.env, { + "TASK_ELEMENTS": "/data/initialisation/elements.json", + "ARKINDEX_CORPUS_ID": str(self.corpus.id), + "ARKINDEX_PROCESS_ID": str(process.id), + "ARKINDEX_WORKER_RUN_ID": str(process.worker_runs.get().id), + "ARKINDEX_TASK_TOKEN": task_token + }) + self.assertEqual(worker_task.requires_gpu, requires_gpu) + self.assertEqual(len(worker_task.parents.all()), 1) + self.assertEqual(worker_task.parents.first(), init_task) + def test_retry_keeps_requires_gpu(self): """ When a process is retried, the newly created tasks keep the same requires_gpu values @@ -819,7 +890,7 @@ class TestCreateProcess(FixtureAPITestCase): process.use_gpu = True process.save() self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process.id)}), {"use_gpu": "true"} @@ -907,7 +978,7 @@ class TestCreateProcess(FixtureAPITestCase): process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) process.versions.add(custom_version) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)})) self.assertEqual(response.status_code, status.HTTP_201_CREATED) diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index 38d06536e2e1100675d3201fa2c0ee905327ca7e..1af5d378662ac38aa8ecf1b4cc4a65b6d2a206a2 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -1749,7 +1749,7 @@ class TestProcesses(FixtureAPITestCase): self.workers_process.activity_state = ActivityState.Error self.workers_process.save() - with self.assertNumQueries(13): + with self.assertNumQueries(14): response = self.client.post(reverse("api:process-retry", kwargs={"pk": self.workers_process.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -2126,7 +2126,7 @@ class TestProcesses(FixtureAPITestCase): with ( self.settings(IMPORTS_WORKER_VERSION=str(self.version_with_model.id)), - self.assertNumQueries(8) + self.assertNumQueries(9) ): response = self.client.post(reverse("api:files-process"), { "files": [str(self.img_df.id)], @@ -2217,7 +2217,7 @@ class TestProcesses(FixtureAPITestCase): self.assertFalse(process2.tasks.exists()) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process2.id)}) ) @@ -2362,7 +2362,7 @@ class TestProcesses(FixtureAPITestCase): self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process2.id)}) ) @@ -2477,7 +2477,7 @@ class TestProcesses(FixtureAPITestCase): self.assertFalse(process.tasks.exists()) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process.id)}) ) @@ -2499,7 +2499,7 @@ class TestProcesses(FixtureAPITestCase): farm = Farm.objects.get(name="Wheat farm") self.client.force_login(self.user) - with self.assertNumQueries(15): + with self.assertNumQueries(16): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(workers_process.id)}), {"farm": str(farm.id)} @@ -2697,7 +2697,7 @@ class TestProcesses(FixtureAPITestCase): self.client.force_login(self.user) - with self.assertNumQueries(15): + with self.assertNumQueries(16): response = self.client.post( reverse("api:process-start", kwargs={"pk": str(process.id)}), {"use_cache": "true", "worker_activity": "true", "use_gpu": "true"} @@ -2733,7 +2733,7 @@ class TestProcesses(FixtureAPITestCase): self.assertNotEqual(run_1.task_slug, run_2.task_slug) self.client.force_login(self.user) - with self.assertNumQueries(14): + with self.assertNumQueries(15): response = self.client.post(reverse("api:process-start", kwargs={"pk": str(process.id)})) self.assertEqual(response.status_code, status.HTTP_201_CREATED)