Skip to content
Snippets Groups Projects
Commit 52e8ff49 authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

Add agent token field

parent 37965e7e
No related branches found
No related tags found
1 merge request!2523Add agent token field
...@@ -12,7 +12,7 @@ from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveUp ...@@ -12,7 +12,7 @@ from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveUp
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task, task_token_default from arkindex.ponos.models import FINAL_STATES, Artifact, State, Task, token_default
from arkindex.ponos.permissions import ( from arkindex.ponos.permissions import (
IsAgentOrArtifactGuest, IsAgentOrArtifactGuest,
IsAgentOrTaskGuest, IsAgentOrTaskGuest,
...@@ -234,7 +234,7 @@ class TaskRestart(ProcessACLMixin, CreateAPIView): ...@@ -234,7 +234,7 @@ class TaskRestart(ProcessACLMixin, CreateAPIView):
copy.id = uuid.uuid4() copy.id = uuid.uuid4()
copy.slug = basename copy.slug = basename
copy.state = State.Pending copy.state = State.Pending
copy.token = task_token_default() copy.token = token_default()
copy.agent_id = None copy.agent_id = None
copy.gpu_id = None copy.gpu_id = None
copy.started = None copy.started = None
......
...@@ -114,7 +114,7 @@ class Migration(migrations.Migration): ...@@ -114,7 +114,7 @@ class Migration(migrations.Migration):
("updated", models.DateTimeField(auto_now=True)), ("updated", models.DateTimeField(auto_now=True)),
("expiry", models.DateTimeField(default=arkindex.ponos.models.expiry_default)), ("expiry", models.DateTimeField(default=arkindex.ponos.models.expiry_default)),
("extra_files", django.contrib.postgres.fields.hstore.HStoreField(default=dict, blank=True)), ("extra_files", django.contrib.postgres.fields.hstore.HStoreField(default=dict, blank=True)),
("token", models.CharField(default=arkindex.ponos.models.task_token_default, max_length=52, unique=True)), ("token", models.CharField(default=arkindex.ponos.models.token_default, max_length=52, unique=True)),
("agent", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks", to="ponos.agent")), ("agent", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks", to="ponos.agent")),
("gpu", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks", to="ponos.gpu")), ("gpu", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks", to="ponos.gpu")),
("image_artifact", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks_using_image", to="ponos.artifact")), ("image_artifact", models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name="tasks_using_image", to="ponos.artifact")),
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
from django.db import migrations, models from django.db import migrations, models
from arkindex.ponos.models import generate_seed, task_token_default from arkindex.ponos.models import generate_seed, token_default
class Migration(migrations.Migration): class Migration(migrations.Migration):
...@@ -89,7 +89,7 @@ class Migration(migrations.Migration): ...@@ -89,7 +89,7 @@ class Migration(migrations.Migration):
model_name="task", model_name="task",
name="token", name="token",
field=models.CharField( field=models.CharField(
default=task_token_default, default=token_default,
max_length=52, max_length=52,
), ),
), ),
......
# Generated by Django 5.0.8 on 2025-02-17 13:50
from django.db import migrations, models
from arkindex.ponos.models import token_default
def add_agent_tokens(apps, schema_editor):
Agent = apps.get_model("ponos", "Agent")
to_update = []
for agent in Agent.objects.filter(token=None).only("id").iterator():
agent.token = token_default()
to_update.append(agent)
Agent.objects.bulk_update(to_update, ["token"], batch_size=100)
class Migration(migrations.Migration):
dependencies = [
("ponos", "0014_task_task_finished_requires_final_state"),
]
operations = [
migrations.AddField(
model_name="agent",
name="token",
field=models.CharField(
max_length=52,
# Make the field temporarily nullable and not unique, so that we can
# fill the tokens on existing agents before adding the constraints.
null=True,
),
),
migrations.RunPython(
add_agent_tokens,
reverse_code=migrations.RunPython.noop,
),
]
# Generated by Django 5.0.8 on 2025-02-17 14:26
from django.db import migrations, models
import arkindex.ponos.models
class Migration(migrations.Migration):
dependencies = [
("ponos", "0015_agent_token"),
]
operations = [
migrations.AlterField(
model_name="agent",
name="token",
field=models.CharField(default=arkindex.ponos.models.token_default, max_length=52),
),
migrations.AddConstraint(
model_name="agent",
constraint=models.UniqueConstraint(models.F("token"), name="unique_agent_token"),
),
]
...@@ -34,6 +34,15 @@ def gen_nonce(size=16): ...@@ -34,6 +34,15 @@ def gen_nonce(size=16):
return urandom(size) return urandom(size)
def token_default():
"""
Default value for task and agent tokens.
:rtype: str
"""
return base64.encodebytes(uuid.uuid4().bytes + uuid.uuid4().bytes).strip().decode("utf-8")
class Farm(models.Model): class Farm(models.Model):
""" """
A group of agents, whose ID and seed can be used to register new agents A group of agents, whose ID and seed can be used to register new agents
...@@ -95,12 +104,22 @@ class Agent(models.Model): ...@@ -95,12 +104,22 @@ class Agent(models.Model):
ram_load = models.FloatField(null=True, blank=True) ram_load = models.FloatField(null=True, blank=True)
last_ping = models.DateTimeField(editable=False) last_ping = models.DateTimeField(editable=False)
token = models.CharField(
default=token_default,
# The token generation always returns 52 characters
max_length=52,
)
class Meta: class Meta:
constraints = [ constraints = [
models.CheckConstraint( models.CheckConstraint(
check=Q(mode=AgentMode.Slurm) | Q(cpu_cores__isnull=False, cpu_frequency__isnull=False, ram_total__isnull=False), check=Q(mode=AgentMode.Slurm) | Q(cpu_cores__isnull=False, cpu_frequency__isnull=False, ram_total__isnull=False),
name="slurm_or_hardware_requirements", name="slurm_or_hardware_requirements",
), ),
models.UniqueConstraint(
"token",
name="unique_agent_token",
),
] ]
def __str__(self) -> str: def __str__(self) -> str:
...@@ -224,15 +243,6 @@ def expiry_default(): ...@@ -224,15 +243,6 @@ def expiry_default():
return timezone.now() + timedelta(days=settings.PONOS_TASK_EXPIRY) return timezone.now() + timedelta(days=settings.PONOS_TASK_EXPIRY)
def task_token_default():
"""
Default value for Task.token.
:rtype: str
"""
return base64.encodebytes(uuid.uuid4().bytes + uuid.uuid4().bytes).strip().decode("utf-8")
class TaskLogs(S3FileMixin): class TaskLogs(S3FileMixin):
s3_bucket = settings.PONOS_S3_LOGS_BUCKET s3_bucket = settings.PONOS_S3_LOGS_BUCKET
...@@ -357,7 +367,7 @@ class Task(models.Model): ...@@ -357,7 +367,7 @@ class Task(models.Model):
extra_files = HStoreField(default=dict, blank=True) extra_files = HStoreField(default=dict, blank=True)
token = models.CharField( token = models.CharField(
default=task_token_default, default=token_default,
# The token generation always returns 52 characters # The token generation always returns 52 characters
max_length=52, max_length=52,
) )
......
...@@ -13,7 +13,7 @@ from django.utils.functional import cached_property ...@@ -13,7 +13,7 @@ from django.utils.functional import cached_property
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from arkindex.images.models import ImageServer from arkindex.images.models import ImageServer
from arkindex.ponos.models import GPU, Task, task_token_default from arkindex.ponos.models import GPU, Task, token_default
class ProcessBuilder: class ProcessBuilder:
...@@ -79,7 +79,7 @@ class ProcessBuilder: ...@@ -79,7 +79,7 @@ class ProcessBuilder:
Build a Task with default attributes and add it to the current stack. Build a Task with default attributes and add it to the current stack.
Depth is not set while building individual Task instances. Depth is not set while building individual Task instances.
""" """
token = task_token_default() token = token_default()
env = { env = {
**self.base_env, **self.base_env,
......
...@@ -15,7 +15,7 @@ from enumfields import Enum, EnumField ...@@ -15,7 +15,7 @@ from enumfields import Enum, EnumField
import pgtrigger import pgtrigger
from arkindex.documents.models import Classification, Element from arkindex.documents.models import Classification, Element
from arkindex.ponos.models import FINAL_STATES, STATES_ORDERING, State, Task, task_token_default from arkindex.ponos.models import FINAL_STATES, STATES_ORDERING, State, Task, token_default
from arkindex.process.builder import ProcessBuilder from arkindex.process.builder import ProcessBuilder
from arkindex.process.managers import ( from arkindex.process.managers import (
ActivityManager, ActivityManager,
...@@ -1012,7 +1012,7 @@ class WorkerRun(models.Model): ...@@ -1012,7 +1012,7 @@ class WorkerRun(models.Model):
) )
task_env = env.copy() task_env = env.copy()
token = task_token_default() token = token_default()
task_env["ARKINDEX_TASK_TOKEN"] = token task_env["ARKINDEX_TASK_TOKEN"] = token
task_env["TASK_ELEMENTS"] = elements_path task_env["TASK_ELEMENTS"] = elements_path
task_env["ARKINDEX_WORKER_RUN_ID"] = str(self.id) task_env["ARKINDEX_WORKER_RUN_ID"] = str(self.id)
......
...@@ -25,7 +25,7 @@ class TestProcessRun(FixtureTestCase): ...@@ -25,7 +25,7 @@ class TestProcessRun(FixtureTestCase):
) )
@override_settings(PONOS_DEFAULT_ENV={"ARKINDEX_API_TOKEN": "testToken"}) @override_settings(PONOS_DEFAULT_ENV={"ARKINDEX_API_TOKEN": "testToken"})
@patch("arkindex.process.builder.task_token_default") @patch("arkindex.process.builder.token_default")
def test_pdf_import_run(self, token_mock): def test_pdf_import_run(self, token_mock):
process = self.corpus.processes.create( process = self.corpus.processes.create(
creator=self.user, creator=self.user,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment