From 1356a2ded3106acbb95092c9f8135d318abf5655 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 25 Feb 2025 09:49:27 +0000 Subject: [PATCH] Use a machine fingerprint for agent identification --- arkindex/metrics/tests/test_metrics_api.py | 1 + .../ponos/migrations/0017_remove_agents.py | 38 +++++++++++++ .../migrations/0018_agent_fingerprint.py | 32 +++++++++++ arkindex/ponos/models.py | 6 ++- arkindex/ponos/serializer_fields.py | 53 ------------------- .../ponos/tests/tasks/test_partial_update.py | 2 + arkindex/ponos/tests/tasks/test_retrieve.py | 2 + arkindex/ponos/tests/tasks/test_update.py | 4 +- arkindex/ponos/tests/test_models.py | 18 +++---- arkindex/process/tests/process/test_create.py | 1 + .../process/tests/worker_runs/test_delete.py | 1 + .../tests/worker_runs/test_partial_update.py | 1 + .../tests/worker_runs/test_retrieve.py | 1 + .../process/tests/worker_runs/test_update.py | 1 + 14 files changed, 97 insertions(+), 64 deletions(-) create mode 100644 arkindex/ponos/migrations/0017_remove_agents.py create mode 100644 arkindex/ponos/migrations/0018_agent_fingerprint.py diff --git a/arkindex/metrics/tests/test_metrics_api.py b/arkindex/metrics/tests/test_metrics_api.py index 917979e597..4a0719a5f5 100644 --- a/arkindex/metrics/tests/test_metrics_api.py +++ b/arkindex/metrics/tests/test_metrics_api.py @@ -35,6 +35,7 @@ class TestMetricsAPI(FixtureAPITestCase): mode=AgentMode.Docker, hostname="Demo Agent", farm=farm, + fingerprint="demo" * 16, last_ping=datetime.now(), cpu_cores=42, cpu_frequency=42e8, diff --git a/arkindex/ponos/migrations/0017_remove_agents.py b/arkindex/ponos/migrations/0017_remove_agents.py new file mode 100644 index 0000000000..87be71102c --- /dev/null +++ b/arkindex/ponos/migrations/0017_remove_agents.py @@ -0,0 +1,38 @@ +# Generated by Django 5.0.8 on 2025-02-18 11:48 + +from django.core.management.base import CommandError +from django.db import migrations + +from arkindex.ponos.models import State + + +def remove_agents(apps, schema_editor): + Agent = apps.get_model("ponos", "Agent") + GPU = apps.get_model("ponos", "GPU") + Task = apps.get_model("ponos", "Task") + + if Task.objects.exclude(agent=None, gpu=None).filter(state=State.Running).exists(): + raise CommandError( + "All existing Ponos agents and GPUs are about to be deleted, but some are currently assigned to running tasks.\n" + "Wait for the tasks to finish or stop them before running this migration." + ) + + Task.objects.exclude(gpu=None).update(gpu=None) + Task.objects.exclude(agent=None).update(agent=None) + GPU.objects.all().delete() + Agent.objects.all().delete() + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0016_agent_token_constraints"), + ] + + operations = [ + migrations.RunPython( + remove_agents, + reverse_code=migrations.RunPython.noop, + elidable=True, + ), + ] diff --git a/arkindex/ponos/migrations/0018_agent_fingerprint.py b/arkindex/ponos/migrations/0018_agent_fingerprint.py new file mode 100644 index 0000000000..75c7a285dc --- /dev/null +++ b/arkindex/ponos/migrations/0018_agent_fingerprint.py @@ -0,0 +1,32 @@ +# Generated by Django 5.0.8 on 2025-02-18 11:48 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("ponos", "0017_remove_agents"), + ] + + operations = [ + # Set a default value for the public key so that the migration is reversible + migrations.AlterField( + model_name="agent", + name="public_key", + field=models.TextField(default=""), + ), + migrations.RemoveField( + model_name="agent", + name="public_key", + ), + migrations.AddField( + model_name="agent", + name="fingerprint", + field=models.CharField(max_length=64), + ), + migrations.AddConstraint( + model_name="agent", + constraint=models.UniqueConstraint(models.F("fingerprint"), name="unique_agent_fingerprint"), + ), + ] diff --git a/arkindex/ponos/models.py b/arkindex/ponos/models.py index d2671227b8..2211c3e89d 100644 --- a/arkindex/ponos/models.py +++ b/arkindex/ponos/models.py @@ -89,7 +89,7 @@ class Agent(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) farm = models.ForeignKey(Farm, on_delete=models.PROTECT) - public_key = models.TextField() + fingerprint = models.CharField(max_length=64) mode = EnumField(AgentMode, default=AgentMode.Docker, max_length=20) accept_tasks = models.BooleanField(default=True) @@ -120,6 +120,10 @@ class Agent(models.Model): "token", name="unique_agent_token", ), + models.UniqueConstraint( + "fingerprint", + name="unique_agent_fingerprint", + ) ] def __str__(self) -> str: diff --git a/arkindex/ponos/serializer_fields.py b/arkindex/ponos/serializer_fields.py index 1bd83378c8..80eef856a8 100644 --- a/arkindex/ponos/serializer_fields.py +++ b/arkindex/ponos/serializer_fields.py @@ -1,61 +1,8 @@ -import base64 -from cryptography.exceptions import UnsupportedAlgorithm -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, load_pem_public_key -from rest_framework import serializers from arkindex.ponos.utils import get_process_from_task_auth -class PublicKeyField(serializers.CharField): - """ - An EC public key, serialized in PEM format - """ - - default_error_messages = { - "invalid_pem": "Incorrect PEM data", - "unsupported_algorithm": "Key algorithm is not supported", - "not_ec": "Key is not an EC public key", - } - - def to_internal_value(self, data) -> ec.EllipticCurvePublicKey: - data = super().to_internal_value(data) - try: - key = load_pem_public_key( - data.encode("utf-8"), - backend=default_backend(), - ) - except ValueError: - self.fail("invalid_pem") - except UnsupportedAlgorithm: - self.fail("unsupported_algorithm") - - if not isinstance(key, ec.EllipticCurvePublicKey): - self.fail("not_ec") - - return key - - def to_representation(self, key: ec.EllipticCurvePublicKey) -> str: - return key.public_bytes( - Encoding.PEM, - PublicFormat.SubjectPublicKeyInfo, - ).decode("utf-8") - - -class Base64Field(serializers.CharField): - """ - A base64-encoded bytestring. - """ - - def to_internal_value(self, data) -> bytes: - return base64.b64decode(super().to_internal_value(data)) - - def to_representation(self, obj: bytes) -> str: - return base64.b64encode(obj) - - class CurrentProcessDefault: """ Use the process of the currently authenticated task as a default value. diff --git a/arkindex/ponos/tests/tasks/test_partial_update.py b/arkindex/ponos/tests/tasks/test_partial_update.py index f4c8adbe30..45646f7edd 100644 --- a/arkindex/ponos/tests/tasks/test_partial_update.py +++ b/arkindex/ponos/tests/tasks/test_partial_update.py @@ -29,6 +29,7 @@ class TestTaskPartialUpdate(FixtureAPITestCase): cls.docker_agent = Agent.objects.create( mode=AgentMode.Docker, farm=cls.farm, + fingerprint="a" * 64, last_ping=datetime.now(timezone.utc), cpu_cores=42, cpu_frequency=42e8, @@ -37,6 +38,7 @@ class TestTaskPartialUpdate(FixtureAPITestCase): cls.slurm_agent = Agent.objects.create( mode=AgentMode.Slurm, farm=cls.farm, + fingerprint="b" * 64, last_ping=datetime.now(timezone.utc), ) diff --git a/arkindex/ponos/tests/tasks/test_retrieve.py b/arkindex/ponos/tests/tasks/test_retrieve.py index 5fe9581a41..7643f6d853 100644 --- a/arkindex/ponos/tests/tasks/test_retrieve.py +++ b/arkindex/ponos/tests/tasks/test_retrieve.py @@ -51,6 +51,7 @@ class TestTaskRetrieve(FixtureAPITestCase): cls.docker_agent = Agent.objects.create( mode=AgentMode.Docker, farm=cls.farm, + fingerprint="a" * 64, last_ping=datetime.now(timezone.utc), cpu_cores=42, cpu_frequency=42e8, @@ -59,6 +60,7 @@ class TestTaskRetrieve(FixtureAPITestCase): cls.slurm_agent = Agent.objects.create( mode=AgentMode.Slurm, farm=cls.farm, + fingerprint="b" * 64, last_ping=datetime.now(timezone.utc), ) diff --git a/arkindex/ponos/tests/tasks/test_update.py b/arkindex/ponos/tests/tasks/test_update.py index 619c8bbfec..bd93e8766e 100644 --- a/arkindex/ponos/tests/tasks/test_update.py +++ b/arkindex/ponos/tests/tasks/test_update.py @@ -39,14 +39,16 @@ class TestTaskUpdate(FixtureAPITestCase): cls.docker_agent = Agent.objects.create( mode=AgentMode.Docker, farm=cls.farm, + fingerprint="a" * 64, last_ping=datetime.now(timezone.utc), cpu_cores=42, cpu_frequency=42e8, - ram_total=42e3 + ram_total=42e3, ) cls.slurm_agent = Agent.objects.create( mode=AgentMode.Slurm, farm=cls.farm, + fingerprint="b" * 64, last_ping=datetime.now(timezone.utc), ) diff --git a/arkindex/ponos/tests/test_models.py b/arkindex/ponos/tests/test_models.py index a94dc73734..ce83600bac 100644 --- a/arkindex/ponos/tests/test_models.py +++ b/arkindex/ponos/tests/test_models.py @@ -145,7 +145,7 @@ class TestModels(FixtureAPITestCase): hostname="agent_smith", cpu_cores=2, cpu_frequency=4.2e9, - public_key="", + fingerprint="a" * 64, farm=self.farm, ram_total=2e9, last_ping=timezone.now(), @@ -159,7 +159,7 @@ class TestModels(FixtureAPITestCase): def test_agent_slurm_mode(self): Agent.objects.create( hostname="agent_smith", - public_key="", + fingerprint="a" * 64, farm=self.farm, last_ping=timezone.now(), mode=AgentMode.Slurm.value @@ -174,7 +174,7 @@ class TestModels(FixtureAPITestCase): hostname="agent_smith", cpu_cores=2, cpu_frequency=4.2e9, - public_key="", + fingerprint="a" * 64, farm=self.farm, ram_total=2e9, last_ping=timezone.now(), @@ -195,7 +195,7 @@ class TestModels(FixtureAPITestCase): "params": { "hostname": "agent_smith", "cpu_frequency": 4.2e9, - "public_key": "", + "fingerprint": "a" * 64, "farm": self.farm, "ram_total": 2e9, "last_ping": timezone.now(), @@ -210,7 +210,7 @@ class TestModels(FixtureAPITestCase): "params": { "hostname": "agent_smith", "cpu_cores": 2, - "public_key": "", + "fingerprint": "a" * 64, "farm": self.farm, "ram_total": 2e9, "last_ping": timezone.now(), @@ -226,7 +226,7 @@ class TestModels(FixtureAPITestCase): "hostname": "agent_smith", "cpu_cores": 2, "cpu_frequency": 4.2e9, - "public_key": "", + "fingerprint": "a" * 64, "farm": self.farm, "last_ping": timezone.now(), "ram_load": 0.49, @@ -241,7 +241,7 @@ class TestModels(FixtureAPITestCase): "hostname": "agent_smith", "cpu_cores": None, "cpu_frequency": None, - "public_key": "", + "fingerprint": "a" * 64, "farm": self.farm, "ram_total": None, "last_ping": timezone.now(), @@ -257,7 +257,7 @@ class TestModels(FixtureAPITestCase): "hostname": "agent_smith", "cpu_cores": 2, "cpu_frequency": 4.2e9, - "public_key": "", + "fingerprint": "a" * 64, "farm": self.farm, "ram_total": 2e9, "last_ping": timezone.now(), @@ -271,7 +271,7 @@ class TestModels(FixtureAPITestCase): "mode": AgentMode.Slurm, "params": { "hostname": "agent_smith", - "public_key": "", + "fingerprint": "b" * 64, "farm": self.farm, "last_ping": timezone.now(), "ram_load": 0.49, diff --git a/arkindex/process/tests/process/test_create.py b/arkindex/process/tests/process/test_create.py index b11f6da848..17278e26e3 100644 --- a/arkindex/process/tests/process/test_create.py +++ b/arkindex/process/tests/process/test_create.py @@ -33,6 +33,7 @@ class TestCreateProcess(FixtureAPITestCase): super().setUpTestData() cls.agent = Agent.objects.create( farm=Farm.objects.first(), + fingerprint="a" * 64, hostname="claude", cpu_cores=42, cpu_frequency=1e15, diff --git a/arkindex/process/tests/worker_runs/test_delete.py b/arkindex/process/tests/worker_runs/test_delete.py index e01006daa0..364cda2564 100644 --- a/arkindex/process/tests/worker_runs/test_delete.py +++ b/arkindex/process/tests/worker_runs/test_delete.py @@ -32,6 +32,7 @@ class TestWorkerRunsDelete(FixtureAPITestCase): cls.agent = Agent.objects.create( farm=cls.farm, + fingerprint="a" * 64, hostname="claude", cpu_cores=42, cpu_frequency=1e15, diff --git a/arkindex/process/tests/worker_runs/test_partial_update.py b/arkindex/process/tests/worker_runs/test_partial_update.py index e6a4606b5a..91c635b57c 100644 --- a/arkindex/process/tests/worker_runs/test_partial_update.py +++ b/arkindex/process/tests/worker_runs/test_partial_update.py @@ -95,6 +95,7 @@ class TestWorkerRunsPartialUpdate(FixtureAPITestCase): cls.agent = Agent.objects.create( farm=cls.farm, + fingerprint="a" * 64, hostname="claude", cpu_cores=42, cpu_frequency=1e15, diff --git a/arkindex/process/tests/worker_runs/test_retrieve.py b/arkindex/process/tests/worker_runs/test_retrieve.py index 42e7146a4a..0ffe02b790 100644 --- a/arkindex/process/tests/worker_runs/test_retrieve.py +++ b/arkindex/process/tests/worker_runs/test_retrieve.py @@ -41,6 +41,7 @@ class TestWorkerRunsRetrieve(FixtureAPITestCase): cls.agent = Agent.objects.create( farm=cls.farm, + fingerprint="a" * 64, hostname="claude", cpu_cores=42, cpu_frequency=1e15, diff --git a/arkindex/process/tests/worker_runs/test_update.py b/arkindex/process/tests/worker_runs/test_update.py index 54226950e5..a79cc7059a 100644 --- a/arkindex/process/tests/worker_runs/test_update.py +++ b/arkindex/process/tests/worker_runs/test_update.py @@ -69,6 +69,7 @@ class TestWorkerRunsUpdate(FixtureAPITestCase): cls.agent = Agent.objects.create( farm=cls.farm, + fingerprint="a" * 64, hostname="claude", cpu_cores=42, cpu_frequency=1e15, -- GitLab