Skip to content
Snippets Groups Projects
Commit 30c1efbe authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Erwan Rouchet
Browse files

Add model rights on models

parent 543abee0
No related branches found
No related tags found
1 merge request!1663Add model rights on models
......@@ -10,6 +10,7 @@ from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.serializer_fields import EnumField
from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
def _model_from_context(serializer_field):
......@@ -22,10 +23,11 @@ _model_from_context.requires_context = True
class ModelSerializer(TrainingModelMixin, serializers.ModelSerializer):
# Actually define the field to avoid the field-level automatically generated UniqueValidator
name = serializers.CharField(max_length=100)
rights = serializers.SerializerMethodField(read_only=True)
class Meta:
model = Model
fields = ('id', 'created', 'updated', 'name', 'description')
fields = ('id', 'created', 'updated', 'name', 'description', 'rights')
def create(self, validated_data):
instance = super().create(validated_data)
......@@ -37,6 +39,17 @@ class ModelSerializer(TrainingModelMixin, serializers.ModelSerializer):
)
return instance
@extend_schema_field(serializers.ListField(child=serializers.ChoiceField(['read', 'write', 'admin'])))
def get_rights(self, model):
level = get_max_level(self.context['request'].user, model)
rights = ['read']
if level >= Role.Contributor.value:
rights.append('write')
if level >= Role.Admin.value:
rights.append('admin')
return rights
class CreateModelErrorResponseSerializer(serializers.Serializer):
id = serializers.UUIDField(required=False, help_text="UUID of an existing model, if the error comes from a duplicate name.")
......
......@@ -15,6 +15,17 @@ def _format_datetime(date):
return str(date.isoformat().replace('+00:00', 'Z')).replace(' ', 'T')
def _deserialize_model(model, access_rights):
return {
'id': str(model.id),
'created': _format_datetime(model.created),
'updated': _format_datetime(model.updated),
'name': model.name,
'description': model.description,
'rights': access_rights
}
class TestModelAPI(FixtureAPITestCase):
"""
Test model and model version api
......@@ -516,10 +527,12 @@ class TestModelAPI(FixtureAPITestCase):
"name": 'The Best Model Ever',
"description": "This is actually the best model ever."
}
with self.assertNumQueries(6):
with self.assertNumQueries(8):
response = self.client.post(reverse('api:models'), request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.json()['rights'], ['read', 'write', 'admin'])
model = Model.objects.get(id=response.json()['id'])
self.assertEqual(model.name, request.get("name"))
self.assertEqual(model.description, request.get("description"))
......@@ -572,13 +585,13 @@ class TestModelAPI(FixtureAPITestCase):
has a set tag and is in Available state.
"""
self.client.force_login(self.user3)
with self.assertNumQueries(6):
with self.assertNumQueries(7):
response = self.client.get(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_200_OK)
models = response.json()['results']
self.assertEqual(len(models), 1)
self.assertListEqual([model['id'] for model in models], [str(self.model2.id)])
self.assertListEqual(models, [_deserialize_model(self.model2, access_rights=['read'])])
def test_list_models_contrib_access(self):
"""User 2 has contributor access to Model1 and Model2.
......@@ -586,25 +599,28 @@ class TestModelAPI(FixtureAPITestCase):
Models list is ordered by name, first Model1 (named 'First Model') then Model2 (named 'Second Model').
"""
self.client.force_login(self.user2)
with self.assertNumQueries(6):
with self.assertNumQueries(8):
response = self.client.get(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_200_OK)
models = response.json()['results']
self.assertEqual(len(models), 2)
self.assertListEqual([model['id'] for model in models], [str(self.model1.id), str(self.model2.id)])
self.assertListEqual(models, [
_deserialize_model(self.model1, access_rights=['read', 'write']),
_deserialize_model(self.model2, access_rights=['read', 'write', 'admin']),
])
def test_list_models_filter_name(self):
"""User 2 has access to both Models, use search parameter name=Second returns only Model2
"""
self.client.force_login(self.user2)
with self.assertNumQueries(6):
with self.assertNumQueries(7):
response = self.client.get(reverse('api:models'), {'name': "second"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
models = response.json()['results']
self.assertEqual(len(models), 1)
self.assertListEqual([model['id'] for model in models], [str(self.model2.id)])
self.assertListEqual(models, [_deserialize_model(self.model2, access_rights=['read', 'write', 'admin'])])
def test_list_model_versions_requires_logged_in(self):
"""To list a model's versions, you need to be logged in.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment