Skip to content
Snippets Groups Projects
Commit e705687f authored by ml bonhomme's avatar ml bonhomme :bee:
Browse files

Create RecommendedWorkerVersionRetrieve endpoint

parent 8ed021de
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !2509. Comments created here will be created in the context of that merge request.
......@@ -21,6 +21,7 @@ from django.db.models import (
Q,
Value,
)
from django.db.models.expressions import RawSQL
from django.db.models.functions import Coalesce, Now
from django.db.models.query import Prefetch
from django.shortcuts import get_object_or_404, redirect
......@@ -68,6 +69,7 @@ from arkindex.process.models import (
WorkerRun,
WorkerType,
WorkerVersion,
WorkerVersionState,
)
from arkindex.process.serializers.files import DataFileCreateSerializer, DataFileSerializer
from arkindex.process.serializers.imports import (
......@@ -1133,6 +1135,66 @@ class FeatureWorkerVersionRetrieve(RetrieveAPIView):
raise NotFound
@extend_schema(tags=["worker versions"])
@extend_schema_view(
get=extend_schema(
operation_id="RetrieveRecommendedWorkerVersion",
parameters=[
OpenApiParameter(
"id",
type=UUID,
location=OpenApiParameter.PATH,
description="UUID of the worker to get a recommended worker version for.",
required=True,
)
],
)
)
class RecommendedWorkerVersionRetrieve(RetrieveAPIView):
permission_classes = (IsVerified, )
serializer_class = WorkerVersionCreateSerializer
@cached_property
def worker(self):
worker = get_object_or_404(Worker, pk=self.kwargs["pk"])
if not worker.is_executable(self.request.user):
raise PermissionDenied(detail="You do not have an execution access to this worker.")
return worker
def get_object(self):
try:
# If it exists, return the most recent version on the master or main branch
main_branch_versions = (
self.worker.versions.filter(branch__in=["master", "main"])
.filter(state=WorkerVersionState.Available) \
.using("default") \
.select_related("worker__type") \
.distinct() \
.order_by("-created")
)
if main_branch_versions.exists():
return main_branch_versions.first()
# The first ordering by tag_as_int converts a tag that looks like "0.3.2" into 32 and ensures that versions
# are ordered in a way that makes sense (if both '0.1.2' and '0.1.3' tags exist, '0.1.3' is returned as 13 > 12).
# tag_as_int is created by splitting on an eventual '-' so if we have '0.3.2' and '0.3.2-rc2' they both have 32
# as `tag_as_int`. The second ordering by tag (regular alphabetical ordering) ensures that the main version (0.3.2)
# is returned in this case.
# This will not order such version tags satisfactorily if the '-rc{number}' part goes beyond 9, as this is
# alphabetical ordering so 1 > 10 > 2 > 3 > 4 etc.
return (
self.worker.versions.filter(tag__isnull=False)
.filter(state=WorkerVersionState.Available) \
.annotate(tag_as_int=RawSQL("split_part(regexp_replace(tag, '[^0-9-]', '', 'g'), '-', 1)::int", []))
.using("default") \
.select_related("worker__type") \
.distinct() \
.order_by("-tag_as_int", "tag", "-created")
.first()
)
except WorkerVersion.DoesNotExist:
raise NotFound
@extend_schema(tags=["workers"])
@extend_schema_view(
get=extend_schema(
......
from unittest.mock import call, patch
from django.urls import reverse
from rest_framework import status
from arkindex.process.models import FeatureUsage, Worker, WorkerVersionState
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role
class TestRecommendedWorkerVersion(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.worker_reco = Worker.objects.get(slug="reco")
cls.version_1 = cls.worker_reco.versions.get()
cls.main_branch_version = cls.worker_reco.versions.create(
configuration={},
branch="main",
revision_url="https://gitlab.com/NERV/eva/commit/12",
state=WorkerVersionState.Available,
docker_image_iid="registry.somewhere.com/something:latest"
)
cls.other_branch_version = cls.worker_reco.versions.create(
configuration={},
branch="eva-00",
revision_url="https://gitlab.com/NERV/eva/commit/00002",
tag="0.3.0rc-2",
state=WorkerVersionState.Available,
docker_image_iid="registry.somewhere.com/something:latest"
)
cls.tagged_version_1 = cls.worker_reco.versions.create(
configuration={},
revision_url="https://gitlab.com/NERV/eva/commit/1234",
tag="0.2.8",
state=WorkerVersionState.Available,
docker_image_iid="registry.somewhere.com/something:latest"
)
cls.tagged_version_2 = cls.worker_reco.versions.create(
configuration={},
revision_url="https://gitlab.com/NERV/eva/commit/5678",
tag="0.1.4",
state=WorkerVersionState.Available,
docker_image_iid="registry.somewhere.com/something:latest"
)
def test_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@patch("arkindex.users.utils.get_max_level", return_value=Role.Guest.value)
def test_requires_contributor(self, max_level_mock):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {
"detail": "You do not have an execution access to this worker.",
})
self.assertEqual(max_level_mock.call_count, 1)
self.assertEqual(max_level_mock.call_args_list, [
call(self.user, self.worker_reco)
])
def test_main_branch_version(self):
self.client.force_login(self.user)
with self.assertNumQueries(5):
response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.main_branch_version.id),
"configuration": {},
"docker_image_iid": self.main_branch_version.docker_image_iid,
"state": "available",
"gpu_usage": FeatureUsage.Disabled.value,
"model_usage": FeatureUsage.Disabled.value,
"worker": {
"id": str(self.worker_reco.id),
"name": self.worker_reco.name,
"type": self.worker_reco.type.slug,
"slug": self.worker_reco.slug,
"description": self.worker_reco.description,
"archived": bool(self.worker_reco.archived),
"repository_url": self.worker_reco.repository_url,
},
"version": None,
"tag": None,
"branch": "main",
"revision_url": "https://gitlab.com/NERV/eva/commit/12",
"created": self.main_branch_version.created.isoformat().replace("+00:00", "Z"),
})
def test_tagged_version(self):
self.main_branch_version.delete()
self.client.force_login(self.user)
with self.assertNumQueries(5):
response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.other_branch_version.id),
"configuration": {},
"docker_image_iid": self.other_branch_version.docker_image_iid,
"state": "available",
"gpu_usage": FeatureUsage.Disabled.value,
"model_usage": FeatureUsage.Disabled.value,
"worker": {
"id": str(self.worker_reco.id),
"name": self.worker_reco.name,
"type": self.worker_reco.type.slug,
"slug": self.worker_reco.slug,
"description": self.worker_reco.description,
"archived": bool(self.worker_reco.archived),
"repository_url": self.worker_reco.repository_url,
},
"version": None,
"tag": "0.3.0rc-2",
"branch": "eva-00",
"revision_url": "https://gitlab.com/NERV/eva/commit/00002",
"created": self.other_branch_version.created.isoformat().replace("+00:00", "Z"),
})
def test_tagged_version_rc(self):
"""
The version with the main tag is returned when release candidate versions (or anything with a tag
that looks like x.x.x-something) also exist
"""
self.main_branch_version.delete()
self.tagged_version_1.tag = "0.3.0"
self.tagged_version_1.save()
self.client.force_login(self.user)
with self.assertNumQueries(5):
response = self.client.get(reverse("api:recommended-worker-version", kwargs={"pk": str(self.worker_reco.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.tagged_version_1.id),
"configuration": {},
"docker_image_iid": self.tagged_version_1.docker_image_iid,
"state": "available",
"gpu_usage": FeatureUsage.Disabled.value,
"model_usage": FeatureUsage.Disabled.value,
"worker": {
"id": str(self.worker_reco.id),
"name": self.worker_reco.name,
"type": self.worker_reco.type.slug,
"slug": self.worker_reco.slug,
"description": self.worker_reco.description,
"archived": bool(self.worker_reco.archived),
"repository_url": self.worker_reco.repository_url,
},
"version": None,
"tag": "0.3.0",
"branch": None,
"revision_url": "https://gitlab.com/NERV/eva/commit/1234",
"created": self.tagged_version_1.created.isoformat().replace("+00:00", "Z"),
})
......@@ -82,6 +82,7 @@ from arkindex.process.api import (
ProcessList,
ProcessRetry,
ProcessWorkersActivity,
RecommendedWorkerVersionRetrieve,
S3ImportCreate,
SelectProcessFailures,
StartProcess,
......@@ -242,6 +243,7 @@ api = [
path("workers/<uuid:pk>/", WorkerRetrieve.as_view(), name="worker-retrieve"),
path("workers/<uuid:pk>/configurations/", WorkerConfigurationList.as_view(), name="worker-configurations"),
path("workers/<uuid:pk>/versions/", WorkerVersionList.as_view(), name="worker-versions"),
path("workers/<uuid:pk>/versions/recommended/", RecommendedWorkerVersionRetrieve.as_view(), name="recommended-worker-version"),
path("workers/versions/<uuid:pk>/", WorkerVersionRetrieve.as_view(), name="version-retrieve"),
path("workers/versions/<uuid:pk>/activity/", UpdateWorkerActivity.as_view(), name="update-worker-activity"),
path("workers/versions/feature/<feature>/", FeatureWorkerVersionRetrieve.as_view(), name="feature-worker-version"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment