diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 4a55630f419774505c6c2cd6df92b507df85f5e4..f6a5cbea6161a8c0142ffcc7b4bd399612ecc0a4 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -129,15 +129,10 @@ class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView): def get_serializer_context(self): context = super().get_serializer_context() context.update({ - 'model': self.model_version.model, - 'is_contributor': self.access_level and self.access_level >= Role.Contributor.value, + 'is_contributor': getattr(self, 'access_level', 0) >= Role.Contributor.value, }) return context - def get_object(self, *args, **kwargs): - self.model_version = super().get_object(*args, **kwargs) - return self.model_version - def check_object_permissions(self, request, model_version): self.access_level = get_max_level(self.request.user, model_version.model) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 9af64253b2d28aeb282092148d9cb531bceb5129..b79d14bffd722e71a01493429ab5110371c79e29 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -17,6 +17,8 @@ from arkindex.users.utils import get_max_level def _model_from_context(serializer_field): + if isinstance(serializer_field.parent.instance, ModelVersion): + return serializer_field.parent.instance.model return serializer_field.context.get('model') @@ -114,7 +116,10 @@ class ModelVersionSerializer(serializers.ModelSerializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - model = self.context.get('model') + if isinstance(self.instance, ModelVersion): + model = self.instance.model + else: + model = self.context.get('model') if model: qs = ModelVersion.objects.filter(model_id=model.id) if getattr(self.instance, 'id', None): diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index 32f1aa8af04f3bcad6278c42529f41112e3a689a..4c5ea73537c3fafd0b012ba10345ec40ef25fd72 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -496,14 +496,15 @@ class TestModelAPI(FixtureAPITestCase): 'hash': 'n' * 32, 'archive_hash': 'n' * 32, } + with self.assertNumQueries(9): response = self.client.patch( reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)}), request, format='json' ) + self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { 'id': str(self.model_version1.id), 'model_id': str(self.model_version1.model_id), @@ -544,14 +545,15 @@ class TestModelAPI(FixtureAPITestCase): 'hash': 'n' * 32, 'archive_hash': 'n' * 32, } + with self.assertNumQueries(9): response = self.client.patch( reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version3.id)}), request, format='json' ) + self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) - self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { 'id': str(self.model_version3.id), 'model_id': str(self.model_version3.model_id),