From 454da644c37451b9dc84ad12e112120a03bacb49 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Thu, 30 Mar 2023 17:35:29 +0200 Subject: [PATCH] Fix AttributeError on RetrieveModelVersion HTML 404 --- arkindex/training/api.py | 7 +------ arkindex/training/serializers.py | 7 ++++++- arkindex/training/tests/test_model_api.py | 6 ++++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 4a55630f41..f6a5cbea61 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -129,15 +129,10 @@ class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView): def get_serializer_context(self): context = super().get_serializer_context() context.update({ - 'model': self.model_version.model, - 'is_contributor': self.access_level and self.access_level >= Role.Contributor.value, + 'is_contributor': getattr(self, 'access_level', 0) >= Role.Contributor.value, }) return context - def get_object(self, *args, **kwargs): - self.model_version = super().get_object(*args, **kwargs) - return self.model_version - def check_object_permissions(self, request, model_version): self.access_level = get_max_level(self.request.user, model_version.model) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 9af64253b2..b79d14bffd 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -17,6 +17,8 @@ from arkindex.users.utils import get_max_level def _model_from_context(serializer_field): + if isinstance(serializer_field.parent.instance, ModelVersion): + return serializer_field.parent.instance.model return serializer_field.context.get('model') @@ -114,7 +116,10 @@ class ModelVersionSerializer(serializers.ModelSerializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - model = self.context.get('model') + if isinstance(self.instance, ModelVersion): + model = self.instance.model + else: + model = self.context.get('model') if model: qs = ModelVersion.objects.filter(model_id=model.id) if getattr(self.instance, 'id', None): diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index 32f1aa8af0..4c5ea73537 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -496,14 +496,15 @@ class TestModelAPI(FixtureAPITestCase): 'hash': 'n' * 32, 'archive_hash': 'n' * 32, } + with self.assertNumQueries(9): response = self.client.patch( reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)}), request, format='json' ) + self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { 'id': str(self.model_version1.id), 'model_id': str(self.model_version1.model_id), @@ -544,14 +545,15 @@ class TestModelAPI(FixtureAPITestCase): 'hash': 'n' * 32, 'archive_hash': 'n' * 32, } + with self.assertNumQueries(9): response = self.client.patch( reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version3.id)}), request, format='json' ) + self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) - self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { 'id': str(self.model_version3.id), 'model_id': str(self.model_version3.model_id), -- GitLab