Skip to content
Snippets Groups Projects
Commit a8baf141 authored by Valentin Rigal's avatar Valentin Rigal Committed by Bastien Abadie
Browse files

Split model version API for validation

parent f4f4b74e
No related branches found
No related tags found
1 merge request!1885Split model version API for validation
......@@ -103,6 +103,7 @@ from arkindex.training.api import (
ModelVersionDownload,
ModelVersionsList,
ModelVersionsRetrieve,
ValidateModelVersion,
)
from arkindex.users.api import (
CredentialsList,
......@@ -259,6 +260,7 @@ api = [
# ML models training
path('modelversion/<uuid:pk>/', ModelVersionsRetrieve.as_view(), name='model-version-retrieve'),
path('modelversion/<uuid:pk>/validate/', ValidateModelVersion.as_view(), name='model-version-validate'),
path('models/', ModelsList.as_view(), name='models'),
path('model/<uuid:pk>/', ModelRetrieve.as_view(), name='model-retrieve'),
path('model/<uuid:pk>/versions/', ModelVersionsList.as_view(), name='model-versions'),
......
from django.db.models import Q
from django.shortcuts import get_object_or_404
from django.utils.functional import cached_property
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import serializers, status
from rest_framework import permissions, serializers, status
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveAPIView, RetrieveUpdateDestroyAPIView
from rest_framework.generics import (
CreateAPIView,
GenericAPIView,
ListCreateAPIView,
RetrieveAPIView,
RetrieveUpdateDestroyAPIView,
)
from rest_framework.response import Response
from arkindex.project.mixins import TrainingModelMixin
......@@ -15,8 +20,8 @@ from arkindex.training.serializers import (
MetricValueBulkSerializer,
MetricValueCreateSerializer,
ModelSerializer,
ModelVersionCreateSerializer,
ModelVersionSerializer,
ModelVersionValidateSerializer,
)
from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
......@@ -28,69 +33,78 @@ from arkindex.users.utils import get_max_level
operation_id='ListModelVersions',
description=(
'List available versions of a Machine Learning model.\n\n'
'Requires a **guest** access to the model.'
'Requires a **guest** access to the model, for versions that are available and have a tag,'
'and a **contributor** access otherwise.\n'
'`s3_put_url` is always set to null, `s3_url` is set for **contributors** on available models.'
),
),
post=extend_schema(
operation_id='CreateModelVersion',
description=(
'Create a new version for a Machine Learning model.\n\n'
'Create a new version for a Machine Learning model.\n'
'`s3_put_url` is always set to null, `s3_url` is set for users with a **contributor** access level.'
'The response includes an S3 URL that you can use to upload the archive containing the model.\n'
'Once the archive is successfully uploaded, you can set the size and hash attributes using '
'`PartialUpdateModelVersion`. The state of the model version will then be updated from `created` '
'to `available`, meaning it can be used in a process.\n\n'
'Requires a **contributor** access to the model.'
),
responses={
200: ModelVersionCreateSerializer, 400: None, 403: None
200: ModelVersionSerializer, 400: None, 403: None
},
),
)
class ModelVersionsList(TrainingModelMixin, ListCreateAPIView):
permission_classes = (IsVerified, )
serializer_class = ModelVersionCreateSerializer
queryset = Model.objects.none()
serializer_class = ModelVersionSerializer
def get_queryset(self):
return Model.objects.filter(pk=self.kwargs['pk'])
def perform_create(self, serializer):
if not self.model_access_rights or self.model_access_rights < Role.Contributor.value:
raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.')
serializer.save()
@cached_property
def model(self):
return get_object_or_404(self.get_queryset(), id=self.kwargs['pk'])
@cached_property
def model_access_rights(self):
return get_max_level(self.request.user, self.model)
self.model = self.get_object()
qs = self.model.versions.order_by('-created')
if self.access_level < Role.Contributor.value:
qs = qs.filter(tag__isnull=False, state=ModelVersionState.Available)
return qs
def get_object(self):
model = get_object_or_404(Model.objects.all(), id=self.kwargs['pk'])
self.check_object_permissions(self.request, model)
return model
def check_object_permissions(self, request, model):
self.access_level = get_max_level(self.request.user, model)
needed_level, error_msg = (
Role.Guest.value,
'You need a Guest access to list versions of a model.',
)
if request.method not in permissions.SAFE_METHODS:
needed_level, error_msg = (
Role.Contributor.value,
'You need a Contributor access to the model to create a new version.',
)
if not self.access_level or self.access_level < needed_level:
raise PermissionDenied(detail=error_msg)
return super().check_object_permissions(request, model)
def get_serializer_context(self):
context = super().get_serializer_context()
context['model'] = self.model
context['model_rights'] = self.model_access_rights
context.update({
'model': self.model,
'is_contributor': self.access_level and self.access_level >= Role.Contributor.value,
})
return context
def filter_queryset(self, queryset):
filters = Q(model=self.kwargs['pk'])
access_level = self.model_access_rights
if not access_level or access_level < Role.Guest.value:
raise PermissionDenied(detail='You need a guest access to the model to list its versions.')
# Guest access allow only versions with a set tag and state = available
if access_level < Role.Contributor.value:
filters &= Q(tag__isnull=False, state=ModelVersionState.Available)
# Order version by creation date
queryset = ModelVersion.objects.filter(filters).order_by('-created')
return super().filter_queryset(queryset)
def create(self, request, *args, **kwargs):
self.model = self.get_object()
return super().create(request, *args, **kwargs)
@extend_schema(tags=['training'])
@extend_schema_view(
put=extend_schema(
description=(
'Update a version of a Machine Learning model.\n\n'
'Update a version of a Machine Learning model.\n'
'If the model version is updated with the hash of another available model version, '
'an HTTP error 400 is returned containing data of the available model.\n\n'
'Requires a **contributor** access to the model.'
),
),
......@@ -102,54 +116,133 @@ class ModelVersionsList(TrainingModelMixin, ListCreateAPIView):
)
)
class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView):
"""Retrieve a version of Machine Learning model
"""Retrieve a version of a Machine Learning model.
Requires a **guest** access to the model, for versions that are available and have a tag,
and a **contributor** access otherwise.
`s3_url` and `s3_put_url` fields are only set for users with a **contributor** access level.
"""
permission_classes = (IsVerified, )
serializer_class = ModelVersionSerializer
queryset = ModelVersion.objects.none()
def get_queryset(self):
return ModelVersion.objects \
.filter(id=self.kwargs['pk']) \
.select_related('model')
queryset = ModelVersion.objects.select_related("model")
def get_serializer_context(self):
context = super().get_serializer_context()
context['model'] = self.model
context.update({
'model': self.model_version.model,
'is_contributor': self.access_level and self.access_level >= Role.Contributor.value,
})
return context
@cached_property
def model(self):
return get_object_or_404(self.get_queryset(), id=self.kwargs['pk']).model
@cached_property
def model_access_rights(self):
return get_max_level(self.request.user, self.model)
def get_object(self, *args, **kwargs):
self.model_version = super().get_object(*args, **kwargs)
return self.model_version
def check_object_permissions(self, request, model_version):
access_level = self.model_access_rights
if model_version.state == ModelVersionState.Available and model_version.tag is not None:
needed_level, error_msg = Role.Guest.value, 'Guest'
self.access_level = get_max_level(self.request.user, model_version.model)
if request.method in permissions.SAFE_METHODS:
# A Contributor access is required to retrieve an untagged or non available version
if model_version.state != ModelVersionState.Available or model_version.tag is None:
needed_level, error_msg = (
Role.Contributor.value,
'You need a Contributor access to the model to retrieve this version.',
)
else:
needed_level, error_msg = (
Role.Guest.value,
'You need a Guest access to the model to retrieve this version.',
)
elif request.method.lower() == 'delete':
needed_level, error_msg = (
Role.Admin.value,
'You need an Admin access to the model to destroy this version.',
)
else:
needed_level, error_msg = Role.Contributor.value, 'Contributor'
needed_level, error_msg = (
Role.Contributor.value,
'You need a Contributor access to the model to update this version.',
)
if not access_level or access_level < needed_level:
raise PermissionDenied(detail=f'You need a {error_msg} access to the model to retrieve this version.')
if not self.access_level or self.access_level < needed_level:
raise PermissionDenied(detail=error_msg)
return super().check_object_permissions(request, model_version)
def perform_update(self, serializer):
access_level = self.model_access_rights
if not access_level or access_level < Role.Contributor.value:
raise PermissionDenied(detail='You need a Contributor access to the model to edit this version.')
return super().perform_update(serializer)
def perform_destroy(self, instance):
access_level = self.model_access_rights
if not access_level or access_level < Role.Admin.value:
raise PermissionDenied(detail='You need an Admin access to the model to destroy this version.')
class ValidateModelVersion(TrainingModelMixin, GenericAPIView):
"""
Checks the archive of a model version has correctly been uploaded to S3.
If the verification succeeds, archive's attributes are updated and the state is set to `available`.
If the verification fails, the state is set to `erroneous`.
Once available, model versions cannot be validated anymore.
If the hash conflicts with another existing model version, an HTTP error 409 is returned
with the existing model version it the response body.
Requires a **contributor** access to the model.
"""
serializer_class = ModelVersionValidateSerializer
queryset = ModelVersion.objects.select_related('model')
def check_object_permissions(self, request, model_version):
if not self.has_write_access(model_version.model):
raise PermissionDenied(detail='You need a Contributor access to the model to validate this version.')
if model_version.state == ModelVersionState.Available:
raise PermissionDenied(detail='This model version is already marked as available.')
return super().check_object_permissions(request, model_version)
super().perform_destroy(instance)
@extend_schema(
operation_id='ValidateModelVersion',
responses={
200: ModelVersionSerializer,
400: None,
403: None,
409: ModelVersionSerializer,
},
tags=['training'],
)
def post(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data)
serializer.is_valid(raise_exception=True)
# Ensure hash unicity among versions of a model
existing_model_version = (
instance.model.versions
.filter(
state=ModelVersionState.Available,
hash=serializer.validated_data['hash'],
).first()
)
if existing_model_version:
# Set the current model version as erroneous and return the available one
instance.state = ModelVersionState.Error
instance.save(update_fields=["state"])
return Response(
ModelVersionSerializer(
existing_model_version,
context={'is_contributor': True, 'model': instance},
).data,
status=status.HTTP_409_CONFLICT,
)
for attr, value in serializer.validated_data.items():
setattr(instance, attr, value)
# Perform archive's validation
try:
instance.perform_check(save=False, raise_exc=True)
instance.check_hash(save=False, raise_exc=True)
except (AssertionError, ValueError) as e:
# Set the model version as erroneous
instance.save(update_fields=["state"])
raise ValidationError({'detail': [str(e)]})
instance.save()
return Response(
ModelVersionSerializer(instance).data,
status=status.HTTP_201_CREATED,
)
@extend_schema(tags=['training'])
......@@ -244,9 +337,11 @@ class ModelVersionDownload(TrainingModelMixin, RetrieveAPIView):
queryset = ModelVersion.objects.all()
def check_object_permissions(self, request, model_version):
if model_version.state != ModelVersionState.Available:
raise PermissionDenied(detail='This model version is not available yet.')
token = model_version.build_authentication_token_hash()
if token != request.query_params['token']:
raise PermissionDenied()
raise PermissionDenied(detail='Wrong token.')
def get(self, request, *args, **kwargs):
return Response(status=status.HTTP_302_FOUND, headers={'Location': self.get_object().s3_url})
......
# Generated by Django 4.1.4 on 2023-01-03 13:34
from django.db import migrations, models
import arkindex.project.fields
class Migration(migrations.Migration):
dependencies = [
('training', '0005_metrics_metrickey_metricvalue'),
]
operations = [
migrations.AlterUniqueTogether(
name='modelversion',
unique_together=set(),
),
migrations.AlterField(
model_name='modelversion',
name='archive_hash',
field=arkindex.project.fields.MD5HashField(blank=True, help_text="hash of the archive which contains the model version's data", max_length=32, null=True),
),
migrations.AlterField(
model_name='modelversion',
name='hash',
field=arkindex.project.fields.MD5HashField(blank=True, help_text="hash of the content of the archive which contains the model version's data", max_length=32, null=True),
),
migrations.AlterField(
model_name='modelversion',
name='size',
field=models.PositiveIntegerField(blank=True, help_text='file size in bytes', null=True),
),
migrations.AddConstraint(
model_name='modelversion',
constraint=models.UniqueConstraint(fields=('model', 'tag'), name='modelversion_unique_tag'),
),
migrations.AddConstraint(
model_name='modelversion',
constraint=models.UniqueConstraint(condition=models.Q(('hash__isnull', False)), fields=('model', 'hash'), name='modelversion_unique_hash'),
),
]
......@@ -61,24 +61,40 @@ class ModelVersion(S3FileMixin, IndexableModel):
state = EnumField(ModelVersionState, default=ModelVersionState.Created)
# Hash of the archive's content
hash = MD5HashField(help_text="hash of the content of the archive which contains the model version's data")
hash = MD5HashField(
null=True,
blank=True,
help_text="hash of the content of the archive which contains the model version's data",
)
# Hash of the archive
archive_hash = MD5HashField(help_text="hash of the archive which contains the model version's data")
archive_hash = MD5HashField(
null=True,
blank=True,
help_text="hash of the archive which contains the model version's data",
)
# Size of the archive
size = models.PositiveIntegerField(help_text='file size in bytes')
size = models.PositiveIntegerField(null=True, blank=True, help_text='file size in bytes')
# Store dictionary of paramseters given by the ML developer
configuration = models.JSONField(default=dict)
class Meta:
unique_together = (
('model', 'tag'),
('model', 'hash'),
)
s3_bucket = settings.AWS_TRAINING_BUCKET
class Meta:
constraints = [
models.UniqueConstraint(
fields=['model', 'tag'],
name='modelversion_unique_tag',
),
models.UniqueConstraint(
fields=['model', 'hash'],
name='modelversion_unique_hash',
condition=Q(hash__isnull=False),
),
]
@cached_property
def s3_key(self):
return f'{self.id}.zst'
......
import re
import uuid
from collections import defaultdict
from textwrap import dedent
......@@ -85,18 +84,26 @@ class ModelVersionLightSerializer(serializers.ModelSerializer):
class Meta:
model = ModelVersion
fields = ('id', 'model', 'tag', 'state', 'size', 'configuration')
read_only_fields = ('id', 'model', 'tag', 'state', 'size', 'configuration')
read_only_fields = ('id', 'model', 'state', 'size', 'hash', 'archive_hash')
class ModelVersionSerializer(TrainingModelMixin, ModelVersionLightSerializer):
description = serializers.CharField(allow_blank=True, style={'base_template': 'textarea.html'})
s3_url = serializers.SerializerMethodField()
class ModelVersionSerializer(serializers.ModelSerializer):
"""
Serialize a model version with fields that can be updated regardless of its state.
"""
model = serializers.HiddenField(default=_model_from_context)
parent = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.none(), default=None)
description = serializers.CharField(required=False, style={'base_template': 'textarea.html'})
configuration = serializers.JSONField(required=False, decoder=None, encoder=None, style={'base_template': 'textarea.html'})
tag = serializers.CharField(allow_null=True, max_length=50, required=False, default=None)
state = EnumField(ModelVersionState, read_only=True)
s3_url = serializers.SerializerMethodField(read_only=True)
s3_put_url = serializers.SerializerMethodField(read_only=True)
class Meta:
model = ModelVersion
fields = ModelVersionLightSerializer.Meta.fields + ('model_id', 'parent', 'description', 'hash', 'archive_hash', 's3_url')
read_only_fields = ('id', 'model', 'parent', 'state', 'hash', 'archive_hash', 'size', 's3_url')
fields = ('id', 'model', 'model_id', 'parent', 'description', 'tag', 'state', 'size', 'hash', 'configuration', 's3_url', 's3_put_url', 'created')
read_only_fields = ('id', 'model', 'state', 'size', 'hash', 'archive_hash')
validators = [
UniqueTogetherValidator(
queryset=ModelVersion.objects.filter(tag__isnull=False),
......@@ -105,63 +112,41 @@ class ModelVersionSerializer(TrainingModelMixin, ModelVersionLightSerializer):
)
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model = self.context.get('model')
if model:
qs = ModelVersion.objects.filter(model_id=model.id)
if getattr(self.instance, 'id', None):
qs = qs.exclude(id=self.instance.id)
self.fields['parent'].queryset = qs
@extend_schema_field(serializers.CharField(allow_null=True))
def get_s3_put_url(self, obj):
if not self.context.get('is_contributor') or obj.state == ModelVersionState.Available:
return None
return obj.s3_put_url
@extend_schema_field(serializers.CharField(allow_null=True))
def get_s3_url(self, obj):
# Only display s3_url to contributors, when the state is `Available`
if not self.context.get('is_contributor') or obj.state != ModelVersionState.Available:
return None
return obj.s3_url
def validate_state(self, state):
# If User wants to update the state to Available, check the file on s3
if state == ModelVersionState.Available:
try:
# Save after hash has been checked, no need to do it twice
self.instance.perform_check(save=False, raise_exc=True)
self.instance.check_hash(raise_exc=True)
except (AssertionError, ValueError) as e:
raise ValidationError(str(e))
return state
def validate_hash(self, hash):
existing_modelversion = self.context['model'].versions.filter(hash=hash).first()
if existing_modelversion:
if existing_modelversion.state != ModelVersionState.Available:
raise ValidationError(ModelVersionCreateSerializer(existing_modelversion).data)
else:
raise ValidationError(detail="A version for this model with this hash already exists.")
return hash
class ModelVersionCreateSerializer(ModelVersionSerializer):
class ModelVersionValidateSerializer(serializers.ModelSerializer):
"""
Create a new Model Version by providing the hash and the size of the archive
A serializer used to update a model version to `available` state.
"""
parent = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.none(), default=None)
hash = serializers.RegexField(re.compile(r'[0-9A-Fa-f]{32}'), min_length=32, max_length=32)
archive_hash = serializers.RegexField(re.compile(r'[0-9A-Fa-f]{32}'), min_length=32, max_length=32)
description = serializers.CharField(required=False, style={'base_template': 'textarea.html'})
size = serializers.IntegerField(min_value=0)
s3_put_url = serializers.SerializerMethodField()
state = EnumField(ModelVersionState, default=ModelVersionState.Created, read_only=True)
configuration = serializers.JSONField(required=False, decoder=None, encoder=None, style={'base_template': 'textarea.html'})
tag = serializers.CharField(allow_null=True, max_length=50, required=False, default=None)
class Meta(ModelVersionSerializer.Meta):
fields = ModelVersionSerializer.Meta.fields + ('s3_put_url', 'created', )
read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url', 'created', )
@extend_schema_field(serializers.CharField(allow_null=True))
def get_s3_put_url(self, obj):
return obj.s3_put_url
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# The API adds the provided model to the context
if self.context.get('model') is not None:
self.fields['parent'].queryset = ModelVersion.objects.filter(model_id=self.context['model'].id)
# If user doesn't have a contributor access to the model, don't show s3_url and s3_put_url
access_level = self.context.get('model_rights')
if not access_level or access_level < Role.Contributor.value:
del self.fields['s3_url']
del self.fields['s3_put_url']
class Meta:
model = ModelVersion
fields = ('size', 'hash', 'archive_hash')
extra_kwargs = {
'size': {'required': True},
'hash': {'required': True},
'archive_hash': {'required': True},
}
class MetricValueSerializer(serializers.ModelSerializer):
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment