From 581c2d323879e8bf8c0cb34ae5a01490187ca1f3 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 26 Sep 2023 07:39:17 +0000 Subject: [PATCH] Allow using available model versions with a tag on models with guest access in WorkerRuns --- arkindex/documents/managers.py | 8 +- arkindex/process/serializers/worker_runs.py | 30 +- arkindex/process/tests/test_workerruns.py | 466 +++++++++++++------- arkindex/training/api.py | 6 +- arkindex/training/managers.py | 40 ++ arkindex/training/models.py | 3 + arkindex/training/serializers.py | 3 +- arkindex/users/managers.py | 12 + 8 files changed, 392 insertions(+), 176 deletions(-) create mode 100644 arkindex/training/managers.py diff --git a/arkindex/documents/managers.py b/arkindex/documents/managers.py index 298a82a1c9..bfc6cf94b7 100644 --- a/arkindex/documents/managers.py +++ b/arkindex/documents/managers.py @@ -4,6 +4,7 @@ import django from django.db import DJANGO_VERSION_PICKLE_KEY, connections, models from arkindex.project.fields import Unnest +from arkindex.users.managers import BaseACLManager from arkindex.users.models import Role @@ -216,16 +217,11 @@ class ElementManager(models.Manager): return paths -class CorpusManager(models.Manager): +class CorpusManager(BaseACLManager): ''' Add ACL functions to corpus listing ''' - def filter_rights(self, *args, **kwargs): - # Avoid circular dependencies as this module is imported by documents.models - from arkindex.users.utils import filter_rights - return filter_rights(*args, **kwargs) - def readable(self, user): return super().get_queryset().filter( id__in=(self.filter_rights(user, self.model, Role.Guest.value).values('id')) diff --git a/arkindex/process/serializers/worker_runs.py b/arkindex/process/serializers/worker_runs.py index 28dff3bfc1..2dc3161fe1 100644 --- a/arkindex/process/serializers/worker_runs.py +++ b/arkindex/process/serializers/worker_runs.py @@ -1,4 +1,5 @@ from collections import defaultdict +from textwrap import dedent from django.db.models import Prefetch from rest_framework import serializers @@ -17,8 +18,6 @@ from arkindex.process.serializers.workers import WorkerConfigurationSerializer, from arkindex.project.mixins import WorkerACLMixin from arkindex.training.models import ModelVersion, ModelVersionState from arkindex.training.serializers import ModelVersionLightSerializer -from arkindex.users.models import Role -from arkindex.users.utils import get_max_level # To prevent each element worker to retrieve contextual information # (process, worker version, model version…) with extra GET requests, we @@ -63,15 +62,22 @@ class WorkerRunSerializer(WorkerACLMixin, serializers.ModelSerializer): model_version = ModelVersionLightSerializer(read_only=True) model_version_id = serializers.PrimaryKeyRelatedField( - queryset=ModelVersion.objects.all().select_related('model'), + queryset=ModelVersion.objects.none(), required=False, allow_null=True, write_only=True, source='model_version', style={'base_template': 'input.html'}, - help_text='UUID of the ModelVersion for this WorkerRun, or `null` if none is set. ' - 'Only ModelVersions in an `available` state may be set. ' - 'Contributor access to the model is required.', + help_text=dedent( + """ + UUID of the ModelVersion for this WorkerRun, or `null` if none is set. + + Only ModelVersions in an `available` state may be set. + + Guest access to the model is required. + For model versions without a `tag`, contributor access is required. + """ + ), ) configuration = WorkerConfigurationSerializer(read_only=True) @@ -111,6 +117,12 @@ class WorkerRunSerializer(WorkerACLMixin, serializers.ModelSerializer): 'summary', ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.context.get('request'): + user = self.context['request'].user + self.fields['model_version_id'].queryset = ModelVersion.objects.executable(user) + def validate(self, data): data = super().validate(data) errors = defaultdict(list) @@ -150,11 +162,7 @@ class WorkerRunSerializer(WorkerACLMixin, serializers.ModelSerializer): errors['configuration_id'].append('The configuration must be part of the same worker.') if model_version: - # We cannot use both the WorkerACLMixin and the TrainingModelMixin at once! - access_level = get_max_level(self.context["request"].user, model_version.model) - if not access_level or access_level < Role.Contributor.value: - errors['model_version_id'].append('You do not have contributor access to this model.') - elif model_version.state != ModelVersionState.Available: + if model_version.state != ModelVersionState.Available: errors['model_version_id'].append('This ModelVersion is not in an Available state.') if not worker_version.model_usage: diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py index 802bafcff8..5901bcf9f6 100644 --- a/arkindex/process/tests/test_workerruns.py +++ b/arkindex/process/tests/test_workerruns.py @@ -46,7 +46,32 @@ class TestWorkerRuns(FixtureAPITestCase): # Model and Model version setup cls.model_1 = Model.objects.create(name='My model') cls.model_1.memberships.create(user=cls.user, level=Role.Contributor.value) - cls.model_version_1 = ModelVersion.objects.create(model=cls.model_1, state=ModelVersionState.Available, size=8, hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb') + cls.model_version_1 = ModelVersion.objects.create( + model=cls.model_1, + state=ModelVersionState.Available, + size=8, + hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', + ) + + cls.model_2 = Model.objects.create(name='Their model') + cls.model_2.memberships.create(user=cls.user, level=Role.Guest.value) + cls.model_version_2 = cls.model_2.versions.create( + state=ModelVersionState.Available, + tag='blah', + size=8, + hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', + ) + + cls.model_3 = Model.objects.create(name='Our model', public=True) + cls.model_version_3 = cls.model_3.versions.create( + state=ModelVersionState.Available, + tag='blah', + size=8, + hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', + ) cls.agent = Agent.objects.create( farm=Farm.objects.first(), @@ -205,8 +230,8 @@ class TestWorkerRuns(FixtureAPITestCase): self.run_1.configuration = configuration self.run_1.save() - # Having a model version set adds three queries, having a configuration adds one - query_count = 11 + bool(model_version) * 3 + bool(configuration) + # Having a model version set adds two queries, having a configuration adds one + query_count = 11 + bool(model_version) * 2 + bool(configuration) with self.assertNumQueries(query_count): response = self.client.post( @@ -1348,7 +1373,7 @@ class TestWorkerRuns(FixtureAPITestCase): parents=[], ) - with self.assertNumQueries(11): + with self.assertNumQueries(10): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -1385,7 +1410,7 @@ class TestWorkerRuns(FixtureAPITestCase): ) random_model_version_uuid = str(uuid.uuid4()) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -1431,7 +1456,7 @@ class TestWorkerRuns(FixtureAPITestCase): archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', ) - with self.assertNumQueries(11): + with self.assertNumQueries(9): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -1443,9 +1468,67 @@ class TestWorkerRuns(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), { - 'model_version_id': ['You do not have contributor access to this model.'] + 'model_version_id': [f'Invalid pk "{model_version_no_access.id}" - object does not exist.'], }) + def test_update_model_version_guest(self): + """ + Cannot update a worker run with a model_version when you only have guest access to the model, + and the model version has no tag or is not available + """ + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version_no_model = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=True + ) + run_2 = self.process_1.worker_runs.create( + version=version_no_model, + parents=[], + ) + + cases = [ + # On a model with a membership giving guest access + (self.model_version_2, None, ModelVersionState.Created), + (self.model_version_2, None, ModelVersionState.Error), + (self.model_version_2, None, ModelVersionState.Available), + (self.model_version_2, 'blah', ModelVersionState.Created), + (self.model_version_2, 'blah', ModelVersionState.Error), + # On a public model with no membership + (self.model_version_3, None, ModelVersionState.Created), + (self.model_version_3, None, ModelVersionState.Error), + (self.model_version_3, None, ModelVersionState.Available), + (self.model_version_3, 'blah', ModelVersionState.Created), + (self.model_version_3, 'blah', ModelVersionState.Error), + ] + + for model_version, tag, state in cases: + with self.subTest(model_version=model_version, tag=tag, state=state): + model_version.tag = tag + model_version.state = state + model_version.save() + + with self.assertNumQueries(9): + response = self.client.put( + reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), + data={ + 'model_version_id': str(model_version.id), + 'parents': [], + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), { + 'model_version_id': [f'Invalid pk "{model_version.id}" - object does not exist.'], + }) + def test_update_model_version_unavailable(self): self.client.force_login(self.user) rev_2 = self.repo.revisions.create( @@ -1466,7 +1549,7 @@ class TestWorkerRuns(FixtureAPITestCase): self.model_version_1.state = ModelVersionState.Error self.model_version_1.save() - with self.assertNumQueries(11): + with self.assertNumQueries(10): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -1505,77 +1588,85 @@ class TestWorkerRuns(FixtureAPITestCase): # Check generated summary, before updating, there should be only information about the worker version self.assertEqual(run.summary, f"Worker {self.worker_1.name} @ {str(version_with_model.id)[:6]}") - with self.assertNumQueries(14): - response = self.client.put( - reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), - data={ - 'model_version_id': str(self.model_version_1.id), - 'parents': [] - }, - format='json', - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - run.refresh_from_db() + model_versions = [ + # Version on a model with contributor access + self.model_version_1, + # Available version with tag on a model with guest access + self.model_version_2, + # Available version with tag on a public model + self.model_version_3, + ] + for model_version in model_versions: + with self.subTest(model_version=model_version): + with self.assertNumQueries(14): + response = self.client.put( + reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), + data={ + 'model_version_id': str(model_version.id), + 'parents': [], + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.json(), { - 'id': str(run.id), - 'configuration': None, - 'model_version': { - 'id': str(self.model_version_1.id), - 'configuration': {}, - 'model': { - 'id': str(self.model_1.id), - 'name': 'My model' - }, - 'size': 8, - 'state': 'available', - 'tag': None - }, - 'parents': [], - 'process': { - 'id': str(self.process_1.id), - 'activity_state': 'disabled', - 'corpus': str(self.corpus.id), - 'mode': 'workers', - 'model_id': None, - 'name': None, - 'state': 'unscheduled', - 'test_folder_id': None, - 'train_folder_id': None, - 'use_cache': False, - 'validation_folder_id': None, - }, - 'worker_version': { - 'id': str(version_with_model.id), - 'configuration': {'test': 'test2'}, - 'docker_image': None, - 'docker_image_iid': None, - 'docker_image_name': f'my_repo.fake/workers/worker/reco:{version_with_model.id}', - 'gpu_usage': 'disabled', - 'model_usage': True, - 'revision': { - 'id': str(rev.id), - 'author': 'bob', - 'commit_url': 'http://my_repo.fake/workers/worker/commit/2', - 'created': rev.created.isoformat().replace('+00:00', 'Z'), - 'hash': '2', - 'message': 'beep boop', - 'refs': [] - }, - 'state': 'created', - 'worker': { - 'id': str(self.worker_1.id), - 'name': 'Recognizer', - 'slug': 'reco', - 'type': 'recognizer' - } - }, - 'summary': f'Worker Recognizer @ {str(version_with_model.id)[:6]} with model My model @ {str(self.model_version_1.id)[:6]}', - }) - self.assertEqual(run.model_version_id, self.model_version_1.id) - # Check generated summary, after updating, there should be information about the model loaded - self.assertEqual(run.summary, f"Worker {self.worker_1.name} @ {str(version_with_model.id)[:6]} with model {self.model_version_1.model.name} @ {str(self.model_version_1.id)[:6]}") + run.refresh_from_db() + self.assertEqual(response.json(), { + 'id': str(run.id), + 'configuration': None, + 'model_version': { + 'id': str(model_version.id), + 'configuration': {}, + 'model': { + 'id': str(model_version.model.id), + 'name': model_version.model.name + }, + 'size': 8, + 'state': 'available', + 'tag': model_version.tag, + }, + 'parents': [], + 'process': { + 'id': str(self.process_1.id), + 'activity_state': 'disabled', + 'corpus': str(self.corpus.id), + 'mode': 'workers', + 'model_id': None, + 'name': None, + 'state': 'unscheduled', + 'test_folder_id': None, + 'train_folder_id': None, + 'use_cache': False, + 'validation_folder_id': None, + }, + 'worker_version': { + 'id': str(version_with_model.id), + 'configuration': {'test': 'test2'}, + 'docker_image': None, + 'docker_image_iid': None, + 'docker_image_name': f'my_repo.fake/workers/worker/reco:{version_with_model.id}', + 'gpu_usage': 'disabled', + 'model_usage': True, + 'revision': { + 'id': str(rev.id), + 'author': 'bob', + 'commit_url': 'http://my_repo.fake/workers/worker/commit/2', + 'created': rev.created.isoformat().replace('+00:00', 'Z'), + 'hash': '2', + 'message': 'beep boop', + 'refs': [] + }, + 'state': 'created', + 'worker': { + 'id': str(self.worker_1.id), + 'name': 'Recognizer', + 'slug': 'reco', + 'type': 'recognizer' + } + }, + 'summary': f'Worker Recognizer @ {str(version_with_model.id)[:6]} with model {model_version.model.name} @ {str(model_version.id)[:6]}', + }) + self.assertEqual(run.model_version_id, model_version.id) + self.assertEqual(run.summary, f"Worker {version_with_model.worker.name} @ {str(version_with_model.id)[:6]} with model {model_version.model.name} @ {str(model_version.id)[:6]}") def test_update_configuration_and_model_version(self): """ @@ -1779,8 +1870,8 @@ class TestWorkerRuns(FixtureAPITestCase): configuration=None if configuration else self.configuration_1, ) - # Having a model version set adds three queries, having a configuration adds one - query_count = 8 + bool(model_version) * 3 + bool(configuration) + # Having a model version set adds two queries, having a configuration adds one + query_count = 8 + bool(model_version) * 2 + bool(configuration) with self.assertNumQueries(query_count): response = self.client.put( @@ -2175,7 +2266,7 @@ class TestWorkerRuns(FixtureAPITestCase): parents=[], ) - with self.assertNumQueries(11): + with self.assertNumQueries(10): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -2211,7 +2302,7 @@ class TestWorkerRuns(FixtureAPITestCase): ) random_model_version_uuid = str(uuid.uuid4()) - with self.assertNumQueries(8): + with self.assertNumQueries(9): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -2250,7 +2341,7 @@ class TestWorkerRuns(FixtureAPITestCase): model_no_access = Model.objects.create(name='Secret model') model_version_no_access = ModelVersion.objects.create(model=model_no_access, state=ModelVersionState.Available, size=8, hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', archive_hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb') - with self.assertNumQueries(11): + with self.assertNumQueries(9): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), data={ @@ -2261,9 +2352,66 @@ class TestWorkerRuns(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), { - 'model_version_id': ['You do not have contributor access to this model.'], + 'model_version_id': [f'Invalid pk "{model_version_no_access.id}" - object does not exist.'], }) + def test_partial_update_model_version_guest(self): + """ + Cannot update a worker run with a model_version when you only have guest access to the model, + and the model version has no tag or is not available + """ + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version_no_model = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=True + ) + run_2 = self.process_1.worker_runs.create( + version=version_no_model, + parents=[], + ) + + cases = [ + # On a model with a membership giving guest access + (self.model_version_2, None, ModelVersionState.Created), + (self.model_version_2, None, ModelVersionState.Error), + (self.model_version_2, None, ModelVersionState.Available), + (self.model_version_2, 'blah', ModelVersionState.Created), + (self.model_version_2, 'blah', ModelVersionState.Error), + # On a public model with no membership + (self.model_version_3, None, ModelVersionState.Created), + (self.model_version_3, None, ModelVersionState.Error), + (self.model_version_3, None, ModelVersionState.Available), + (self.model_version_3, 'blah', ModelVersionState.Created), + (self.model_version_3, 'blah', ModelVersionState.Error), + ] + + for model_version, tag, state in cases: + with self.subTest(model_version=model_version, tag=tag, state=state): + model_version.tag = tag + model_version.state = state + model_version.save() + + with self.assertNumQueries(9): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run_2.id)}), + data={ + 'model_version_id': str(model_version.id), + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), { + 'model_version_id': [f'Invalid pk "{model_version.id}" - object does not exist.'], + }) + def test_partial_update_model_version_unavailable(self): self.client.force_login(self.user) rev_2 = self.repo.revisions.create( @@ -2284,7 +2432,7 @@ class TestWorkerRuns(FixtureAPITestCase): self.model_version_1.state = ModelVersionState.Error self.model_version_1.save() - with self.assertNumQueries(11): + with self.assertNumQueries(10): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -2322,74 +2470,84 @@ class TestWorkerRuns(FixtureAPITestCase): self.assertIsNone(run.model_version_id) self.assertEqual(run.summary, f"Worker Recognizer @ {str(version_with_model.id)[:6]}") - with self.assertNumQueries(14): - response = self.client.patch( - reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), - data={ - 'model_version_id': str(self.model_version_1.id), - }, - format='json', - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) + model_versions = [ + # Version on a model with contributor access + self.model_version_1, + # Available version with tag on a model with guest access + self.model_version_2, + # Available version with tag on a public model + self.model_version_3, + ] + for model_version in model_versions: + with self.subTest(model_version=model_version): + with self.assertNumQueries(14): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), + data={ + 'model_version_id': str(model_version.id), + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) - run.refresh_from_db() - self.assertEqual(response.json(), { - 'id': str(run.id), - 'configuration': None, - 'model_version': { - 'id': str(self.model_version_1.id), - 'configuration': {}, - 'model': { - 'id': str(self.model_1.id), - 'name': 'My model' - }, - 'size': 8, - 'state': 'available', - 'tag': None - }, - 'parents': [], - 'process': { - 'id': str(self.process_1.id), - 'activity_state': 'disabled', - 'corpus': str(self.corpus.id), - 'mode': 'workers', - 'model_id': None, - 'name': None, - 'state': 'unscheduled', - 'test_folder_id': None, - 'train_folder_id': None, - 'use_cache': False, - 'validation_folder_id': None, - }, - 'worker_version': { - 'id': str(version_with_model.id), - 'configuration': {'test': 'test2'}, - 'docker_image': None, - 'docker_image_iid': None, - 'docker_image_name': f'my_repo.fake/workers/worker/reco:{version_with_model.id}', - 'gpu_usage': 'disabled', - 'model_usage': True, - 'revision': { - 'id': str(rev.id), - 'author': 'bob', - 'commit_url': 'http://my_repo.fake/workers/worker/commit/2', - 'created': rev.created.isoformat().replace('+00:00', 'Z'), - 'hash': '2', - 'message': 'beep boop', - 'refs': [] - }, - 'state': 'created', - 'worker': { - 'id': str(self.worker_1.id), - 'name': 'Recognizer', - 'slug': 'reco', - 'type': 'recognizer' - } - }, - 'summary': f'Worker Recognizer @ {str(version_with_model.id)[:6]} with model My model @ {str(self.model_version_1.id)[:6]}', - }) - self.assertEqual(run.model_version_id, self.model_version_1.id) - self.assertEqual(run.summary, f"Worker {version_with_model.worker.name} @ {str(version_with_model.id)[:6]} with model {self.model_version_1.model.name} @ {str(self.model_version_1.id)[:6]}") + run.refresh_from_db() + self.assertEqual(response.json(), { + 'id': str(run.id), + 'configuration': None, + 'model_version': { + 'id': str(model_version.id), + 'configuration': {}, + 'model': { + 'id': str(model_version.model.id), + 'name': model_version.model.name + }, + 'size': 8, + 'state': 'available', + 'tag': model_version.tag, + }, + 'parents': [], + 'process': { + 'id': str(self.process_1.id), + 'activity_state': 'disabled', + 'corpus': str(self.corpus.id), + 'mode': 'workers', + 'model_id': None, + 'name': None, + 'state': 'unscheduled', + 'test_folder_id': None, + 'train_folder_id': None, + 'use_cache': False, + 'validation_folder_id': None, + }, + 'worker_version': { + 'id': str(version_with_model.id), + 'configuration': {'test': 'test2'}, + 'docker_image': None, + 'docker_image_iid': None, + 'docker_image_name': f'my_repo.fake/workers/worker/reco:{version_with_model.id}', + 'gpu_usage': 'disabled', + 'model_usage': True, + 'revision': { + 'id': str(rev.id), + 'author': 'bob', + 'commit_url': 'http://my_repo.fake/workers/worker/commit/2', + 'created': rev.created.isoformat().replace('+00:00', 'Z'), + 'hash': '2', + 'message': 'beep boop', + 'refs': [] + }, + 'state': 'created', + 'worker': { + 'id': str(self.worker_1.id), + 'name': 'Recognizer', + 'slug': 'reco', + 'type': 'recognizer' + } + }, + 'summary': f'Worker Recognizer @ {str(version_with_model.id)[:6]} with model {model_version.model.name} @ {str(model_version.id)[:6]}', + }) + self.assertEqual(run.model_version_id, model_version.id) + self.assertEqual(run.summary, f"Worker {version_with_model.worker.name} @ {str(version_with_model.id)[:6]} with model {model_version.model.name} @ {str(model_version.id)[:6]}") def test_partial_update_model_version_with_configuration(self): """ @@ -2589,8 +2747,8 @@ class TestWorkerRuns(FixtureAPITestCase): configuration=None if configuration else self.configuration_1, ) - # Having a model version set adds three queries, having a configuration adds one - query_count = 8 + bool(model_version) * 3 + bool(configuration) + # Having a model version set adds two queries, having a configuration adds one + query_count = 8 + bool(model_version) * 2 + bool(configuration) with self.assertNumQueries(query_count): response = self.client.patch( diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 912215ed6f..60858e77b4 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -72,7 +72,7 @@ from arkindex.users.utils import get_max_level }, ), ) -class ModelVersionsList(TrainingModelMixin, ListCreateAPIView): +class ModelVersionsList(ListCreateAPIView): permission_classes = (IsVerified, ) serializer_class = ModelVersionSerializer @@ -133,7 +133,7 @@ class ModelVersionsList(TrainingModelMixin, ListCreateAPIView): description='Delete a model version.\n\nRequires an **admin** access on the related model.' ) ) -class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView): +class ModelVersionsRetrieve(RetrieveUpdateDestroyAPIView): """Retrieve a version of a Machine Learning model. Requires a **guest** access to the model, for versions that are available and have a tag, @@ -360,7 +360,7 @@ class ModelRetrieve(TrainingModelMixin, RetrieveAPIView): ] ), ) -class ModelVersionDownload(TrainingModelMixin, RetrieveAPIView): +class ModelVersionDownload(RetrieveAPIView): queryset = ModelVersion.objects.all() def check_object_permissions(self, request, model_version): diff --git a/arkindex/training/managers.py b/arkindex/training/managers.py new file mode 100644 index 0000000000..f14776bc57 --- /dev/null +++ b/arkindex/training/managers.py @@ -0,0 +1,40 @@ +from django.db.models import Q + +from arkindex.users.managers import BaseACLManager +from arkindex.users.models import Role + + +class ModelVersionManager(BaseACLManager): + + def executable(self, user): + """ + Model versions that are allowed to be executed on any model by a user. + + This can includes model versions that may not be `available`, + to allow for more specific error messages to be shown, as well as for users + to see the model versions that they are allowed to execute, but cannot + for reasons other than permissions. + """ + # Without authentication, or as a Ponos agent, nothing is executable + # Agents are not users, filtering on them will make Django raise exceptions, so we have to handle them separately. + if user.is_anonymous or getattr(user, 'is_agent', False): + return self.none() + + # Admins can execute anything + if user.is_admin: + return self.all() + + from arkindex.training.models import Model, ModelVersionState + + # We have no choice but to make two separate subqueries because Django does not handle + # custom subqueries in a FROM clause, so we cannot tell which level the user has on + # each model via the `max_level` annotation included by `filter_rights`. + guest_model_ids = self.filter_rights(user, Model, Role.Guest.value).values('id') + contributor_model_ids = self.filter_rights(user, Model, Role.Contributor.value).values('id') + + return self.filter( + # With contributor access to the model, all versions are executable + Q(model_id__in=contributor_model_ids) + # With guest access, only available versions with tags are executable + | Q(model_id__in=guest_model_ids, tag__isnull=False, state=ModelVersionState.Available) + ) diff --git a/arkindex/training/models.py b/arkindex/training/models.py index 8cca9f93de..ccecef001f 100644 --- a/arkindex/training/models.py +++ b/arkindex/training/models.py @@ -15,6 +15,7 @@ from enumfields import Enum, EnumField from arkindex.project.aws import S3FileMixin from arkindex.project.fields import MD5HashField from arkindex.project.models import IndexableModel +from arkindex.training.managers import ModelVersionManager logger = logging.getLogger(__name__) @@ -91,6 +92,8 @@ class ModelVersion(S3FileMixin, IndexableModel): # Store dictionary of paramseters given by the ML developer configuration = models.JSONField(default=dict) + objects = ModelVersionManager() + s3_bucket = settings.AWS_TRAINING_BUCKET class Meta: diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 4ed2fd5bbc..1472b77ecc 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -11,7 +11,6 @@ from rest_framework.validators import UniqueTogetherValidator from arkindex.documents.serializers.elements import ElementListSerializer from arkindex.ponos.models import Task -from arkindex.project.mixins import TrainingModelMixin from arkindex.project.serializer_fields import EnumField from arkindex.training.models import ( Dataset, @@ -53,7 +52,7 @@ class ModelLightSerializer(serializers.ModelSerializer): fields = ('id', 'name') -class ModelSerializer(TrainingModelMixin, ModelLightSerializer): +class ModelSerializer(ModelLightSerializer): # Actually define the field to avoid the field-level automatically generated UniqueValidator rights = serializers.SerializerMethodField(read_only=True) diff --git a/arkindex/users/managers.py b/arkindex/users/managers.py index 0f4f195fb5..00e43fa92e 100644 --- a/arkindex/users/managers.py +++ b/arkindex/users/managers.py @@ -1,4 +1,5 @@ from django.contrib.auth.models import BaseUserManager +from django.db.models import Manager class UserManager(BaseUserManager): @@ -37,3 +38,14 @@ class UserManager(BaseUserManager): user.is_admin = True user.save(using=self._db) return user + + +class BaseACLManager(Manager): + + def filter_rights(self, *args, **kwargs): + # Avoid circular dependencies: + # models with ACL set up are imported in arkindex.users.utils to be added to an enum, + # and models with ACL will also be using managers that inherit from this manager, + # with most ACL checks depending on arkindex.users.utils.filter_rights. + from arkindex.users.utils import filter_rights + return filter_rights(*args, **kwargs) -- GitLab