diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 4a55630f419774505c6c2cd6df92b507df85f5e4..f6a5cbea6161a8c0142ffcc7b4bd399612ecc0a4 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -129,15 +129,10 @@ class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView):
     def get_serializer_context(self):
         context = super().get_serializer_context()
         context.update({
-            'model': self.model_version.model,
-            'is_contributor': self.access_level and self.access_level >= Role.Contributor.value,
+            'is_contributor': getattr(self, 'access_level', 0) >= Role.Contributor.value,
         })
         return context
 
-    def get_object(self, *args, **kwargs):
-        self.model_version = super().get_object(*args, **kwargs)
-        return self.model_version
-
     def check_object_permissions(self, request, model_version):
         self.access_level = get_max_level(self.request.user, model_version.model)
 
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 9af64253b2d28aeb282092148d9cb531bceb5129..b79d14bffd722e71a01493429ab5110371c79e29 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -17,6 +17,8 @@ from arkindex.users.utils import get_max_level
 
 
 def _model_from_context(serializer_field):
+    if isinstance(serializer_field.parent.instance, ModelVersion):
+        return serializer_field.parent.instance.model
     return serializer_field.context.get('model')
 
 
@@ -114,7 +116,10 @@ class ModelVersionSerializer(serializers.ModelSerializer):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        model = self.context.get('model')
+        if isinstance(self.instance, ModelVersion):
+            model = self.instance.model
+        else:
+            model = self.context.get('model')
         if model:
             qs = ModelVersion.objects.filter(model_id=model.id)
             if getattr(self.instance, 'id', None):
diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py
index 32f1aa8af04f3bcad6278c42529f41112e3a689a..4c5ea73537c3fafd0b012ba10345ec40ef25fd72 100644
--- a/arkindex/training/tests/test_model_api.py
+++ b/arkindex/training/tests/test_model_api.py
@@ -496,14 +496,15 @@ class TestModelAPI(FixtureAPITestCase):
             'hash': 'n' * 32,
             'archive_hash': 'n' * 32,
         }
+
         with self.assertNumQueries(9):
             response = self.client.patch(
                 reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)}),
                 request,
                 format='json'
             )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
             'id': str(self.model_version1.id),
             'model_id': str(self.model_version1.model_id),
@@ -544,14 +545,15 @@ class TestModelAPI(FixtureAPITestCase):
             'hash': 'n' * 32,
             'archive_hash': 'n' * 32,
         }
+
         with self.assertNumQueries(9):
             response = self.client.patch(
                 reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version3.id)}),
                 request,
                 format='json'
             )
+            self.assertEqual(response.status_code, status.HTTP_200_OK, response.json())
 
-        self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
             'id': str(self.model_version3.id),
             'model_id': str(self.model_version3.model_id),