From 454da644c37451b9dc84ad12e112120a03bacb49 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Thu, 30 Mar 2023 17:35:29 +0200
Subject: [PATCH] Fix AttributeError on RetrieveModelVersion HTML 404

---
 arkindex/training/api.py                  | 7 +------
 arkindex/training/serializers.py          | 7 ++++++-
 arkindex/training/tests/test_model_api.py | 6 ++++--
 3 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 4a55630f41..f6a5cbea61 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -129,15 +129,10 @@ class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView):
     def get_serializer_context(self):
         context = super().get_serializer_context()
         context.update({
-            'model': self.model_version.model,
-            'is_contributor': self.access_level and self.access_level >= Role.Contributor.value,
+            'is_contributor': getattr(self, 'access_level', 0) >= Role.Contributor.value,
         })
         return context
 
-    def get_object(self, *args, **kwargs):
-        self.model_version = super().get_object(*args, **kwargs)
-        return self.model_version
-
     def check_object_permissions(self, request, model_version):
         self.access_level = get_max_level(self.request.user, model_version.model)
 
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index 9af64253b2..b79d14bffd 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -17,6 +17,8 @@ from arkindex.users.utils import get_max_level
 
 
 def _model_from_context(serializer_field):
+    if isinstance(serializer_field.parent.instance, ModelVersion):
+        return serializer_field.parent.instance.model
     return serializer_field.context.get('model')
 
 
@@ -114,7 +116,10 @@ class ModelVersionSerializer(serializers.ModelSerializer):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        model = self.context.get('model')
+        if isinstance(self.instance, ModelVersion):
+            model = self.instance.model
+        else:
+            model = self.context.get('model')
         if model:
             qs = ModelVersion.objects.filter(model_id=model.id)
             if getattr(self.instance, 'id', None):
diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py
index 32f1aa8af0..4c5ea73537 100644
--- a/arkindex/training/tests/test_model_api.py
+++ b/arkindex/training/tests/test_model_api.py
@@ -496,14 +496,15 @@ class TestModelAPI(FixtureAPITestCase):
             'hash': 'n' * 32,
             'archive_hash': 'n' * 32,
         }
+
         with self.assertNumQueries(9):
             response = self.client.patch(
                 reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)}),
                 request,
                 format='json'
             )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
 
-        self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
             'id': str(self.model_version1.id),
             'model_id': str(self.model_version1.model_id),
@@ -544,14 +545,15 @@ class TestModelAPI(FixtureAPITestCase):
             'hash': 'n' * 32,
             'archive_hash': 'n' * 32,
         }
+
         with self.assertNumQueries(9):
             response = self.client.patch(
                 reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version3.id)}),
                 request,
                 format='json'
             )
+            self.assertEqual(response.status_code, status.HTTP_200_OK, response.json())
 
-        self.assertEqual(response.status_code, status.HTTP_200_OK)
         self.assertDictEqual(response.json(), {
             'id': str(self.model_version3.id),
             'model_id': str(self.model_version3.model_id),
-- 
GitLab