diff --git a/arkindex/process/api.py b/arkindex/process/api.py index b1dfb348b851fa2ca1a69f39f43a48305a137ab7..fed3ed9f2b5326d9c6a56c2c39f52d35701df8d3 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -21,7 +21,7 @@ from django.db.models import ( Q, Value, ) -from django.db.models.functions import Coalesce, Now +from django.db.models.functions import Coalesce, Collate, Now from django.db.models.query import Prefetch from django.shortcuts import get_object_or_404, redirect from django.utils.functional import cached_property @@ -68,6 +68,7 @@ from arkindex.process.models import ( WorkerRun, WorkerType, WorkerVersion, + WorkerVersionState, ) from arkindex.process.serializers.files import DataFileCreateSerializer, DataFileSerializer from arkindex.process.serializers.imports import ( @@ -1133,6 +1134,70 @@ class FeatureWorkerVersionRetrieve(RetrieveAPIView): raise NotFound +@extend_schema(tags=["worker versions"]) +@extend_schema_view( + get=extend_schema( + operation_id="RetrieveRecommendedWorkerVersion", + parameters=[ + OpenApiParameter( + "id", + type=UUID, + location=OpenApiParameter.PATH, + description="UUID of the worker to get a recommended worker version for.", + required=True, + ) + ], + ) +) +class RecommendedWorkerVersionRetrieve(RetrieveAPIView): + """ + Get the recommended worker version for a given worker. + + Requires an execution access to the worker. + + This recommended worker version is the latest version with a `main` or `master` **branch**, if it exists. + If it does not, then the recommended worker version is the latest version with a **tag** that starts with + a number and is not labelled as an unstable version according to the Python version specifiers rules + (https://packaging.python.org/en/latest/specifications/version-specifiers/) + """ + permission_classes = (IsVerified, ) + serializer_class = WorkerVersionSerializer + + @cached_property + def worker(self): + worker = get_object_or_404(Worker, pk=self.kwargs["pk"]) + if not worker.is_executable(self.request.user): + raise PermissionDenied(detail="You do not have an execution access to this worker.") + return worker + + def get_object(self): + # If it exists, return the most recent version on the master or main branch + main_branch_version = ( + self.worker.versions \ + .filter(branch__in=["master", "main"]) + .filter(state=WorkerVersionState.Available) \ + .using("default") \ + .select_related("worker__type") \ + .distinct() \ + .order_by("-created") \ + .first() + ) + if main_branch_version: + return main_branch_version + + # Otherwise, return the latest tagged version with a valid version tag (x.x.x shaped, where the first character + # is a digit) which is not an unstable version (excluding release candidates, alpha versions etc), following the + # Python version specifiers specs https://packaging.python.org/en/latest/specifications/version-specifiers/ + tagged_versions_qs = (self.worker.versions.filter(tag__isnull=False) \ + .filter(state=WorkerVersionState.Available) \ + .annotate(filterable_tag=Collate("tag", "C")) \ + .exclude(filterable_tag__iregex=r"^[^\d]|[\d._-]+(?:a(?:lpha)?|b(?:eta)?|r?c|pre(?:view)?)[._-]?\d*$") \ + .using("default") \ + .select_related("worker__type") \ + .distinct() \ + .order_by("-tag", "-created")) + return get_object_or_404(tagged_versions_qs[:1]) + @extend_schema(tags=["workers"]) @extend_schema_view( get=extend_schema( diff --git a/arkindex/process/migrations/0047_alter_workerversion_tag.py b/arkindex/process/migrations/0047_alter_workerversion_tag.py new file mode 100644 index 0000000000000000000000000000000000000000..d64f11d4e83047c1daf36e7a5853544b40efa5a3 --- /dev/null +++ b/arkindex/process/migrations/0047_alter_workerversion_tag.py @@ -0,0 +1,28 @@ +# Generated by Django 5.0.8 on 2025-01-21 10:59 + +from django.contrib.postgres.operations import CreateCollation +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("process", "0046_workerrun_ttl"), + ] + + operations = [ + # This collation allows us to sort the "tag" field as semantic version tags. The sorting is done using US English + # rules, treating numeric values as single numbers and not split by digit (so that 10.2.1 is not the same as 1.0.21). + # The ordering is non-deterministic because two strings can have the same value (1.2 == 01.2). + CreateCollation( + "english_numeric", + provider="icu", + locale="en-US-u-kn-true-u-ks-level2", + deterministic=False + ), + migrations.AlterField( + model_name="workerversion", + name="tag", + field=models.CharField(blank=True, db_collation="english_numeric", default=None, max_length=512, null=True), + ), + ] diff --git a/arkindex/process/models.py b/arkindex/process/models.py index 4b0ac2d2aecc3638fe008dbe0903fe13aeb960db..ae3c90b1a304c194cb08d305945df82fc29825b0 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -729,7 +729,7 @@ class WorkerVersion(models.Model): # URL of the commit for this version, when worker is based on a repository revision_url = models.URLField(null=True, blank=True, max_length=250, default=None) branch = models.CharField(blank=True, null=True, max_length=512, default=None) - tag = models.CharField(blank=True, null=True, max_length=512, default=None) + tag = models.CharField(blank=True, null=True, max_length=512, default=None, db_collation="english_numeric") feature = EnumField( ArkindexFeature, diff --git a/arkindex/process/tests/worker_versions/test_recommended_worker_version.py b/arkindex/process/tests/worker_versions/test_recommended_worker_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c637af154def82e7cfe98e1d8adcd11ceb2757fc --- /dev/null +++ b/arkindex/process/tests/worker_versions/test_recommended_worker_version.py @@ -0,0 +1,187 @@ +from unittest.mock import call, patch + +from django.urls import reverse +from rest_framework import status + +from arkindex.process.models import FeatureUsage, Worker, WorkerVersionState +from arkindex.project.tests import FixtureAPITestCase +from arkindex.users.models import Role + + +class TestRecommendedWorkerVersion(FixtureAPITestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + + cls.worker_reco = Worker.objects.get(slug="reco") + cls.version_1 = cls.worker_reco.versions.get() + cls.main_branch_version = cls.worker_reco.versions.create( + configuration={}, + branch="main", + revision_url="https://gitlab.com/NERV/eva/commit/12", + state=WorkerVersionState.Available, + docker_image_iid="registry.somewhere.com/something:latest" + ) + cls.other_branch_version = cls.worker_reco.versions.create( + configuration={}, + branch="eva-00", + revision_url="https://gitlab.com/NERV/eva/commit/00002", + tag="0.3.2rc-2", + state=WorkerVersionState.Available, + docker_image_iid="registry.somewhere.com/something:latest" + ) + cls.tagged_version_1 = cls.worker_reco.versions.create( + configuration={}, + revision_url="https://gitlab.com/NERV/eva/commit/1234", + tag="0.2.8", + state=WorkerVersionState.Available, + docker_image_iid="registry.somewhere.com/something:latest" + ) + cls.tagged_version_2 = cls.worker_reco.versions.create( + configuration={}, + revision_url="https://gitlab.com/NERV/eva/commit/5678", + tag="0.3.1", + state=WorkerVersionState.Available, + docker_image_iid="registry.somewhere.com/something:latest" + ) + cls.bad_tag_version = cls.worker_reco.versions.create( + configuration={}, + revision_url="https://gitlab.com/NERV/eva/commit/0246", + tag="shiny-rock", + state=WorkerVersionState.Available, + docker_image_iid="registry.somewhere.com/something:latest" + ) + + def test_requires_login(self): + with self.assertNumQueries(0): + response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_requires_verified(self): + self.user.verified_email = False + self.user.save() + self.client.force_login(self.user) + + with self.assertNumQueries(2): + response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @patch("arkindex.users.utils.get_max_level", return_value=Role.Guest.value) + def test_requires_contributor(self, max_level_mock): + self.client.force_login(self.user) + + with self.assertNumQueries(3): + response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + self.assertDictEqual(response.json(), { + "detail": "You do not have an execution access to this worker.", + }) + self.assertEqual(max_level_mock.call_count, 1) + self.assertEqual(max_level_mock.call_args_list, [ + call(self.user, self.worker_reco) + ]) + + def test_worker_doesnt_exist(self): + self.client.force_login(self.user) + + with self.assertNumQueries(3): + response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_no_recommended_version(self): + test_worker = Worker.objects.create( + name="Test worker", + slug="testworker", + type=self.worker_reco.type, + ) + test_worker.versions.create( + configuration={}, + revision_url="https://gitlab.com/SEELE/scrolls/commit/1234", + state=WorkerVersionState.Available, + docker_image_iid="registry.somewhere.com/something:latest" + ) + test_worker.versions.create( + configuration={}, + revision_url="https://gitlab.com/SEELE/scrolls/commit/5678", + branch="deadsea", + state=WorkerVersionState.Available, + docker_image_iid="registry.somewhere.com/something:latest" + ) + test_worker.versions.create( + configuration={}, + revision_url="https://gitlab.com/SEELE/scrolls/commit/9123", + tag="0.2.5", + state=WorkerVersionState.Error, + docker_image_iid="registry.somewhere.com/something:latest" + ) + + self.client.force_login(self.user) + + with self.assertNumQueries(5): + response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(test_worker.id)})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + self.assertDictEqual(response.json(), {"detail": "No WorkerVersion matches the given query."}) + + def test_main_branch_version(self): + self.client.force_login(self.user) + + with self.assertNumQueries(4): + response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertDictEqual(response.json(), { + "id": str(self.main_branch_version.id), + "configuration": {}, + "docker_image_iid": self.main_branch_version.docker_image_iid, + "state": "available", + "gpu_usage": FeatureUsage.Disabled.value, + "model_usage": FeatureUsage.Disabled.value, + "worker": { + "id": str(self.worker_reco.id), + "name": self.worker_reco.name, + "type": self.worker_reco.type.slug, + "slug": self.worker_reco.slug, + "description": self.worker_reco.description, + "archived": bool(self.worker_reco.archived), + "repository_url": self.worker_reco.repository_url, + }, + "version": None, + "tag": None, + "branch": "main", + "revision_url": "https://gitlab.com/NERV/eva/commit/12", + "created": self.main_branch_version.created.isoformat().replace("+00:00", "Z"), + }) + + def test_tagged_version(self): + self.main_branch_version.delete() + self.client.force_login(self.user) + + with self.assertNumQueries(5): + response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertDictEqual(response.json(), { + "id": str(self.tagged_version_2.id), + "configuration": {}, + "docker_image_iid": self.tagged_version_2.docker_image_iid, + "state": "available", + "gpu_usage": FeatureUsage.Disabled.value, + "model_usage": FeatureUsage.Disabled.value, + "worker": { + "id": str(self.worker_reco.id), + "name": self.worker_reco.name, + "type": self.worker_reco.type.slug, + "slug": self.worker_reco.slug, + "description": self.worker_reco.description, + "archived": bool(self.worker_reco.archived), + "repository_url": self.worker_reco.repository_url, + }, + "version": None, + "tag": "0.3.1", + "branch": None, + "revision_url": "https://gitlab.com/NERV/eva/commit/5678", + "created": self.tagged_version_2.created.isoformat().replace("+00:00", "Z"), + }) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 17c22a99ebae0fec9fcef8e65d6254047f602906..008705ec60447481cf6ab8d78d2d0ae4fd8298be 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -82,6 +82,7 @@ from arkindex.process.api import ( ProcessList, ProcessRetry, ProcessWorkersActivity, + RecommendedWorkerVersionRetrieve, S3ImportCreate, SelectProcessFailures, StartProcess, @@ -242,6 +243,7 @@ api = [ path("workers/<uuid:pk>/", WorkerRetrieve.as_view(), name="worker-retrieve"), path("workers/<uuid:pk>/configurations/", WorkerConfigurationList.as_view(), name="worker-configurations"), path("workers/<uuid:pk>/versions/", WorkerVersionList.as_view(), name="worker-versions"), + path("workers/<uuid:pk>/versions/recommended/", RecommendedWorkerVersionRetrieve.as_view(), name="recommended-worker-version"), path("workers/versions/<uuid:pk>/", WorkerVersionRetrieve.as_view(), name="version-retrieve"), path("workers/versions/<uuid:pk>/activity/", UpdateWorkerActivity.as_view(), name="update-worker-activity"), path("workers/versions/feature/<feature>/", FeatureWorkerVersionRetrieve.as_view(), name="feature-worker-version"),