From 52e8ff495b36d0bbdb46c2413df81765a41fe585 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 18 Feb 2025 11:31:09 +0000 Subject: [PATCH] Add agent token field --- arkindex/ponos/api.py | 4 +- arkindex/ponos/migrations/0001_initial.py | 2 +- .../ponos/migrations/0004_index_cleanup.py | 4 +- arkindex/ponos/migrations/0015_agent_token.py | 38 +++++++++++++++++++ .../0016_agent_token_constraints.py | 24 ++++++++++++ arkindex/ponos/models.py | 30 ++++++++++----- arkindex/process/builder.py | 4 +- arkindex/process/models.py | 4 +- arkindex/process/tests/process/test_run.py | 2 +- 9 files changed, 92 insertions(+), 20 deletions(-) create mode 100644 arkindex/ponos/migrations/0015_agent_token.py create mode 100644 arkindex/ponos/migrations/0016_agent_token_constraints.py diff --git a/arkindex/ponos/api.py b/arkindex/ponos/api.py index 91f6d607da..7616e944de 100644 --- a/arkindex/ponos/api.py +++ b/arkindex/ponos/api.py @@ -12,7 +12,7 @@ from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveUp from rest_framework.response import Response from rest_framework.views import APIView -from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task, task_token_default +from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task, token_default from arkindex.ponos.permissions import ( IsAgentOrArtifactGuest, IsAgentOrTaskGuest, @@ -234,7 +234,7 @@ class TaskRestart(ProcessACLMixin, CreateAPIView): copy.id = uuid.uuid4() copy.slug = basename copy.state = State.Pending - copy.token = task_token_default() + copy.token = token_default() copy.agent_id = None copy.gpu_id = None copy.started = None diff --git a/arkindex/ponos/migrations/0001_initial.py b/arkindex/ponos/migrations/0001_initial.py index 7a72645956..0c9d9a2378 100644 --- a/arkindex/ponos/migrations/0001_initial.py +++ b/arkindex/ponos/migrations/0001_initial.py @@ -114,7 +114,7 @@ class Migration(migrations.Migration): ("updated", models.DateTimeField(auto_now=True)), ("expiry", models.DateTimeField(default=arkindex.ponos.models.expiry_default)), ("extra_files", django.contrib.postgres.fields.hstore.HStoreField(default=dict, blank=True)), - ("token", models.CharField(default=arkindex.ponos.models.task_token_default, max_length=52, unique=True)), + ("token", models.CharField(default=arkindex.ponos.models.token_default, max_length=52, unique=True)), ("agent", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks", to="ponos.agent")), ("gpu", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks", to="ponos.gpu")), ("image_artifact", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks_using_image", to="ponos.artifact")), diff --git a/arkindex/ponos/migrations/0004_index_cleanup.py b/arkindex/ponos/migrations/0004_index_cleanup.py index 6e25ec7c74..939e253118 100644 --- a/arkindex/ponos/migrations/0004_index_cleanup.py +++ b/arkindex/ponos/migrations/0004_index_cleanup.py @@ -3,7 +3,7 @@ from django.core.validators import RegexValidator from django.db import migrations, models -from arkindex.ponos.models import generate_seed, task_token_default +from arkindex.ponos.models import generate_seed, token_default class Migration(migrations.Migration): @@ -89,7 +89,7 @@ class Migration(migrations.Migration): model_name="task", name="token", field=models.CharField( - default=task_token_default, + default=token_default, max_length=52, ), ), diff --git a/arkindex/ponos/migrations/0015_agent_token.py b/arkindex/ponos/migrations/0015_agent_token.py new file mode 100644 index 0000000000..b3f118b8f0 --- /dev/null +++ b/arkindex/ponos/migrations/0015_agent_token.py @@ -0,0 +1,38 @@ +# Generated by Django 5.0.8 on 2025-02-17 13:50 + +from django.db import migrations, models + +from arkindex.ponos.models import token_default + + +def add_agent_tokens(apps, schema_editor): + Agent = apps.get_model("ponos", "Agent") + to_update = [] + for agent in Agent.objects.filter(token=None).only("id").iterator(): + agent.token = token_default() + to_update.append(agent) + Agent.objects.bulk_update(to_update, ["token"], batch_size=100) + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0014_task_task_finished_requires_final_state"), + ] + + operations = [ + migrations.AddField( + model_name="agent", + name="token", + field=models.CharField( + max_length=52, + # Make the field temporarily nullable and not unique, so that we can + # fill the tokens on existing agents before adding the constraints. + null=True, + ), + ), + migrations.RunPython( + add_agent_tokens, + reverse_code=migrations.RunPython.noop, + ), + ] diff --git a/arkindex/ponos/migrations/0016_agent_token_constraints.py b/arkindex/ponos/migrations/0016_agent_token_constraints.py new file mode 100644 index 0000000000..c0ae4fd402 --- /dev/null +++ b/arkindex/ponos/migrations/0016_agent_token_constraints.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.8 on 2025-02-17 14:26 + +from django.db import migrations, models + +import arkindex.ponos.models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0015_agent_token"), + ] + + operations = [ + migrations.AlterField( + model_name="agent", + name="token", + field=models.CharField(default=arkindex.ponos.models.token_default, max_length=52), + ), + migrations.AddConstraint( + model_name="agent", + constraint=models.UniqueConstraint(models.F("token"), name="unique_agent_token"), + ), + ] diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py index e77e10df10..d2671227b8 100644 --- a/arkindex/ponos/models.py +++ b/arkindex/ponos/models.py @@ -34,6 +34,15 @@ def gen_nonce(size=16): return urandom(size) +def token_default(): + """ + Default value for task and agent tokens. + + :rtype: str + """ + return base64.encodebytes(uuid.uuid4().bytes + uuid.uuid4().bytes).strip().decode("utf-8") + + class Farm(models.Model): """ A group of agents, whose ID and seed can be used to register new agents @@ -95,12 +104,22 @@ class Agent(models.Model): ram_load = models.FloatField(null=True, blank=True) last_ping = models.DateTimeField(editable=False) + token = models.CharField( + default=token_default, + # The token generation always returns 52 characters + max_length=52, + ) + class Meta: constraints = [ models.CheckConstraint( check=Q(mode=AgentMode.Slurm) | Q(cpu_cores__isnull=False, cpu_frequency__isnull=False, ram_total__isnull=False), name="slurm_or_hardware_requirements", ), + models.UniqueConstraint( + "token", + name="unique_agent_token", + ), ] def __str__(self) -> str: @@ -224,15 +243,6 @@ def expiry_default(): return timezone.now() + timedelta(days=settings.PONOS_TASK_EXPIRY) -def task_token_default(): - """ - Default value for Task.token. - - :rtype: str - """ - return base64.encodebytes(uuid.uuid4().bytes + uuid.uuid4().bytes).strip().decode("utf-8") - - class TaskLogs(S3FileMixin): s3_bucket = settings.PONOS_S3_LOGS_BUCKET @@ -357,7 +367,7 @@ class Task(models.Model): extra_files = HStoreField(default=dict, blank=True) token = models.CharField( - default=task_token_default, + default=token_default, # The token generation always returns 52 characters max_length=52, ) diff --git a/arkindex/process/builder.py b/arkindex/process/builder.py index 6db8b7afc9..9756ba6944 100644 --- a/arkindex/process/builder.py +++ b/arkindex/process/builder.py @@ -13,7 +13,7 @@ from django.utils.functional import cached_property from rest_framework.exceptions import ValidationError from arkindex.images.models import ImageServer -from arkindex.ponos.models import GPU, Task, task_token_default +from arkindex.ponos.models import GPU, Task, token_default class ProcessBuilder: @@ -79,7 +79,7 @@ class ProcessBuilder: Build a Task with default attributes and add it to the current stack. Depth is not set while building individual Task instances. """ - token = task_token_default() + token = token_default() env = { **self.base_env, diff --git a/arkindex/process/models.py b/arkindex/process/models.py index f9fa1efaa5..16e89ce509 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -15,7 +15,7 @@ from enumfields import Enum, EnumField import pgtrigger from arkindex.documents.models import Classification, Element -from arkindex.ponos.models import FINAL_STATES, STATES_ORDERING, State, Task, task_token_default +from arkindex.ponos.models import FINAL_STATES, STATES_ORDERING, State, Task, token_default from arkindex.process.builder import ProcessBuilder from arkindex.process.managers import ( ActivityManager, @@ -1012,7 +1012,7 @@ class WorkerRun(models.Model): ) task_env = env.copy() - token = task_token_default() + token = token_default() task_env["ARKINDEX_TASK_TOKEN"] = token task_env["TASK_ELEMENTS"] = elements_path task_env["ARKINDEX_WORKER_RUN_ID"] = str(self.id) diff --git a/arkindex/process/tests/process/test_run.py b/arkindex/process/tests/process/test_run.py index 8ca2347b3e..99d606d864 100644 --- a/arkindex/process/tests/process/test_run.py +++ b/arkindex/process/tests/process/test_run.py @@ -25,7 +25,7 @@ class TestProcessRun(FixtureTestCase): ) @override_settings(PONOS_DEFAULT_ENV={"ARKINDEX_API_TOKEN": "testToken"}) - @patch("arkindex.process.builder.task_token_default") + @patch("arkindex.process.builder.token_default") def test_pdf_import_run(self, token_mock): process = self.corpus.processes.create( creator=self.user, -- GitLab