diff --git a/arkindex/training/admin.py b/arkindex/training/admin.py index 1d38c2bef5d57b4e022e24ad3a47dd52737b3cf3..8d668f3427eb3be54919b91d324a35c8fd908b87 100644 --- a/arkindex/training/admin.py +++ b/arkindex/training/admin.py @@ -13,8 +13,8 @@ class ModelAdmin(admin.ModelAdmin): class ModelVersionAdmin(admin.ModelAdmin): list_display = ('id', 'model', 'tag', 'size', 'state') list_filter = ('model__name', ('state', EnumFieldListFilter), ) - fields = ('model', 'parent', 'description', 'state', 'tag', 'hash', 'size', 'configuration',) - readonly_fields = ('hash', 'size', ) + fields = ('model', 'parent', 'description', 'state', 'tag', 'hash', 'archive_hash', 'size', 'configuration',) + readonly_fields = ('hash', 'archive_hash', 'size', ) admin.site.register(Model, ModelAdmin) diff --git a/arkindex/training/migrations/0004_modelversion_archive_hash.py b/arkindex/training/migrations/0004_modelversion_archive_hash.py new file mode 100644 index 0000000000000000000000000000000000000000..32de137a13ca19d6374918072c62c1abb26fffb5 --- /dev/null +++ b/arkindex/training/migrations/0004_modelversion_archive_hash.py @@ -0,0 +1,26 @@ +# Generated by Django 4.0.2 on 2022-04-14 11:02 + +from django.db import migrations + +import arkindex.project.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ('training', '0003_alter_modelversion_unique_together_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='modelversion', + name='archive_hash', + field=arkindex.project.fields.MD5HashField(default='00000000000000000000000000000000', help_text="hash of the archive which contains the model version's data", max_length=32), + preserve_default=False, + ), + migrations.AlterField( + model_name='modelversion', + name='hash', + field=arkindex.project.fields.MD5HashField(help_text="hash of the content of the archive which contains the model version's data", max_length=32), + ), + ] diff --git a/arkindex/training/models.py b/arkindex/training/models.py index bc7494be08a0c0cd6fd70ce78a4341817a957d4b..578e852c08ccbc365e147bc15dedad1200709bd9 100644 --- a/arkindex/training/models.py +++ b/arkindex/training/models.py @@ -1,3 +1,5 @@ +import logging + from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation from django.db import models @@ -8,6 +10,8 @@ from arkindex.project.aws import S3FileMixin from arkindex.project.fields import MD5HashField from arkindex.project.models import IndexableModel +logger = logging.getLogger(__name__) + class Model(IndexableModel): """ @@ -53,8 +57,11 @@ class ModelVersion(S3FileMixin, IndexableModel): state = EnumField(ModelVersionState, default=ModelVersionState.Created) + # Hash of the archive's content + hash = MD5HashField(help_text="hash of the content of the archive which contains the model version's data") + # Hash of the archive - hash = MD5HashField() + archive_hash = MD5HashField(help_text="hash of the archive which contains the model version's data") # Size of the archive size = models.PositiveIntegerField(help_text='file size in bytes') @@ -82,9 +89,9 @@ class ModelVersion(S3FileMixin, IndexableModel): assert self.exists(), 'Archive has not been uploaded' assert self.s3_object.content_length == self.size, \ f'Uploaded file size is {self.s3_object.content_length} bytes, expected {self.size} bytes' - self.status = ModelVersionState.Available + self.state = ModelVersionState.Available except AssertionError: - self.status = ModelVersionState.Error + self.state = ModelVersionState.Error if raise_exc: raise finally: @@ -95,14 +102,22 @@ class ModelVersion(S3FileMixin, IndexableModel): """ Checks the MD5 hash against the hash from Amazon S3 """ - assert self.hash, 'File has no hash' + assert self.archive_hash, 'File has no hash' assert self.exists(), 'No file content, assert file has been correctly uploaded' # The hash given by Boto seems to be surrounded by double quotes - if self.s3_object.e_tag.strip('"') == self.hash: - self.status = ModelVersionState.Available + if self.s3_object.e_tag.strip('"') == self.archive_hash: + self.state = ModelVersionState.Available + elif '-' in self.s3_object.e_tag: + # Multipart hash: a hash of each part's hash, + # combined with the number of parts, separated by a dash + logger.warning('Could not check remote multipart hash {!r} against local hash {!r}'.format( + self.s3_object.e_tag, + self.archive_hash, + )) + self.state = ModelVersionState.Available else: - self.status = ModelVersionState.Error + self.state = ModelVersionState.Error if save: self.save() - if self.status == ModelVersionState.Error and raise_exc: + if self.state == ModelVersionState.Error and raise_exc: raise ValueError('MD5 hashes do not match') diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 825942a4b64c93acee8924af8fd0291f84ba88ed..371167f089ee1b1d345a7876acb9e51da1db4eac 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -66,8 +66,8 @@ class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer): class Meta: model = ModelVersion - fields = ('id', 'model', 'model_id', 'parent', 'description', 'tag', 'hash', 'state', 'size', 'configuration', 's3_url') - read_only_fields = ('id', 'model', 'parent', 'state', 'hash', 'size', 's3_url') + fields = ('id', 'model', 'model_id', 'parent', 'description', 'tag', 'hash', 'archive_hash', 'state', 'size', 'configuration', 's3_url') + read_only_fields = ('id', 'model', 'parent', 'state', 'hash', 'archive_hash', 'size', 's3_url') validators = [ UniqueTogetherValidator( queryset=ModelVersion.objects.filter(tag__isnull=False), @@ -107,6 +107,7 @@ class ModelVersionCreateSerializer(ModelVersionSerializer): """ parent = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.none(), default=None) hash = serializers.RegexField(re.compile(r'[0-9A-Fa-f]{32}'), min_length=32, max_length=32) + archive_hash = serializers.RegexField(re.compile(r'[0-9A-Fa-f]{32}'), min_length=32, max_length=32) description = serializers.CharField(required=False, style={'base_template': 'textarea.html'}) size = serializers.IntegerField(min_value=0) s3_put_url = serializers.SerializerMethodField() diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index e1f3f449f5bb45068d8d790e37b927efa4635e0d..7cedcf3405265c8712412f448c259e13173e24dc 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -53,11 +53,11 @@ class TestModelAPI(FixtureAPITestCase): cls.model3 = Model.objects.create(name="Third Model") # Create some Model Versions - cls.model_version1 = ModelVersion.objects.create(model=cls.model1, hash="aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", size=8) - cls.model_version2 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="tagged", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", size=8) - cls.model_version3 = ModelVersion.objects.create(model=cls.model2, state="available", tag="tagged", hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", size=8) - cls.model_version4 = ModelVersion.objects.create(model=cls.model2, description="some description", tag="taggedv2", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa", size=8) - cls.model_version5 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="available_version", state='available', configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbdd", size=8) + cls.model_version1 = ModelVersion.objects.create(model=cls.model1, hash="aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", archive_hash="5eca9bd3eb07c006cd43ae48dfde7fd3", size=8) + cls.model_version2 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="tagged", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", archive_hash="8b4f9ea16de4bcf5bbfc0ff1ea237934", size=8) + cls.model_version3 = ModelVersion.objects.create(model=cls.model2, state="available", tag="tagged", hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", archive_hash="a501b1ae6f6bc833551245f8328590b8", size=8) + cls.model_version4 = ModelVersion.objects.create(model=cls.model2, description="some description", tag="taggedv2", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa", archive_hash="d67459a391b228cced507f068d4a570a", size=8) + cls.model_version5 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="available_version", state='available', configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbdd", archive_hash="3d87fd18e03fcbbfb64a381ad9472596", size=8) # Create three users with different access rights on each model cls.user1 = User.objects.create(email='user1@test.test', display_name='User 1', verified_email=True) @@ -79,6 +79,7 @@ class TestModelAPI(FixtureAPITestCase): def build_model_version_create_request(self): return { 'hash': '94274e84f3de91d1645b1e082b5f3466', + 'archive_hash': '0958a74b060a89fc38318a9a96aef32a', 'size': 8 } @@ -109,6 +110,16 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to create a new version."}) + @patch('arkindex.training.models.logger') + @patch('arkindex.project.aws.S3FileMixin.s3_object') + def test_model_version_check_hash_ignores_multipart(self, s3_object_mock, logger_mock): + s3_object_mock.e_tag = '"badbadbadbad-5"' + version = ModelVersion.objects.create(model_id=self.model1.id, archive_hash='huehuehuehue', size=8, hash='huehuehue') + + version.check_hash(save=True, raise_exc=True) + self.assertEqual(version.state, ModelVersionState.Available) + self.assertEqual(logger_mock.warning.call_count, 1) + @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url') def test_model_version_creation_no_tag(self, s3_presigned_url_mock): """ @@ -141,6 +152,7 @@ class TestModelAPI(FixtureAPITestCase): 'tag': None, 'size': request['size'], 'hash': request['hash'], + 'archive_hash': request['archive_hash'], 'created': _format_datetime(fake_now), 's3_url': s3_presigned_url_mock.return_value, 's3_put_url': s3_presigned_url_mock.return_value @@ -170,6 +182,7 @@ class TestModelAPI(FixtureAPITestCase): request = { 'tag': 'TAG', 'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', + 'archive_hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8, } fake_now = timezone.now() @@ -195,6 +208,7 @@ class TestModelAPI(FixtureAPITestCase): 'tag': request['tag'], 'size': request['size'], 'hash': request['hash'], + 'archive_hash': request['archive_hash'], 'created': _format_datetime(fake_now), 's3_url': s3_presigned_url_mock.return_value, 's3_put_url': s3_presigned_url_mock.return_value @@ -219,7 +233,7 @@ class TestModelAPI(FixtureAPITestCase): Raises 400 when creating a model version that already exists, same model_id and hash and state==Available """ self.client.force_login(self.user1) - request = {'tag': 'production', 'hash': self.model_version5.hash, 'size': 32} + request = {'tag': 'production', 'hash': self.model_version5.hash, 'archive_hash': self.model_version5.archive_hash, 'size': 32} with self.assertNumQueries(7): response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') @@ -248,6 +262,7 @@ class TestModelAPI(FixtureAPITestCase): 'created': _format_datetime(self.model_version2.created), 'size': str(self.model_version2.size), 'hash': self.model_version2.hash, + 'archive_hash': self.model_version2.archive_hash, 's3_url': self.model_version2.s3_url, 's3_put_url': self.model_version2.s3_put_url } @@ -272,7 +287,7 @@ class TestModelAPI(FixtureAPITestCase): s3_presigned_url.return_value = "http://s3/upload_url" exists.return_value = True s3_object().content_length = self.model_version1.size - s3_object().e_tag = self.model_version1.hash + s3_object().e_tag = self.model_version1.archive_hash self.client.force_login(self.user1) request = { "description" : "A very long description", @@ -291,6 +306,7 @@ class TestModelAPI(FixtureAPITestCase): "description": request.get('description'), "tag": request.get('tag'), "hash": self.model_version1.hash, + "archive_hash": self.model_version1.archive_hash, "state": "available", "size": self.model_version1.size, "configuration": request.get('configuration'), @@ -354,7 +370,7 @@ class TestModelAPI(FixtureAPITestCase): s3_presigned_url.return_value = "http://s3/upload_put_url" exists.return_value = True s3_object().content_length = self.model_version1.size - s3_object().e_tag = self.model_version1.hash + s3_object().e_tag = self.model_version1.archive_hash self.client.force_login(self.user1) request = { "tag": self.model_version2.tag, @@ -397,7 +413,7 @@ class TestModelAPI(FixtureAPITestCase): s3_presigned_url.return_value = "http://s3/upload_url" exists.return_value = True s3_object().content_length = self.model_version1.size - s3_object().e_tag = self.model_version1.hash + s3_object().e_tag = self.model_version1.archive_hash self.client.force_login(self.user1) params = self.build_model_version_update_request() @@ -415,7 +431,7 @@ class TestModelAPI(FixtureAPITestCase): s3_presigned_url.return_value = "http://s3/upload_url" exists.return_value = True s3_object().content_length = self.model_version2.size - s3_object().e_tag = self.model_version2.hash + s3_object().e_tag = self.model_version2.archive_hash self.client.force_login(self.user1) request = { @@ -433,6 +449,7 @@ class TestModelAPI(FixtureAPITestCase): "description": request.get('description'), "tag": self.model_version2.tag, "hash": self.model_version2.hash, + "archive_hash": self.model_version2.archive_hash, "state": "available", "size": self.model_version2.size, "configuration": request.get('configuration'), @@ -474,7 +491,7 @@ class TestModelAPI(FixtureAPITestCase): """ exists.return_value = True s3_object().content_length = self.model_version1.size - s3_object().e_tag = f'"{self.model_version2.hash}"' + s3_object().e_tag = f'"{self.model_version2.archive_hash}"' self.client.force_login(self.user1) request = self.build_model_version_update_request() with self.assertNumQueries(9):