diff --git a/arkindex/process/serializers/worker_runs.py b/arkindex/process/serializers/worker_runs.py index 7ccb1afec4d13e259180d97265dc1299ce67edab..690bd9b3c4e051b60650e0bf9e2a8b0debb3a4ca 100644 --- a/arkindex/process/serializers/worker_runs.py +++ b/arkindex/process/serializers/worker_runs.py @@ -4,7 +4,7 @@ from rest_framework.exceptions import ValidationError from arkindex.process.models import WorkerConfiguration, WorkerRun from arkindex.process.serializers.imports import ProcessTrainingSerializer from arkindex.process.serializers.workers import WorkerConfigurationSerializer, WorkerVersionSerializer -from arkindex.training.models import Model, ModelVersion +from arkindex.training.models import ModelVersion, ModelVersionState from arkindex.training.serializers import ModelVersionLightSerializer from arkindex.users.models import Role from arkindex.users.utils import get_max_level @@ -57,7 +57,7 @@ class WorkerRunEditSerializer(WorkerRunSerializer): Parents, model version and configuration can be edited. """ model_version_id = serializers.PrimaryKeyRelatedField( - queryset=ModelVersion.objects.all(), + queryset=ModelVersion.objects.all().select_related('model'), required=False, allow_null=True, write_only=True, @@ -82,11 +82,15 @@ class WorkerRunEditSerializer(WorkerRunSerializer): model_usage = self.context.get('model_usage') if not model_usage: raise ValidationError("This worker version does not support model usage.") - model = Model.objects.get(id=model_version.model_id) + + if model_version.state != ModelVersionState.Available: + raise ValidationError("This model version is not in an available state.") + # Check access rights on model version - access_level = get_max_level(self.context["request"].user, model) + access_level = get_max_level(self.context["request"].user, model_version.model) if not access_level or access_level < Role.Contributor.value: raise ValidationError('You do not have access to this model version.') + return model_version def validate(self, data): diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py index 5aed67e4ce4bb3ca9f035681f07131f6bfdb4e50..b834a2eb09c01db0b20df6a297bacfa1c34bddaa 100644 --- a/arkindex/process/tests/test_workerruns.py +++ b/arkindex/process/tests/test_workerruns.py @@ -1077,7 +1077,7 @@ class TestWorkerRuns(FixtureAPITestCase): model_no_access = Model.objects.create(name='Secret model') model_version_no_access = ModelVersion.objects.create(model=model_no_access, state=ModelVersionState.Available, size=8, hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb') - with self.assertNumQueries(10): + with self.assertNumQueries(9): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -1090,7 +1090,40 @@ class TestWorkerRuns(FixtureAPITestCase): 'model_version_id': ['You do not have access to this model version.'] }) - def test_update_run_model_version(self): + def test_update_run_model_version_unavailable(self): + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=True + ) + run = self.process_1.worker_runs.create( + version=version, + parents=[], + ) + self.model_version_1.state = ModelVersionState.Error + self.model_version_1.save() + + with self.assertNumQueries(7): + response = self.client.put( + reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), + data={ + 'model_version_id': str(self.model_version_1.id), + 'parents': [] + }, format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'model_version_id': ['This model version is not in an available state.'] + }) + + def test_update_run_model_version_id(self): """ Update the worker run by adding a model_version with a worker version that supports it """ @@ -1113,7 +1146,7 @@ class TestWorkerRuns(FixtureAPITestCase): self.assertEqual(run.model_version, None) # Check generated summary, before updating, there should be only information about the worker version self.assertEqual(run.summary, f"Worker {self.worker_1.name} @ {str(version_with_model.id)[:6]}") - with self.assertNumQueries(15): + with self.assertNumQueries(13): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -1206,7 +1239,7 @@ class TestWorkerRuns(FixtureAPITestCase): self.assertIsNone(run.configuration) # Check generated summary, before updating, there should be only information about the worker version self.assertEqual(run.summary, f"Worker {self.worker_1.name} @ {str(version_with_model.id)[:6]}") - with self.assertNumQueries(16): + with self.assertNumQueries(14): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -1740,7 +1773,7 @@ class TestWorkerRuns(FixtureAPITestCase): model_no_access = Model.objects.create(name='Secret model') model_version_no_access = ModelVersion.objects.create(model=model_no_access, state=ModelVersionState.Available, size=8, hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb') - with self.assertNumQueries(10): + with self.assertNumQueries(9): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -1752,6 +1785,39 @@ class TestWorkerRuns(FixtureAPITestCase): 'model_version_id': ['You do not have access to this model version.'] }) + def test_partial_update_run_model_version_unavailable(self): + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=True + ) + run = self.process_1.worker_runs.create( + version=version, + parents=[], + ) + self.model_version_1.state = ModelVersionState.Error + self.model_version_1.save() + + with self.assertNumQueries(7): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), + data={ + 'model_version_id': str(self.model_version_1.id), + 'parents': [] + }, format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'model_version_id': ['This model version is not in an available state.'] + }) + def test_partial_update_run_model_version(self): """ Update the worker run by adding a model_version with a worker version that supports it @@ -1774,7 +1840,7 @@ class TestWorkerRuns(FixtureAPITestCase): ) self.assertIsNone(run.model_version_id) self.assertEqual(run.summary, f"Worker {version_with_model.worker.name} @ {str(version_with_model.id)[:6]}") - with self.assertNumQueries(15): + with self.assertNumQueries(13): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -1863,14 +1929,14 @@ class TestWorkerRuns(FixtureAPITestCase): configuration=self.configuration_1 ) self.assertEqual(run.model_version_id, None) - with self.assertNumQueries(15): + with self.assertNumQueries(13): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ 'model_version_id': str(self.model_version_1.id), }, format='json' ) - self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) run.refresh_from_db() self.assertEqual(response.json(), { 'id': str(run.id),