Skip to content
Snippets Groups Projects
Commit 8c5d5f4a authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

Merge branch 'list-models-endpoint' into 'master'

List models endpoint

Closes #978

See merge request !1648
parents 50101273 141c8371
No related branches found
No related tags found
1 merge request!1648List models endpoint
......@@ -88,7 +88,7 @@ from arkindex.documents.api.ml import (
from arkindex.documents.api.search import CorpusSearch
from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
from arkindex.project.openapi import OpenApiSchemaView
from arkindex.training.api import ModelCreate, ModelVersionsCreate, ModelVersionsUpdate
from arkindex.training.api import ModelsList, ModelVersionsCreate, ModelVersionsUpdate
from arkindex.users.api import (
CredentialsList,
CredentialsRetrieve,
......@@ -239,7 +239,7 @@ api = [
# ML models training
path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'),
path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'),
path('model/', ModelCreate.as_view(), name='model-create'),
path('models/', ModelsList.as_view(), name='models'),
# Image management
path('image/', ImageCreate.as_view(), name='image-create'),
......
from drf_spectacular.utils import extend_schema, extend_schema_view
from django.db.models import Q
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.generics import CreateAPIView, UpdateAPIView
from rest_framework.generics import CreateAPIView, ListCreateAPIView, UpdateAPIView
from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.permissions import IsVerified
......@@ -81,6 +82,22 @@ class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView):
@extend_schema(tags=['training'])
@extend_schema_view(
get=extend_schema(
operation_id='ListModels',
description=(
'List available Machine Learning models.'
),
responses={
200: ModelSerializer,
},
parameters=[
OpenApiParameter(
'name',
description='Filter models whose name contains the given string (case insensitive)',
required=False,
),
]
),
post=extend_schema(
description=(
'Create a new Machine Learning model.'
......@@ -92,10 +109,17 @@ class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView):
},
),
)
class ModelCreate(TrainingModelMixin, CreateAPIView):
class ModelsList(TrainingModelMixin, ListCreateAPIView):
permission_classes = (IsVerified, )
serializer_class = ModelSerializer
def get_queryset(self):
filters = Q()
if 'name' in self.request.query_params:
filters &= Q(name__icontains=self.request.query_params['name'])
return self.readable_models.filter(filters).order_by('name')
def perform_create(self, serializer):
model_name = serializer.validated_data.get('name')
existing_model = Model.objects.using('default').filter(name=model_name).first()
......
......@@ -10,9 +10,9 @@ from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.users.models import Group, Right, Role, User
class TestModelVersion(FixtureAPITestCase):
class TestModelAPI(FixtureAPITestCase):
"""
Test model versions
Test model and model version api
"""
@classmethod
......@@ -39,6 +39,8 @@ class TestModelVersion(FixtureAPITestCase):
# Create some Model Versions
cls.model_version1 = ModelVersion.objects.create(model=cls.model1, hash="aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", size=8)
cls.model_version2 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="tagged", configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", size=8)
cls.model_version3 = ModelVersion.objects.create(model=cls.model2, state="available", tag="tagged", hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", size=8)
cls.model_version4 = ModelVersion.objects.create(model=cls.model2, description="some description", tag="taggedv2", configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa", size=8)
# Create three users with different access rights on each model
cls.user1 = User.objects.create(email='user1@test.test', display_name='User 1', verified_email=True)
......@@ -410,14 +412,14 @@ class TestModelVersion(FixtureAPITestCase):
"""
user = User.objects.create(display_name="Not Verified", verified_email=False)
self.client.force_login(user)
response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"})
response = self.client.post(reverse('api:models'), {"name": "The Best Classifier"})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
def test_create_model_requires_logged_in(self):
"""Model creation requires being logged in
"""
response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"})
response = self.client.post(reverse('api:models'), {"name": "The Best Classifier"})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."})
......@@ -426,7 +428,7 @@ class TestModelVersion(FixtureAPITestCase):
"""
self.client.force_login(self.user1)
with self.assertNumQueries(2):
response = self.client.post(reverse('api:model-create'))
response = self.client.post(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"name": ["This field is required."]})
......@@ -436,7 +438,7 @@ class TestModelVersion(FixtureAPITestCase):
"""
self.client.force_login(self.user1)
with self.assertNumQueries(2):
response = self.client.post(reverse('api:model-create'), {"name": ''})
response = self.client.post(reverse('api:models'), {"name": ''})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"name": ["This field may not be blank."]})
......@@ -450,7 +452,7 @@ class TestModelVersion(FixtureAPITestCase):
"description": "This is actually the best model ever."
}
with self.assertNumQueries(6):
response = self.client.post(reverse('api:model-create'), request)
response = self.client.post(reverse('api:models'), request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
model = Model.objects.get(id=response.json()['id'])
......@@ -470,7 +472,7 @@ class TestModelVersion(FixtureAPITestCase):
"name": self.model1.name,
}
with self.assertNumQueries(6):
response = self.client.post(reverse('api:model-create'), request)
response = self.client.post(reverse('api:models'), request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"id": str(self.model1.id), "name": 'A model with this name already exists'})
......@@ -483,6 +485,57 @@ class TestModelVersion(FixtureAPITestCase):
"name": self.model1.name,
}
with self.assertNumQueries(6):
response = self.client.post(reverse('api:model-create'), request)
response = self.client.post(reverse('api:models'), request)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
def test_list_models_requires_verified(self):
user = User.objects.create(display_name="Not Verified", verified_email=False)
self.client.force_login(user)
response = self.client.get(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
def test_list_models_requires_logged_in(self):
response = self.client.get(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."})
def test_list_models_low_access(self):
"""User 3 has only access to Model2 with a guest access.
He only sees model_version3 of Model2 since it's the only one that
has a set tag and is in Available state.
"""
self.client.force_login(self.user3)
with self.assertNumQueries(6):
response = self.client.get(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_200_OK)
models = response.json()['results']
self.assertEqual(len(models), 1)
self.assertListEqual([model['id'] for model in models], [str(self.model2.id)])
def test_list_models_contrib_access(self):
"""User 2 has contributor access to Model1 and Model2.
He has contributor access on both that's why he sees all related versions.
Models list is ordered by name, first Model1 (named 'First Model') then Model2 (named 'Second Model').
"""
self.client.force_login(self.user2)
with self.assertNumQueries(6):
response = self.client.get(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_200_OK)
models = response.json()['results']
self.assertEqual(len(models), 2)
self.assertListEqual([model['id'] for model in models], [str(self.model1.id), str(self.model2.id)])
def test_list_models_filter_name(self):
"""User 2 has access to both Models, use search parameter name=Second returns only Model2
"""
self.client.force_login(self.user2)
with self.assertNumQueries(6):
response = self.client.get(reverse('api:models'), {'name': "second"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
models = response.json()['results']
self.assertEqual(len(models), 1)
self.assertListEqual([model['id'] for model in models], [str(self.model2.id)])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment