diff --git a/arkindex/process/api.py b/arkindex/process/api.py index ccbf8b7c5f67091d774039ebe074dce8029eb13e..4891f3dab2eafa5128b07e675113243c494eeb68 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from datetime import timedelta from textwrap import dedent from uuid import UUID @@ -128,7 +129,7 @@ from arkindex.project.pagination import CountCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.project.tools import PercentileCont from arkindex.project.triggers import process_delete -from arkindex.training.models import Dataset +from arkindex.training.models import Dataset, Model from arkindex.training.serializers import DatasetSerializer from arkindex.users.models import OAuthCredentials, Role, Scope @@ -962,6 +963,12 @@ class RevisionRetrieve(RepositoryACLMixin, RetrieveAPIView): description='Filter by or exclude workers that have been marked as archived.', required=False, ), + OpenApiParameter( + 'compatible_model', + type=UUID, + description='Filter workers compatible with a specific model.', + required=False, + ), ] ), post=extend_schema( @@ -992,14 +999,28 @@ class WorkerList(ListCreateAPIView): def filter_queryset(self, qs): filters = Q() + errors = defaultdict(list) + + compatible_model_id = self.request.query_params.get('compatible_model') + if compatible_model_id: + try: + compatible_model_id = UUID(compatible_model_id) + compatible_model = Model.objects.readable(self.request.user).get(id=compatible_model_id) + except (TypeError, ValueError): + errors['compatible_model'].append('Invalid UUID') + except Model.DoesNotExist: + errors['compatible_model'].append("This model does not exist or you don't have access to it") + else: + filters &= Q(models=compatible_model) repo_id = self.request.query_params.get('repository_id') if repo_id: try: repo_id = UUID(repo_id) except (TypeError, ValueError): - raise ValidationError({'repository_id': ['Invalid UUID']}) - filters &= Q(repository_id=repo_id) + errors['repository_id'].append('Invalid UUID') + else: + filters &= Q(repository_id=repo_id) name_filter = self.request.query_params.get('name') if name_filter: @@ -1015,12 +1036,16 @@ class WorkerList(ListCreateAPIView): try: worker_type = WorkerType.objects.get(slug=worker_type) except WorkerType.DoesNotExist: - raise ValidationError({'type': ['No registered worker type with that slug.']}) - filters &= Q(type_id=worker_type.id) + errors['type'].append('No registered worker type with that slug.') + else: + filters &= Q(type_id=worker_type.id) if 'archived' in self.request.query_params: filters &= Q(archived__isnull=self.request.query_params['archived'].lower().strip() in ('false', '0')) + if errors: + raise ValidationError(errors) + return super().filter_queryset(qs.filter(filters)) def get_queryset(self): diff --git a/arkindex/process/tests/test_workers.py b/arkindex/process/tests/test_workers.py index 9d5f603c6589c0d1235891232fe76e90d8f35c4a..b77525c42e4395d99887be965cc0de80fac0454f 100644 --- a/arkindex/process/tests/test_workers.py +++ b/arkindex/process/tests/test_workers.py @@ -158,6 +158,95 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): ] }) + def test_workers_list_compatible_model(self): + """ + User is able to filter workers by a compatible model. + """ + self.model.memberships.create(user=self.user, level=Role.Guest.value) + self.worker_generic.models.set([self.model]) + + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.get( + reverse('api:workers-list'), + {'compatible_model': str(self.model.id)}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), { + 'count': 1, + 'next': None, + 'number': 1, + 'previous': None, + 'results': [ + { + 'id': str(self.worker_generic.id), + 'repository_id': str(self.repo.id), + 'name': 'Generic worker with a Model', + 'description': '', + 'slug': 'generic', + 'type': 'recognizer', + 'archived': False, + }, + ] + }) + + def test_workers_list_compatible_model_public(self): + public_model = Model.objects.create(name='Public', public=True) + self.worker_reco.models.set([public_model]) + + self.client.force_login(self.user) + with self.assertNumQueries(9): + response = self.client.get( + reverse('api:workers-list'), + {'compatible_model': str(public_model.id)}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json(), { + 'count': 1, + 'next': None, + 'number': 1, + 'previous': None, + 'results': [{ + 'id': str(self.worker_reco.id), + 'repository_id': str(self.repo.id), + 'name': 'Recognizer', + 'description': '', + 'slug': 'reco', + 'type': 'recognizer', + 'archived': False, + }], + }) + + def test_worker_list_compatible_model_private(self): + private_model = Model.objects.create(name='Private') + self.worker_reco.models.set([private_model]) + self.client.force_login(self.user) + with self.assertNumQueries(5): + response = self.client.get( + reverse('api:workers-list'), + {'compatible_model': str(private_model.id)}, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'compatible_model': ["This model does not exist or you don't have access to it"] + }) + + def test_workers_list_filter_invalid_query_params(self): + self.client.force_login(self.user) + with self.assertNumQueries(2): + response = self.client.get( + reverse('api:workers-list'), + { + 'compatible_model': 'A', + 'repository_id': 'A', + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'compatible_model': ['Invalid UUID'], + 'repository_id': ['Invalid UUID'], + }) + def test_workers_list_filter_type_slug(self): """ User is able to filter workers on the repository by worker type slug diff --git a/arkindex/training/managers.py b/arkindex/training/managers.py index dc6f263a55513829a8f6a9fa2ef568053906c709..53670b62cf03dcf459e0ddb380581dd280dee3e7 100644 --- a/arkindex/training/managers.py +++ b/arkindex/training/managers.py @@ -42,6 +42,17 @@ class ModelVersionManager(BaseACLManager): class ModelManager(BaseACLManager): + def readable(self, user): + """ + Models that can be listed by a user. + """ + if user.is_anonymous or getattr(user, 'is_agent', False): + return self.none() + if user.is_admin: + return self.all() + # Allow viewing any model marked as public or with guest access + return self.filter(id__in=self.filter_rights(user, self.model, Role.Guest.value).values('id')) + def editable(self, user): """ Models whose attributes can be modified by a user.