diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 5d32d1ae397127c6cb17b0a52eea24d79cdd79b5..46bf599c152811abf05ca05e176696d81e722c33 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -95,7 +95,13 @@ from arkindex.process.api import ( WorkerVersionRetrieve, ) from arkindex.project.openapi import OpenApiSchemaView -from arkindex.training.api import ModelsList, ModelVersionDownload, ModelVersionsList, ModelVersionsRetrieve +from arkindex.training.api import ( + ModelRetrieve, + ModelsList, + ModelVersionDownload, + ModelVersionsList, + ModelVersionsRetrieve, +) from arkindex.users.api import ( CredentialsList, CredentialsRetrieve, @@ -252,6 +258,7 @@ api = [ # ML models training path('modelversion/<uuid:pk>/', ModelVersionsRetrieve.as_view(), name='model-version-retrieve'), path('models/', ModelsList.as_view(), name='models'), + path('model/<uuid:pk>/', ModelRetrieve.as_view(), name='model-retrieve'), path('model/<uuid:pk>/versions/', ModelVersionsList.as_view(), name='model-versions'), path('modelversion/<uuid:pk>/download/', ModelVersionDownload.as_view(), name='model-version-download'), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index d1277f5c64ea55fa76dbb5cabb8436aaba1436d1..b43ef8081a1fa7a041efa34b07e4df952bf7b880 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -204,6 +204,20 @@ class ModelsList(TrainingModelMixin, ListCreateAPIView): return serializer.save() +@extend_schema(tags=['training']) +class ModelRetrieve(TrainingModelMixin, RetrieveAPIView): + """ + Retrieve a Machine Learning model. + + Requires a **guest** access to the model. + """ + permission_classes = (IsVerified, ) + serializer_class = ModelSerializer + + def get_queryset(self): + return self.readable_models + + @extend_schema(tags=['training']) @extend_schema_view( get=extend_schema( diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index fb63a575da6749117c53ca313531efc5296492e6..afa9d321724a94a6dbd856e4591de93c626ec90d 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -1,5 +1,3 @@ - - from unittest.mock import patch from uuid import uuid4 @@ -12,39 +10,6 @@ from arkindex.training.models import Model, ModelVersion, ModelVersionState from arkindex.users.models import Group, Right, Role, User -def _format_datetime(date): - return str(date.isoformat().replace('+00:00', 'Z')).replace(' ', 'T') - - -def _deserialize_model(model, access_rights): - return { - 'id': str(model.id), - 'created': _format_datetime(model.created), - 'updated': _format_datetime(model.updated), - 'name': model.name, - 'description': model.description, - 'rights': access_rights - } - - -def _deserialize_model_version(model_version): - # Only stringify if we get an ID - parent = str(model_version.parent) if model_version.parent else model_version.parent - return { - 'id': str(model_version.id), - 'model_id': str(model_version.model_id), - 'parent': parent, - 'description': model_version.description, - 'tag': model_version.tag, - 'hash': model_version.hash, - 'archive_hash': model_version.archive_hash, - 'state': model_version.state.value, - 'size': model_version.size, - 'configuration': model_version.configuration, - 's3_url': model_version.s3_url, - } - - class TestModelAPI(FixtureAPITestCase): """ Test model and model version api @@ -72,11 +37,48 @@ class TestModelAPI(FixtureAPITestCase): cls.model3 = Model.objects.create(name="Third Model") # Create some Model Versions - cls.model_version1 = ModelVersion.objects.create(model=cls.model1, hash="aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", archive_hash="5eca9bd3eb07c006cd43ae48dfde7fd3", size=8) - cls.model_version2 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="tagged", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", archive_hash="8b4f9ea16de4bcf5bbfc0ff1ea237934", size=8) - cls.model_version3 = ModelVersion.objects.create(model=cls.model2, state="available", tag="tagged", hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", archive_hash="a501b1ae6f6bc833551245f8328590b8", size=8) - cls.model_version4 = ModelVersion.objects.create(model=cls.model2, description="some description", tag="taggedv2", configuration={"n_epochs": '10'}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa", archive_hash="d67459a391b228cced507f068d4a570a", size=8) - cls.model_version5 = ModelVersion.objects.create(model=cls.model1, description="some description", tag="available_version", state='available', configuration={"n_epochs": 10}, hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbdd", archive_hash="3d87fd18e03fcbbfb64a381ad9472596", size=8) + cls.model_version1 = ModelVersion.objects.create( + model=cls.model1, + hash="aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + archive_hash="5eca9bd3eb07c006cd43ae48dfde7fd3", + size=8, + ) + cls.model_version2 = ModelVersion.objects.create( + model=cls.model1, + description="some description", + tag="tagged", + configuration={"n_epochs": '10'}, + hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + archive_hash="8b4f9ea16de4bcf5bbfc0ff1ea237934", + size=8, + ) + cls.model_version3 = ModelVersion.objects.create( + model=cls.model2, + state=ModelVersionState.Available, + tag="tagged", + hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", + archive_hash="a501b1ae6f6bc833551245f8328590b8", + size=8, + ) + cls.model_version4 = ModelVersion.objects.create( + model=cls.model2, + description="some description", + tag="taggedv2", + configuration={"n_epochs": '10'}, + hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa", + archive_hash="d67459a391b228cced507f068d4a570a", + size=8, + ) + cls.model_version5 = ModelVersion.objects.create( + model=cls.model1, + description="some description", + tag="available_version", + state=ModelVersionState.Available, + configuration={"n_epochs": 10}, + hash="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbdd", + archive_hash="3d87fd18e03fcbbfb64a381ad9472596", + size=8, + ) # Create three users with different access rights on each model cls.user1 = User.objects.create(email='user1@test.test', display_name='User 1', verified_email=True) @@ -118,7 +120,7 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."}) - def test_model_version_requires_contributor(self): + def test_create_model_version_requires_contributor(self): """ Can't create model version as guest """ @@ -131,7 +133,7 @@ class TestModelAPI(FixtureAPITestCase): @patch('arkindex.training.models.logger') @patch('arkindex.project.aws.S3FileMixin.s3_object') - def test_model_version_check_hash_ignores_multipart(self, s3_object_mock, logger_mock): + def test_create_model_version_check_hash_ignores_multipart(self, s3_object_mock, logger_mock): s3_object_mock.e_tag = '"badbadbadbad-5"' version = ModelVersion.objects.create(model_id=self.model1.id, archive_hash='huehuehuehue', size=8, hash='huehuehue') @@ -140,7 +142,7 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(logger_mock.warning.call_count, 1) @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url') - def test_model_version_creation_no_tag(self, s3_presigned_url_mock): + def test_create_model_version_no_tag(self, s3_presigned_url_mock): """ Creates a new model version without setting a tag """ @@ -154,7 +156,7 @@ class TestModelAPI(FixtureAPITestCase): mock_now.return_value = fake_now with self.assertNumQueries(8): response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() self.assertIn('id', data) df = ModelVersion.objects.get(id=data['id']) @@ -172,7 +174,7 @@ class TestModelAPI(FixtureAPITestCase): 'size': request['size'], 'hash': request['hash'], 'archive_hash': request['archive_hash'], - 'created': _format_datetime(fake_now), + 'created': fake_now.isoformat().replace('+00:00', 'Z'), 's3_url': s3_presigned_url_mock.return_value, 's3_put_url': s3_presigned_url_mock.return_value } @@ -192,7 +194,7 @@ class TestModelAPI(FixtureAPITestCase): self.assertDictEqual(response.json(), {"tag": ["This field may not be blank."]}) @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url') - def test_model_version_creation_with_tag(self, s3_presigned_url_mock): + def test_create_model_version_with_tag(self, s3_presigned_url_mock): """ Creates a new model version with a tag """ @@ -228,7 +230,7 @@ class TestModelAPI(FixtureAPITestCase): 'size': request['size'], 'hash': request['hash'], 'archive_hash': request['archive_hash'], - 'created': _format_datetime(fake_now), + 'created': fake_now.isoformat().replace('+00:00', 'Z'), 's3_url': s3_presigned_url_mock.return_value, 's3_put_url': s3_presigned_url_mock.return_value } @@ -278,7 +280,7 @@ class TestModelAPI(FixtureAPITestCase): 'state': self.model_version2.state.value, 'configuration': self.model_version2.configuration, 'tag': self.model_version2.tag, - 'created': _format_datetime(self.model_version2.created), + 'created': self.model_version2.created.isoformat().replace('+00:00', 'Z'), 'size': str(self.model_version2.size), 'hash': self.model_version2.hash, 'archive_hash': self.model_version2.archive_hash, @@ -636,8 +638,16 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) models = response.json()['results'] - self.assertEqual(len(models), 1) - self.assertListEqual(models, [_deserialize_model(self.model2, access_rights=['read'])]) + self.assertListEqual(models, [ + { + 'id': str(self.model2.id), + 'created': self.model2.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'Second Model', + 'description': '', + 'rights': ['read'] + } + ]) def test_list_models_contrib_access(self): """User 2 has contributor access to Model1 and Model2. @@ -650,10 +660,23 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) models = response.json()['results'] - self.assertEqual(len(models), 2) self.assertListEqual(models, [ - _deserialize_model(self.model1, access_rights=['read', 'write']), - _deserialize_model(self.model2, access_rights=['read', 'write', 'admin']), + { + 'id': str(self.model1.id), + 'created': self.model1.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'First Model', + 'description': '', + 'rights': ['read', 'write'] + }, + { + 'id': str(self.model2.id), + 'created': self.model2.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'Second Model', + 'description': '', + 'rights': ['read', 'write', 'admin'] + } ]) def test_list_models_filter_name(self): @@ -665,8 +688,76 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) models = response.json()['results'] - self.assertEqual(len(models), 1) - self.assertListEqual(models, [_deserialize_model(self.model2, access_rights=['read', 'write', 'admin'])]) + self.assertListEqual(models, [ + { + 'id': str(self.model2.id), + 'created': self.model2.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'Second Model', + 'description': '', + 'rights': ['read', 'write', 'admin'] + } + ]) + + def test_retrieve_model_requires_login(self): + with self.assertNumQueries(0): + response = self.client.get(reverse('api:model-retrieve', kwargs={'pk': str(self.model2.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_retrieve_model_requires_verified(self): + self.user3.verified_email = False + self.user3.save() + self.client.force_login(self.user3) + + with self.assertNumQueries(2): + response = self.client.get(reverse('api:model-retrieve', kwargs={'pk': str(self.model2.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_retrieve_model_requires_guest(self): + self.assertFalse(self.model1.public) + self.assertFalse(self.model1.memberships.filter(user=self.user3).exists()) + self.client.force_login(self.user3) + + with self.assertNumQueries(4): + response = self.client.get(reverse('api:model-retrieve', kwargs={'pk': str(self.model1.id)})) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_retrieve_model_public(self): + self.assertFalse(self.model1.memberships.filter(user=self.user3).exists()) + self.model1.public = True + self.model1.save() + self.client.force_login(self.user3) + + with self.assertNumQueries(5): + response = self.client.get(reverse('api:model-retrieve', kwargs={'pk': str(self.model1.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertDictEqual(response.json(), { + 'id': str(self.model1.id), + 'created': self.model1.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'First Model', + 'description': '', + 'rights': ['read'] + }) + + def test_retrieve_model(self): + self.assertFalse(self.model2.public) + self.assertTrue(self.model2.memberships.filter(user=self.user3).exists()) + self.client.force_login(self.user3) + + with self.assertNumQueries(5): + response = self.client.get(reverse('api:model-retrieve', kwargs={'pk': str(self.model2.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertDictEqual(response.json(), { + 'id': str(self.model2.id), + 'created': self.model2.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'Second Model', + 'description': '', + 'rights': ['read'] + }) def test_list_model_versions_requires_logged_in(self): """To list a model's versions, you need to be logged in. @@ -733,15 +824,30 @@ class TestModelAPI(FixtureAPITestCase): def test_retrieve_model_versions_tag_available(self): """Retrieve a model version with a set tag and state==Available with guest rights on the model. """ + self.assertIsNotNone(self.model_version3.tag) + self.assertEqual(self.model_version3.state, ModelVersionState.Available) self.client.force_login(self.user3) with self.assertNumQueries(7): response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version3.id)})) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), _deserialize_model_version(self.model_version3)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + 'id': str(self.model_version3.id), + 'model_id': str(self.model2.id), + 'parent': None, + 'description': '', + 'tag': 'tagged', + 'hash': 'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba', + 'archive_hash': 'a501b1ae6f6bc833551245f8328590b8', + 'state': 'available', + 'size': 8, + 'configuration': {}, + 's3_url': self.model_version3.s3_url, + }) def test_retrieve_model_versions_require_contributor(self): """To retrieve a model version with no set tag or state!=Available, you need contributor rights on the model. """ + self.assertNotEqual(self.model_version4.state, ModelVersionState.Available) self.client.force_login(self.user3) with self.assertNumQueries(7): response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version4.id)})) @@ -754,8 +860,20 @@ class TestModelAPI(FixtureAPITestCase): self.client.force_login(self.user2) with self.assertNumQueries(7): response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version4.id)})) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), _deserialize_model_version(self.model_version4)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + 'id': str(self.model_version4.id), + 'model_id': str(self.model2.id), + 'parent': None, + 'description': 'some description', + 'tag': 'taggedv2', + 'hash': 'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbaa', + 'archive_hash': 'd67459a391b228cced507f068d4a570a', + 'state': 'created', + 'size': 8, + 'configuration': {'n_epochs': '10'}, + 's3_url': self.model_version4.s3_url, + }) def test_download_model_version_wrong_token(self): with self.assertNumQueries(1):