Skip to content
Snippets Groups Projects
Commit 0002b002 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Erwan Rouchet
Browse files

Model hash must be unique

parent 8c5d5f4a
No related branches found
No related tags found
1 merge request!1650Model hash must be unique
......@@ -23,7 +23,7 @@ from arkindex.training.serializers import (
'Requires a **contributor** access to the model.'
),
responses={
200: ModelVersionCreateSerializer, 403: None
200: ModelVersionCreateSerializer, 400: None, 403: None
},
),
)
......
# Generated by Django 4.0.2 on 2022-03-31 08:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('training', '0002_alter_modelversion_tag'),
]
operations = [
migrations.AlterUniqueTogether(
name='modelversion',
unique_together={('model', 'tag')},
),
migrations.AlterField(
model_name='modelversion',
name='tag',
field=models.CharField(default=None, max_length=50, null=True),
),
migrations.AlterUniqueTogether(
name='modelversion',
unique_together={('model', 'tag'), ('model', 'hash')},
),
]
......@@ -49,7 +49,7 @@ class ModelVersion(S3FileMixin, IndexableModel):
description = models.TextField(default="")
tag = models.CharField(null=True, max_length=50, blank=True, default=None)
tag = models.CharField(null=True, max_length=50, default=None)
state = EnumField(ModelVersionState, default=ModelVersionState.Created)
......@@ -65,6 +65,7 @@ class ModelVersion(S3FileMixin, IndexableModel):
class Meta:
unique_together = (
('model', 'tag'),
('model', 'hash'),
)
s3_bucket = settings.AWS_TRAINING_BUCKET
......
......@@ -47,7 +47,7 @@ class CreateModelErrorResponseSerializer(serializers.Serializer):
class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer):
model = serializers.HiddenField(default=_model_from_context)
description = serializers.CharField(allow_blank=True, style={'base_template': 'textarea.html'})
tag = serializers.CharField(allow_blank=True, max_length=50)
tag = serializers.CharField(max_length=50)
state = EnumField(ModelVersionState)
configuration = serializers.JSONField(style={'base_template': 'textarea.html'})
s3_url = serializers.SerializerMethodField()
......@@ -79,6 +79,15 @@ class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer):
raise ValidationError(str(e))
return state
def validate_hash(self, hash):
existing_modelversion = self.context['model'].versions.filter(hash=hash).first()
if existing_modelversion:
if existing_modelversion.state != ModelVersionState.Available:
raise ValidationError(ModelVersionCreateSerializer(existing_modelversion).data)
else:
raise ValidationError(detail="A version for this model with this hash already exists.")
return hash
class ModelVersionCreateSerializer(ModelVersionSerializer):
"""
......@@ -91,7 +100,7 @@ class ModelVersionCreateSerializer(ModelVersionSerializer):
s3_put_url = serializers.SerializerMethodField()
state = EnumField(ModelVersionState, default=ModelVersionState.Created, read_only=True)
configuration = serializers.JSONField(required=False, decoder=None, encoder=None, style={'base_template': 'textarea.html'})
tag = serializers.CharField(allow_blank=True, allow_null=True, max_length=50, required=False, default=None)
tag = serializers.CharField(allow_null=True, max_length=50, required=False, default=None)
class Meta(ModelVersionSerializer.Meta):
fields = ModelVersionSerializer.Meta.fields + ('s3_put_url',)
......
......@@ -38,9 +38,10 @@ class TestModelAPI(FixtureAPITestCase):
# Create some Model Versions
cls.model_version1 = ModelVersion.objects.create(model=cls.model1, hash="aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", size=8)
cls.model_version2 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="tagged", configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", size=8)
cls.model_version2 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="tagged", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", size=8)
cls.model_version3 = ModelVersion.objects.create(model=cls.model2, state="available", tag="tagged", hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", size=8)
cls.model_version4 = ModelVersion.objects.create(model=cls.model2, description="some description", tag="taggedv2", configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa", size=8)
cls.model_version4 = ModelVersion.objects.create(model=cls.model2, description="some description", tag="taggedv2", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa", size=8)
cls.model_version5 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="available_version", state='available', configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbdd", size=8)
# Create three users with different access rights on each model
cls.user1 = User.objects.create(email='user1@test.test', display_name='User 1', verified_email=True)
......@@ -61,7 +62,7 @@ class TestModelAPI(FixtureAPITestCase):
def build_model_version_create_request(self):
return {
'hash': 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
'hash': '94274e84f3de91d1645b1e082b5f3466',
'size': 8
}
......@@ -101,7 +102,7 @@ class TestModelAPI(FixtureAPITestCase):
s3_presigned_url_mock.return_value = 'http://s3/upload_put_url'
self.client.force_login(self.user1)
request = self.build_model_version_create_request()
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
......@@ -126,6 +127,19 @@ class TestModelAPI(FixtureAPITestCase):
}
)
def test_create_model_version_blank_tag_model(self):
"""
Raise 400 when creating a model version with a blank tag
"""
self.client.force_login(self.user1)
request = self.build_model_version_create_request()
request['tag'] = ''
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"tag": ["This field may not be blank."]})
@patch('arkindex.project.aws.s3.meta.client.generate_presigned_url')
def test_model_version_creation_with_tag(self, s3_presigned_url_mock):
"""
......@@ -138,7 +152,7 @@ class TestModelAPI(FixtureAPITestCase):
'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
'size': 8,
}
with self.assertNumQueries(8):
with self.assertNumQueries(9):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
......@@ -162,19 +176,57 @@ class TestModelAPI(FixtureAPITestCase):
}
)
def test_create_model_version_unique(self):
def test_create_model_version_unique_tag_model(self):
"""
Raises 400 when creating a model version that already exists, same model_id and tag
"""
self.client.force_login(self.user1)
ModelVersion.objects.create(model=self.model1, tag='production', hash='5a50cdbaf05d3b6cc51fcb173d0057c0', size=16)
request = {'tag': 'production', 'hash': 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'size': 32}
with self.assertNumQueries(7):
request = self.build_model_version_create_request()
request['tag'] = self.model_version2.tag
with self.assertNumQueries(8):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
def test_create_model_version_unique_hash_model_available(self):
"""
Raises 400 when creating a model version that already exists, same model_id and hash and state==Available
"""
self.client.force_login(self.user1)
request = {'tag': 'production', 'hash': self.model_version5.hash, 'size': 32}
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"hash": ["A version for this model with this hash already exists."]})
def test_create_model_version_unique_hash_model_not_available(self):
"""
Raises 400 when creating a model version that already exists, same model_id and hash but state != Available. Returns usual payload with s3_put_url of the match
"""
self.client.force_login(self.user1)
request = {'tag': 'production', 'hash': self.model_version2.hash, 'size': 32}
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json()['hash'],
{
'id': str(self.model_version2.id),
'model_id': str(self.model1.id),
'parent': str(self.model_version2.parent),
'description': self.model_version2.description,
'state': self.model_version2.state.value,
'configuration': self.model_version2.configuration,
'tag': self.model_version2.tag,
'size': str(self.model_version2.size),
'hash': self.model_version2.hash,
's3_url': self.model_version2.s3_url,
's3_put_url': self.model_version2.s3_put_url
}
)
def test_partial_update_model_version_requires_contributor(self):
"""
Can't partial update a model version as guest
......@@ -256,7 +308,7 @@ class TestModelAPI(FixtureAPITestCase):
"""
exists.return_value = True
s3_object().content_length = self.model_version1.size
s3_object().e_tag = '"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"'
s3_object().e_tag = f'"{self.model_version2.hash}"'
self.client.force_login(self.user1)
with self.assertNumQueries(9):
response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), {
......@@ -398,7 +450,7 @@ class TestModelAPI(FixtureAPITestCase):
"""
exists.return_value = True
s3_object().content_length = self.model_version1.size
s3_object().e_tag = '"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"'
s3_object().e_tag = f'"{self.model_version2.hash}"'
self.client.force_login(self.user1)
request = self.build_model_version_update_request()
with self.assertNumQueries(9):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment