diff --git a/arkindex/process/api.py b/arkindex/process/api.py index e13b106f3edf54b59cb752adf8fe1f6f42e0c692..ccbf8b7c5f67091d774039ebe074dce8029eb13e 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -573,7 +573,7 @@ class StartProcess(CorpusACLMixin, CreateAPIView): .filter(corpus_id__isnull=False) .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related( 'version__worker__repository', - 'model_version', + 'model_version__model', 'configuration', ))) .prefetch_related('datasets') @@ -2117,7 +2117,7 @@ class ApplyProcessTemplate(ProcessACLMixin, WorkerACLMixin, CreateAPIView): def get_queryset(self): return Process.objects \ .filter(mode=ProcessMode.Template) \ - .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version__worker__type', 'model_version'))) + .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version__worker__type', 'model_version__model'))) def check_object_permissions(self, request, template): access_level = self.process_access_level(template) diff --git a/arkindex/process/builder.py b/arkindex/process/builder.py index 4100afc07b6694bacf6e81fa02a9338aa6451dc9..8469402e2de3766e1072980f8cdcc23428557bba 100644 --- a/arkindex/process/builder.py +++ b/arkindex/process/builder.py @@ -38,7 +38,7 @@ class ProcessBuilder(object): queryset=( WorkerRun.objects .using('default') - .select_related('version__worker__repository', 'model_version') + .select_related('version__worker__repository', 'model_version__model') ), ) ) @@ -173,10 +173,17 @@ class ProcessBuilder(object): raise ValidationError("Some worker versions require a GPU and the `use_gpu` option is disabled.") @prefetch_worker_runs - def validate_archived_workers(self): + def validate_archived(self): if any(run.version.worker.archived for run in self.process.worker_runs.all()): raise ValidationError("Some worker versions are on archived workers and cannot be executed.") + if any( + run.model_version.model.archived + for run in self.process.worker_runs.all() + if run.model_version is not None + ): + raise ValidationError("Some model versions are on archived models and cannot be executed.") + def validate_repository(self) -> None: if self.process.revision is None: raise ValidationError('A revision is required to create an import workflow from GitLab repository') @@ -193,11 +200,11 @@ class ProcessBuilder(object): def validate_workers(self) -> None: self.validate_gpu_requirement() - self.validate_archived_workers() + self.validate_archived() def validate_dataset(self) -> None: self.validate_gpu_requirement() - self.validate_archived_workers() + self.validate_archived() if self.process.generate_thumbnails: raise ValidationError('Thumbnails generation is incompatible with dataset mode processes.') diff --git a/arkindex/process/serializers/imports.py b/arkindex/process/serializers/imports.py index 4e4eaa4b48fa552fa45bdcd87aa3159620b01d06..e5742e15eb77bd88ce276f738b718da7cde4f6db 100644 --- a/arkindex/process/serializers/imports.py +++ b/arkindex/process/serializers/imports.py @@ -600,6 +600,13 @@ class ApplyProcessTemplateSerializer(ProcessACLMixin, serializers.Serializer): ): raise ValidationError(detail='This template contains one or more unavailable model versions and cannot be applied.') + if any( + run.model_version.model.archived + for run in template_process.worker_runs.all() + if run.model_version_id is not None + ): + raise ValidationError(detail='This template contains one or more model versions from archived models and cannot be applied.') + return data diff --git a/arkindex/process/serializers/worker_runs.py b/arkindex/process/serializers/worker_runs.py index c04d0d0e4e800c1b68440d549af5c8d91acb08fd..a1cf053653aaee8404d9f708b42b86fc07bc5a68 100644 --- a/arkindex/process/serializers/worker_runs.py +++ b/arkindex/process/serializers/worker_runs.py @@ -123,7 +123,7 @@ class WorkerRunSerializer(WorkerACLMixin, serializers.ModelSerializer): super().__init__(*args, **kwargs) if self.context.get('request'): user = self.context['request'].user - self.fields['model_version_id'].queryset = ModelVersion.objects.executable(user) + self.fields['model_version_id'].queryset = ModelVersion.objects.executable(user).select_related('model') def validate(self, data): data = super().validate(data) @@ -169,6 +169,9 @@ class WorkerRunSerializer(WorkerACLMixin, serializers.ModelSerializer): if model_version.state != ModelVersionState.Available: errors['model_version_id'].append('This ModelVersion is not in an Available state.') + if model_version.model.archived: + errors['model_version_id'].append('This ModelVersion is part of an archived model.') + if worker_version.model_usage == FeatureUsage.Disabled: errors['model_version_id'].append('This worker version does not support models.') @@ -258,6 +261,9 @@ class UserWorkerRunSerializer(serializers.ModelSerializer): if not model_version.is_executable(self.context['request'].user): raise ValidationError(detail='You do not have guest access to this model.') + if model_version.model.archived: + raise ValidationError(detail='The model used to create a local worker run must not be archived.') + return model_version def validate(self, data): diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index 58f3e5525818db9c9a67a6907b57d202e3d763a2..80125798e18460ef51ec427c5b2b96e67986cd85 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -1789,7 +1789,7 @@ class TestProcesses(FixtureAPITestCase): 'farm': ['You do not have access to this farm.'], }) - def test_retry_archived(self): + def test_retry_archived_worker(self): self.elts_process.run() self.elts_process.tasks.all().update(state=State.Error) self.elts_process.finished = timezone.now() @@ -1810,6 +1810,33 @@ class TestProcesses(FixtureAPITestCase): 'Some worker versions are on archived workers and cannot be executed.', ]) + def test_retry_archived_model(self): + process = self.corpus.processes.create( + creator=self.user, + mode=ProcessMode.Workers, + farm=self.default_farm, + ) + process.worker_runs.create( + version=self.version_with_model, + model_version=self.model_version_1, + ) + process.run() + process.tasks.all().update(state=State.Error) + process.finished = timezone.now() + process.save() + self.model_1.archived = timezone.now() + self.model_1.save() + + self.client.force_login(self.user) + + with self.assertNumQueries(15): + response = self.client.post(reverse('api:process-retry', kwargs={'pk': process.id})) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), [ + 'Some model versions are on archived models and cannot be executed.', + ]) + @patch('arkindex.project.triggers.process_tasks.initialize_activity.delay') def test_retry_no_tasks(self, delay_mock): self.client.force_login(self.user) @@ -2410,6 +2437,24 @@ class TestProcesses(FixtureAPITestCase): {'model_version': ['This process contains one or more unavailable model versions and cannot be started.']}, ) + def test_start_process_archived_models(self): + process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) + process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None, model_version=self.model_version_1) + self.model_1.archived = timezone.now() + self.model_1.save() + self.assertFalse(process2.tasks.exists()) + + self.client.force_login(self.user) + with self.assertNumQueries(15): + response = self.client.post( + reverse('api:process-start', kwargs={'pk': str(process2.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual( + response.json(), + ['Some model versions are on archived models and cannot be executed.'], + ) + def test_start_process_required_fields_no_config(self): # Both workers now have a required field without a default value self.dla.configuration['user_configuration'] = { diff --git a/arkindex/process/tests/test_templates.py b/arkindex/process/tests/test_templates.py index 400097d526f7c2d96a1cfb6ca4c7217a39c77295..899f9771a1d69ae399bd86b437cb6961e6021cfa 100644 --- a/arkindex/process/tests/test_templates.py +++ b/arkindex/process/tests/test_templates.py @@ -67,8 +67,8 @@ class TestTemplates(FixtureAPITestCase): parents=[run_1.id], ) - model = Model.objects.create(name='moo') - cls.model_version = model.versions.create(state=ModelVersionState.Available) + cls.model = Model.objects.create(name='moo') + cls.model_version = cls.model.versions.create(state=ModelVersionState.Available) run_1 = cls.template.worker_runs.create( version=cls.version_1, parents=[], configuration=cls.worker_configuration @@ -408,6 +408,23 @@ class TestTemplates(FixtureAPITestCase): self.process.refresh_from_db() self.assertEqual(self.process.template, None) + def test_apply_archived_model(self): + self.model.archived = datetime.now(timezone.utc) + self.model.save() + self.client.force_login(self.user) + + with self.assertNumQueries(14): + response = self.client.post( + reverse('api:apply-process-template', kwargs={'pk': str(self.template.id)}), + data=json.dumps({"process_id": str(self.process.id)}), + content_type='application/json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {'non_field_errors': ['This template contains one or more model versions from archived models and cannot be applied.']}) + self.process.refresh_from_db() + self.assertEqual(self.process.template, None) + def test_apply_unsupported_mode(self): self.client.force_login(self.user) for mode in set(ProcessMode) - {ProcessMode.Workers, ProcessMode.Dataset, ProcessMode.Local, ProcessMode.Repository}: diff --git a/arkindex/process/tests/test_user_workerruns.py b/arkindex/process/tests/test_user_workerruns.py index c7c2918d38c1b81e4e8dd73b011b29e0e64fa3fe..b3c32e01cd4fafcd3fe6e260734361a05be67e14 100644 --- a/arkindex/process/tests/test_user_workerruns.py +++ b/arkindex/process/tests/test_user_workerruns.py @@ -15,7 +15,7 @@ from arkindex.process.models import ( WorkerVersionState, ) from arkindex.project.tests import FixtureAPITestCase -from arkindex.training.models import Model, ModelVersion, ModelVersionState +from arkindex.training.models import Model, ModelVersionState from arkindex.users.models import Right, Role, User @@ -36,6 +36,10 @@ class TestUserWorkerRuns(FixtureAPITestCase): }, state=WorkerVersionState.Created, ) + + cls.model = Model.objects.create(name='Some model', public=False) + cls.model_version = cls.model.versions.create(hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', archive_hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', size=8) + # Local worker run cls.local_process = Process.objects.get(mode=ProcessMode.Local, creator=cls.user) cls.local_run = WorkerRun.objects.get(process=cls.local_process) @@ -353,28 +357,35 @@ class TestUserWorkerRuns(FixtureAPITestCase): self.assertEqual(response.json(), {'configuration_id': ['Invalid pk "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" - object does not exist.']}) def test_create_user_run_model_no_access(self): - test_model = Model.objects.create(name='Some model', public=False) - model_version = ModelVersion.objects.create(model_id=test_model.id, hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', archive_hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', size=8) self.client.force_login(self.user) with self.assertNumQueries(10): response = self.client.post( reverse('api:user-worker-run-create'), - data={'worker_version_id': str(self.other_version.id), 'model_version_id': str(model_version.id)} + data={'worker_version_id': str(self.other_version.id), 'model_version_id': str(self.model_version.id)} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {'model_version_id': ['You do not have guest access to this model.']}) + def test_create_user_run_model_archived(self): + self.model.archived = datetime.now(timezone.utc) + self.model.save() + self.model.memberships.create(user=self.user, level=Role.Contributor.value) + self.client.force_login(self.user) + + with self.assertNumQueries(9): + response = self.client.post( + reverse('api:user-worker-run-create'), + data={'worker_version_id': str(self.other_version.id), 'model_version_id': str(self.model_version.id)} + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), {'model_version_id': ['The model used to create a local worker run must not be archived.']}) + def test_create_user_run_full(self): - test_model = Model.objects.create(name='Some model', public=False) - model_version = ModelVersion.objects.create( - state=ModelVersionState.Available, - model_id=test_model.id, - hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', - archive_hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', - size=8, - tag='0.1.3' - ) - Right.objects.create(user=self.user, content_object=test_model, level=Role.Guest.value) + self.model_version.tag = '0.1.3' + self.model_version.state = ModelVersionState.Available + self.model_version.save() + self.model.memberships.create(user=self.user, level=Role.Guest.value) test_configuration = WorkerConfiguration.objects.create(worker=self.other_version.worker, name="Some configuration", configuration={'param': 'value'}) self.client.force_login(self.user) @@ -384,7 +395,7 @@ class TestUserWorkerRuns(FixtureAPITestCase): reverse('api:user-worker-run-create'), data={ 'worker_version_id': str(self.other_version.id), - 'model_version_id': str(model_version.id), + 'model_version_id': str(self.model_version.id), 'configuration_id': str(test_configuration.id) } ) @@ -401,9 +412,9 @@ class TestUserWorkerRuns(FixtureAPITestCase): }, 'model_version': { 'configuration': {}, - 'id': str(model_version.id), + 'id': str(self.model_version.id), 'model': { - 'id': str(test_model.id), + 'id': str(self.model.id), 'name': 'Some model' }, 'size': 8, @@ -432,7 +443,7 @@ class TestUserWorkerRuns(FixtureAPITestCase): 'type': 'custom' } }, - 'summary': f"Worker Custom worker @ version 2 with model Some model @ {str(model_version.id)[:6]} using configuration 'Some configuration'", + 'summary': f"Worker Custom worker @ version 2 with model Some model @ {str(self.model_version.id)[:6]} using configuration 'Some configuration'", 'process': { 'activity_state': 'disabled', 'corpus': None, diff --git a/arkindex/process/tests/test_workerruns.py b/arkindex/process/tests/test_workerruns.py index 83973e88a23ee60bda78a93619c8f4646313de4f..cc08879e405276a4dbbef5e30d1f785b47ccbcd3 100644 --- a/arkindex/process/tests/test_workerruns.py +++ b/arkindex/process/tests/test_workerruns.py @@ -299,6 +299,23 @@ class TestWorkerRuns(FixtureAPITestCase): self.assertEqual(response.json(), {'worker_version_id': ['This WorkerVersion is part of an archived worker.']}) + def test_create_archived_model(self): + self.model_1.archived = datetime.now(timezone.utc) + self.model_1.save() + self.version_1.model_usage = FeatureUsage.Supported + self.version_1.save() + self.client.force_login(self.user) + + with self.assertNumQueries(13): + response = self.client.post( + reverse('api:worker-run-list', kwargs={'pk': str(self.process_2.id)}), + data={'worker_version_id': str(self.version_1.id), 'model_version_id': str(self.model_version_1.id)}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), {'model_version_id': ['This ModelVersion is part of an archived model.']}) + def test_create_invalid_process_id(self): self.client.force_login(self.user) @@ -1664,6 +1681,37 @@ class TestWorkerRuns(FixtureAPITestCase): 'model_version_id': ['This ModelVersion is not in an Available state.'] }) + def test_update_model_archived(self): + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=FeatureUsage.Required + ) + run = self.process_1.worker_runs.create(version=version) + self.model_1.archived = datetime.now(timezone.utc) + self.model_1.save() + + with self.assertNumQueries(10): + response = self.client.put( + reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), + data={ + 'model_version_id': str(self.model_version_1.id), + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), { + 'model_version_id': ['This ModelVersion is part of an archived model.'], + }) + def test_update_model_version_id(self): """ Update the worker run by adding a model_version with a worker version that supports it @@ -1698,7 +1746,7 @@ class TestWorkerRuns(FixtureAPITestCase): ] for model_version in model_versions: with self.subTest(model_version=model_version): - with self.assertNumQueries(14): + with self.assertNumQueries(13): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -1795,7 +1843,7 @@ class TestWorkerRuns(FixtureAPITestCase): # Check generated summary, before updating, there should be only information about the worker version self.assertEqual(run.summary, f"Worker {self.worker_1.name} @ {str(version_with_model.id)[:6]}") - with self.assertNumQueries(15): + with self.assertNumQueries(14): response = self.client.put( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -2558,6 +2606,37 @@ class TestWorkerRuns(FixtureAPITestCase): 'model_version_id': ['This ModelVersion is not in an Available state.'], }) + def test_partial_update_model_archived(self): + self.client.force_login(self.user) + rev_2 = self.repo.revisions.create( + hash='2', + message='beep boop', + author='bob', + ) + version = WorkerVersion.objects.create( + worker=self.worker_1, + revision=rev_2, + configuration={"test": "test2"}, + model_usage=FeatureUsage.Required + ) + run = self.process_1.worker_runs.create(version=version) + self.model_1.archived = datetime.now(timezone.utc) + self.model_1.save() + + with self.assertNumQueries(10): + response = self.client.patch( + reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), + data={ + 'model_version_id': str(self.model_version_1.id), + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), { + 'model_version_id': ['This ModelVersion is part of an archived model.'], + }) + def test_partial_update_model_version(self): """ Update the worker run by adding a model_version with a worker version that supports it @@ -2591,7 +2670,7 @@ class TestWorkerRuns(FixtureAPITestCase): ] for model_version in model_versions: with self.subTest(model_version=model_version): - with self.assertNumQueries(14): + with self.assertNumQueries(13): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ @@ -2686,7 +2765,7 @@ class TestWorkerRuns(FixtureAPITestCase): ) self.assertEqual(run.model_version_id, None) - with self.assertNumQueries(14): + with self.assertNumQueries(13): response = self.client.patch( reverse('api:worker-run-details', kwargs={'pk': str(run.id)}), data={ diff --git a/arkindex/training/admin.py b/arkindex/training/admin.py index e70c4c8bba72b969750f1866e1dcb238a4b8af57..7e5dea52806e3d44dcebab58fe8878bd114e735c 100644 --- a/arkindex/training/admin.py +++ b/arkindex/training/admin.py @@ -1,13 +1,15 @@ from django.contrib import admin from enumfields.admin import EnumFieldListFilter +from arkindex.project.admin import ArchivedListFilter from arkindex.training.models import Dataset, MetricKey, MetricValue, Model, ModelVersion class ModelAdmin(admin.ModelAdmin): - list_display = ('name', 'created', ) + list_display = ('name', 'created', 'archived') + list_filter = (ArchivedListFilter, ) search_fields = ('name', 'description', ) - fields = ('name', 'description', 'public', 'compatible_workers') + fields = ('name', 'description', 'public', 'archived', 'compatible_workers') class ModelVersionAdmin(admin.ModelAdmin): diff --git a/arkindex/training/api.py b/arkindex/training/api.py index cdecc15cd2fc4b67cd74a1429b51c2bdddd23b50..5b6e72505e0913ddc1ef17eed67c7eeeca771856 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -43,6 +43,7 @@ from arkindex.training.serializers import ( MetricValueBulkSerializer, MetricValueCreateSerializer, ModelCompatibleWorkerSerializer, + ModelCreateSerializer, ModelSerializer, ModelVersionSerializer, ModelVersionValidateSerializer, @@ -188,6 +189,11 @@ class ModelVersionsRetrieve(RetrieveUpdateDestroyAPIView): raise PermissionDenied(detail=error_msg) return super().check_object_permissions(request, model_version) + def perform_destroy(self, instance): + if instance.model.archived: + raise ValidationError('This model is archived.') + return super().perform_destroy(instance) + class ValidateModelVersion(TrainingModelMixin, GenericAPIView): """ @@ -206,8 +212,13 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView): def check_object_permissions(self, request, model_version): if not self.has_write_access(model_version.model): raise PermissionDenied(detail='You need a Contributor access to the model to validate this version.') + if model_version.state == ModelVersionState.Available: raise PermissionDenied(detail='This model version is already marked as available.') + + if model_version.model.archived: + raise ValidationError('This model version is part of an archived model.') + return super().check_object_permissions(request, model_version) @extend_schema( @@ -302,12 +313,14 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView): ) class ModelsList(TrainingModelMixin, ListCreateAPIView): permission_classes = (IsVerified, ) - serializer_class = ModelSerializer + serializer_class = ModelCreateSerializer + queryset = Model.objects.none() def get_queryset(self): filters = Q() if 'name' in self.request.query_params: filters &= Q(name__icontains=self.request.query_params['name']) + if 'compatible_worker' in self.request.query_params: data = self.request.query_params['compatible_worker'] try: @@ -316,6 +329,9 @@ class ModelsList(TrainingModelMixin, ListCreateAPIView): raise serializers.ValidationError({'compatible_worker': [f'{data} is not a valid UUID.']}) filters &= Q(compatible_workers__id=data) + if 'archived' in self.request.query_params: + filters &= Q(archived__isnull=self.request.query_params['archived'].lower().strip() in ('false', '0')) + # Use the default database to prevent a stale read when a model has just been created return self.readable_models.using('default').filter(filters).order_by('name') diff --git a/arkindex/training/migrations/0006_model_archived.py b/arkindex/training/migrations/0006_model_archived.py new file mode 100644 index 0000000000000000000000000000000000000000..250ed999f0158e0a77a6ad1a526dad2dbf453dd4 --- /dev/null +++ b/arkindex/training/migrations/0006_model_archived.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.7 on 2023-12-05 13:55 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('training', '0005_modelversion_configuration_object'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='archived', + field=models.DateTimeField(blank=True, default=None, null=True), + ), + ] diff --git a/arkindex/training/models.py b/arkindex/training/models.py index 95d356eb0487e096763bc9920d5d74f3c8ca6ec7..2904a067292a660bc5166de847b0334b84946190 100644 --- a/arkindex/training/models.py +++ b/arkindex/training/models.py @@ -16,6 +16,7 @@ from arkindex.project.aws import S3FileMixin from arkindex.project.fields import MD5HashField from arkindex.project.models import IndexableModel from arkindex.training.managers import ModelManager, ModelVersionManager +from arkindex.users.models import Role logger = logging.getLogger(__name__) @@ -31,6 +32,8 @@ class Model(IndexableModel): public = models.BooleanField(default=False) + archived = models.DateTimeField(null=True, blank=True, default=None) + # Link to the workers that are able to use this model compatible_workers = models.ManyToManyField('process.Worker', related_name='models', blank=True) @@ -50,6 +53,21 @@ class Model(IndexableModel): def __str__(self): return self.name + def is_archivable(self, user) -> bool: + """ + Whether the user can archive or unarchive this model + """ + if user.is_anonymous or getattr(user, 'is_agent', False): + return False + + if user.is_admin: + return True + + from arkindex.users.utils import get_max_level + level = get_max_level(user, self) + + return level is not None and level >= Role.Admin.value + class ModelVersionState(Enum): """ diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 0f8f5096439c33eaf5b76704e75674763be99e75..09698df482765c3401ecb5876e42b687d93aa9c0 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -14,7 +14,7 @@ from arkindex.documents.models import Element from arkindex.documents.serializers.elements import ElementListSerializer from arkindex.ponos.models import Task from arkindex.process.models import Worker -from arkindex.project.serializer_fields import EnumField +from arkindex.project.serializer_fields import ArchivedField, EnumField from arkindex.training.models import ( Dataset, DatasetElement, @@ -67,9 +67,18 @@ class ModelSerializer(ModelLightSerializer): # Actually define the field to avoid the field-level automatically generated UniqueValidator rights = serializers.SerializerMethodField(read_only=True) - class Meta: - model = Model - fields = ModelLightSerializer.Meta.fields + ('created', 'updated', 'description', 'rights') + archived = ArchivedField( + required=False, + help_text=dedent(""" + Whether this model is archived. + Model versions cannot be created on archived models, and they cannot be used in processes. + + Archiving or unarchiving an existing model requires admin access to the model. + """), + ) + + class Meta(ModelLightSerializer.Meta): + fields = ModelLightSerializer.Meta.fields + ('created', 'updated', 'archived', 'description', 'rights') def create(self, validated_data): instance = super().create(validated_data) @@ -98,6 +107,11 @@ class ModelSerializer(ModelLightSerializer): return rights +class ModelCreateSerializer(ModelSerializer): + # Exclude archived from model creation, since creating an archived model makes no sense + archived = ArchivedField(read_only=True) + + class CreateModelErrorResponseSerializer(serializers.Serializer): id = serializers.UUIDField(required=False, help_text="UUID of an existing model, if the error comes from a duplicate name.") name = serializers.CharField(required=False, help_text="Name of an existing model, if the error comes from a duplicate name.") @@ -125,6 +139,11 @@ class ModelCompatibleWorkerSerializer(serializers.ModelSerializer): # Any worker that the user can browse can be set self.fields['worker'].queryset = Worker.objects.executable(user) + def validate_model(self, model): + if model.archived: + raise ValidationError('This model is archived.') + return model + def validate(self, data): data = super().validate(data) # Ignore the uniqueness check for deletions @@ -194,6 +213,16 @@ class ModelVersionSerializer(serializers.ModelSerializer): qs = qs.exclude(id=self.instance.id) self.fields['parent'].queryset = qs + def validate(self, data): + model = data.get('model') + if not model: + model = self.instance.model + + if model.archived: + raise ValidationError({'model': ['This model is archived.']}) + + return data + @extend_schema_field(serializers.CharField(allow_null=True)) def get_s3_put_url(self, obj): if not self.context.get('is_contributor') or obj.state == ModelVersionState.Available: @@ -233,6 +262,7 @@ class MetricValueCreateSerializer(MetricValueSerializer): """ model_version_id = serializers.PrimaryKeyRelatedField( queryset=ModelVersion.objects.select_related('model'), + source='model_version', write_only=True ) name = serializers.CharField(write_only=True) @@ -242,12 +272,16 @@ class MetricValueCreateSerializer(MetricValueSerializer): model = MetricValue fields = MetricValueSerializer.Meta.fields + ('model_version_id', 'name', 'mode') - def validate_model_version_id(self, model_version_id): + def validate_model_version_id(self, model_version): # Assert user has a contributor access to the model - access_level = get_max_level(self.context['request'].user, model_version_id.model) + access_level = get_max_level(self.context['request'].user, model_version.model) if not access_level or access_level < Role.Contributor.value: raise PermissionDenied(detail='You do not have contributor access to this model.') - return model_version_id + + if model_version.model.archived: + raise ValidationError('This ModelVersion is part of an archived model.') + + return model_version def validate(self, data): errors = defaultdict(list) @@ -261,7 +295,7 @@ class MetricValueCreateSerializer(MetricValueSerializer): mode = self.initial_data.get('mode') step = data.get('step') value = data.get('value') - model_version = data.get('model_version_id') + model_version = data.get('model_version') # Retrieve or create the metric key. # Annotating the Metric Key with the count of associated Metric Values (as metric_values_count) in order to check @@ -320,7 +354,11 @@ class MetricValueBulkItemSerializer(serializers.Serializer): class MetricValueBulkSerializer(serializers.Serializer): - model_version_id = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.select_related('model').all(), write_only=True) + model_version_id = serializers.PrimaryKeyRelatedField( + queryset=ModelVersion.objects.select_related('model').all(), + source='model_version', + write_only=True, + ) step = serializers.IntegerField(min_value=0, required=False) metrics = MetricValueBulkItemSerializer( many=True, @@ -332,17 +370,21 @@ class MetricValueBulkSerializer(serializers.Serializer): """), ) - def validate_model_version_id(self, model_version_id): + def validate_model_version_id(self, model_version): # Assert user has a contributor access to the model - access_level = get_max_level(self.context['request'].user, model_version_id.model) + access_level = get_max_level(self.context['request'].user, model_version.model) if not access_level or access_level < Role.Contributor.value: raise PermissionDenied(detail='You do not have contributor access to this model.') - return model_version_id + + if model_version.model.archived: + raise ValidationError('This ModelVersion is part of an archived model.') + + return model_version def validate(self, data): errors = defaultdict(lambda: defaultdict(list)) data = super().validate(data) - model_version_id = data.get('model_version_id') + model_version = data.get('model_version') step = data.get('step') # Retrieve the existing Metric Keys corresponding to the request metric names @@ -350,7 +392,7 @@ class MetricValueBulkSerializer(serializers.Serializer): queryset = MetricKey.objects \ .annotate(metric_values_count=Count('values')) \ .annotate(values_with_steps=Count('values', filter=Q(values__step__isnull=False))) \ - .filter(name__in=metric_names, model_version=model_version_id) + .filter(name__in=metric_names, model_version=model_version) if step is not None: queryset = queryset.annotate(current_step_values=Count('values', filter=Q(values__step=step))) metrickeys = { @@ -409,7 +451,7 @@ class MetricValueBulkSerializer(serializers.Serializer): metric_key = MetricKey( id=uuid.uuid4(), name=item['name'], - model_version=validated_data['model_version_id'], + model_version=validated_data['model_version'], mode=item.get('mode') ) metric_keys.append(metric_key) diff --git a/arkindex/training/tests/test_metrics_api.py b/arkindex/training/tests/test_metrics_api.py index a27d9916339229a1fe5c9383800245f449a4fd12..460cb9aa9c439f79f03d1b2a7b78bb4d4bc87027 100644 --- a/arkindex/training/tests/test_metrics_api.py +++ b/arkindex/training/tests/test_metrics_api.py @@ -89,6 +89,25 @@ class TestMetricsAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You do not have contributor access to this model."}) + def test_create_metric_value_archived(self): + self.model.archived = datetime.now(timezone.utc) + self.model.save() + self.client.force_login(self.user) + + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:metric-create'), + data={ + 'model_version_id': str(self.model_version.id), + 'name': 'a test metric', + 'value': 1.2 + }, + format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {"model_version_id": ["This ModelVersion is part of an archived model."]}) + @patch('django.utils.timezone.now') def test_create_metric_value_existing_metric_key(self, datetime_mock): """ @@ -441,6 +460,34 @@ class TestMetricsAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You do not have contributor access to this model."}) + def test_bulk_create_metric_value_archived(self): + self.model.archived = datetime.now(timezone.utc) + self.model.save() + self.client.force_login(self.user) + + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:metrics-create'), + data={ + 'model_version_id': str(self.model_version.id), + 'metrics': [ + { + 'name': 'a test metric', + 'value': 12 + }, + { + 'name': 'another metric', + 'value': 3, + 'mode': 'point' + } + ] + }, + format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {"model_version_id": ["This ModelVersion is part of an archived model."]}) + def test_bulk_create_metric_value_duplicate(self): """ Cannot send two metrics values for the same metric key / name diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index 1b4e2c80e677c157198c11986d92f4ecefbe19b1..0d52aab947ec45dc11f68f691b7b1f1e58d6ac1a 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -91,32 +91,17 @@ class TestModelAPI(FixtureAPITestCase): Right(group=cls.group1, content_object=cls.model2, level=100), ]) - @property - def model_version_create_request(self): - return { - 'tag': 'TAG', - 'description': 'description', - 'configuration': {'hello': 'this is me'}, - } - - @property - def model_version_update_request(self): - return { - 'hash': '94274e84f3de91d1645b1e082b5f3466', - 'archive_hash': '0958a74b060a89fc38318a9a96aef32a', - 'size': 8, - 'state': ModelVersionState.Available.value, - 'description': 'other description', - 'configuration': {'hi': 'Who am I ?'}, - } - def test_create_model_version_requires_verified(self): user = User.objects.create(email='not_verified@mail.com', display_name='Not Verified', verified_email=False) self.client.force_login(user) with self.assertNumQueries(2): response = self.client.post( reverse('api:model-versions', kwargs={'pk': str(self.model2.id)}), - self.model_version_create_request, + { + 'tag': 'TAG', + 'description': 'description', + 'configuration': {'hello': 'this is me'}, + }, format='json', ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -132,6 +117,17 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {'detail': 'You need a Contributor access to the model to create a new version.'}) + def test_create_model_version_archived(self): + self.model1.archived = timezone.now() + self.model1.save() + self.client.force_login(self.user1) + + with self.assertNumQueries(6): + response = self.client.post(reverse('api:model-versions', kwargs={'pk': str(self.model1.id)})) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {'model': ['This model is archived.']}) + @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url') def test_create_model_version_empty_fields(self, s3_presigned_url_mock): """ @@ -218,7 +214,11 @@ class TestModelAPI(FixtureAPITestCase): with self.assertNumQueries(8): response = self.client.post( reverse('api:model-versions', kwargs={'pk': str(self.model1.id)}), - self.model_version_create_request, + { + 'tag': 'TAG', + 'description': 'description', + 'configuration': {'hello': 'this is me'}, + }, format='json', ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -298,7 +298,8 @@ class TestModelAPI(FixtureAPITestCase): 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), 'name': 'First Model', 'description': 'first', - 'rights': ['read'] + 'rights': ['read'], + 'archived': False, }) def test_retrieve_model(self): @@ -316,7 +317,8 @@ class TestModelAPI(FixtureAPITestCase): 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), 'name': 'Second Model', 'description': '', - 'rights': ['read'] + 'rights': ['read'], + 'archived': False, }) def test_update_model_requires_login(self): @@ -360,6 +362,66 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {'name': ['A model with this name already exists']}) + def test_update_model_archived_requires_archivable(self): + self.assertFalse(self.model1.is_archivable(self.user2)) + self.client.force_login(self.user2) + + cases = [ + (timezone.now(), False), + (None, True), + ] + + for current_value, new_value in cases: + with self.subTest(current_value=current_value, new_value=new_value): + self.model1.archived = current_value + self.model1.save() + + with self.assertNumQueries(8): + response = self.client.put( + reverse('api:model-retrieve', kwargs={'pk': str(self.model1.id)}), + {'name': 'new name', 'description': 'test', 'archived': new_value}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {'archived': ['You are not allowed to archive or unarchive this model.']}) + + def test_update_model_archived(self): + self.assertTrue(self.model1.is_archivable(self.user1)) + self.client.force_login(self.user1) + + cases = [ + (timezone.now(), False), + (None, True), + ] + + for current_value, new_value in cases: + with self.subTest(current_value=current_value, new_value=new_value): + self.model1.archived = current_value + self.model1.save() + + with self.assertNumQueries(10): + response = self.client.put( + reverse('api:model-retrieve', kwargs={'pk': str(self.model1.id)}), + {'name': 'new name', 'description': 'test', 'archived': new_value}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.model1.refresh_from_db() + self.assertDictEqual( + response.json(), + { + 'id': str(self.model1.id), + 'description': 'test', + 'name': 'new name', + 'rights': ['read', 'write', 'admin'], + 'created': self.model1.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), + 'archived': new_value, + } + ) + def test_update_model(self): self.assertFalse(self.model1.public) self.client.force_login(self.user1) @@ -383,16 +445,17 @@ class TestModelAPI(FixtureAPITestCase): 'rights': ['read', 'write', 'admin'], 'created': self.model1.created.isoformat().replace('+00:00', 'Z'), 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), + 'archived': False, } ) - def test_patch_model_requires_login(self): + def test_partial_update_model_requires_login(self): with self.assertNumQueries(0): response = self.client.patch(reverse('api:model-retrieve', kwargs={'pk': str(self.model2.id)})) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {'detail': 'Authentication credentials were not provided.'}) - def test_patch_model_requires_verified(self): + def test_partial_update_model_requires_verified(self): self.user3.verified_email = False self.user3.save() self.client.force_login(self.user3) @@ -402,7 +465,7 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {'detail': 'You do not have permission to perform this action.'}) - def test_patch_model_requires_contrib(self): + def test_partial_update_model_requires_contrib(self): self.assertFalse(self.model1.public) self.model1.memberships.create(user=self.user3, level=Role.Guest.value) self.client.force_login(self.user3) @@ -417,7 +480,7 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual((self.model1.name, self.model1.description), ('First Model', 'first')) self.assertDictEqual(response.json(), {'detail': 'You do not have a contributor access to this model.'}) - def test_patch_model(self): + def test_partial_update_model(self): self.client.force_login(self.user1) with self.assertNumQueries(8): response = self.client.patch( @@ -439,9 +502,70 @@ class TestModelAPI(FixtureAPITestCase): 'rights': ['read', 'write', 'admin'], 'created': self.model1.created.isoformat().replace('+00:00', 'Z'), 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), + 'archived': False, } ) + def test_partial_update_model_archived_requires_archivable(self): + self.assertFalse(self.model1.is_archivable(self.user2)) + self.client.force_login(self.user2) + + cases = [ + (timezone.now(), False), + (None, True), + ] + + for current_value, new_value in cases: + with self.subTest(current_value=current_value, new_value=new_value): + self.model1.archived = current_value + self.model1.save() + + with self.assertNumQueries(7): + response = self.client.patch( + reverse('api:model-retrieve', kwargs={'pk': str(self.model1.id)}), + {'archived': new_value}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {'archived': ['You are not allowed to archive or unarchive this model.']}) + + def test_partial_update_model_archived(self): + self.assertTrue(self.model1.is_archivable(self.user1)) + self.client.force_login(self.user1) + + cases = [ + (timezone.now(), False), + (None, True), + ] + + for current_value, new_value in cases: + with self.subTest(current_value=current_value, new_value=new_value): + self.model1.archived = current_value + self.model1.save() + + with self.assertNumQueries(9): + response = self.client.patch( + reverse('api:model-retrieve', kwargs={'pk': str(self.model1.id)}), + {'archived': new_value}, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.model1.refresh_from_db() + self.assertDictEqual( + response.json(), + { + 'id': str(self.model1.id), + 'description': 'first', + 'name': 'First Model', + 'rights': ['read', 'write', 'admin'], + 'created': self.model1.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), + 'archived': new_value, + } + ) + def test_list_model_versions_requires_logged_in(self): """To list a model's versions, you need to be logged in. """ @@ -492,6 +616,17 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {'detail': 'You need an Admin access to the model to destroy this version.'}) + def test_destroy_model_versions_archived(self): + self.model1.archived = timezone.now() + self.model1.save() + self.client.force_login(self.user1) + + with self.assertNumQueries(6): + response = self.client.delete(reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)})) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertListEqual(response.json(), ['This model is archived.']) + def test_destroy_model_versions(self): """To destroy a model version, you need admin rights on the model. This also deletes every worker run that used this model version @@ -600,6 +735,22 @@ class TestModelAPI(FixtureAPITestCase): {'detail': 'You need a Contributor access to the model to update this version.'} ) + @patch('arkindex.project.aws.s3.Object') + @patch('arkindex.project.aws.S3FileMixin.exists') + def test_partial_update_model_version_archived(self, exists, s3_object): + s3_object().content_length = self.model_version3.size + s3_object().e_tag = self.model_version3.archive_hash + exists.return_value = True + self.model1.archived = timezone.now() + self.model1.save() + self.client.force_login(self.user1) + + with self.assertNumQueries(6): + response = self.client.patch(reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)}), {'state': 'available'}) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {'model': ['This model is archived.']}) + @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url') @patch('arkindex.project.aws.S3FileMixin.exists') @patch('arkindex.project.aws.s3.Object') @@ -748,12 +899,46 @@ class TestModelAPI(FixtureAPITestCase): with self.assertNumQueries(6): response = self.client.put( reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version3.id)}), - self.model_version_update_request, + { + 'hash': '94274e84f3de91d1645b1e082b5f3466', + 'archive_hash': '0958a74b060a89fc38318a9a96aef32a', + 'size': 8, + 'state': ModelVersionState.Available.value, + 'description': 'other description', + 'configuration': {'hi': 'Who am I ?'}, + }, format='json', ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {'detail': 'You need a Contributor access to the model to update this version.'}) + @patch('arkindex.project.aws.s3.Object') + @patch('arkindex.project.aws.S3FileMixin.exists') + def test_update_model_version_archived(self, exists, s3_object): + s3_object().content_length = self.model_version3.size + s3_object().e_tag = self.model_version3.archive_hash + exists.return_value = True + self.model1.archived = timezone.now() + self.model1.save() + self.client.force_login(self.user1) + + with self.assertNumQueries(6): + response = self.client.put( + reverse('api:model-version-retrieve', kwargs={'pk': str(self.model_version1.id)}), + { + 'hash': '94274e84f3de91d1645b1e082b5f3466', + 'archive_hash': '0958a74b060a89fc38318a9a96aef32a', + 'size': 8, + 'state': ModelVersionState.Available.value, + 'description': 'other description', + 'configuration': {'hi': 'Who am I ?'}, + }, + format='json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertDictEqual(response.json(), {'model': ['This model is archived.']}) + @patch('arkindex.project.aws.S3FileMixin.exists') def test_validate_model_version_requires_contributor(self, exists): self.client.force_login(self.user3) @@ -770,6 +955,23 @@ class TestModelAPI(FixtureAPITestCase): 'detail': 'You need a Contributor access to the model to validate this version.' }) + @patch('arkindex.project.aws.S3FileMixin.exists') + def test_validate_model_version_archived(self, exists): + self.model1.archived = timezone.now() + self.model1.save() + self.client.force_login(self.user1) + exists.return_value = False + + with self.assertNumQueries(6): + response = self.client.post( + reverse('api:model-version-validate', kwargs={'pk': str(self.model_version1.id)}), + {'archive_hash': 'x' * 32, 'hash': 'y' * 32, 'size': 32}, + format='json' + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), ['This model version is part of an archived model.']) + @patch('arkindex.project.aws.S3FileMixin.exists') def test_validate_model_version_required_fields(self, exists): self.client.force_login(self.user1) @@ -997,7 +1199,8 @@ class TestModelAPI(FixtureAPITestCase): 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), 'name': 'Second Model', 'description': '', - 'rights': ['read'] + 'rights': ['read'], + 'archived': False, } ]) @@ -1019,7 +1222,8 @@ class TestModelAPI(FixtureAPITestCase): 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), 'name': 'First Model', 'description': 'first', - 'rights': ['read', 'write'] + 'rights': ['read', 'write'], + 'archived': False, }, { 'id': str(self.model2.id), @@ -1027,7 +1231,8 @@ class TestModelAPI(FixtureAPITestCase): 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), 'name': 'Second Model', 'description': '', - 'rights': ['read', 'write', 'admin'] + 'rights': ['read', 'write', 'admin'], + 'archived': False, } ]) @@ -1047,7 +1252,8 @@ class TestModelAPI(FixtureAPITestCase): 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), 'name': 'Second Model', 'description': '', - 'rights': ['read', 'write', 'admin'] + 'rights': ['read', 'write', 'admin'], + 'archived': False, } ]) @@ -1065,10 +1271,47 @@ class TestModelAPI(FixtureAPITestCase): 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), 'name': 'Second Model', 'description': '', - 'rights': ['read', 'write', 'admin'] + 'rights': ['read', 'write', 'admin'], + 'archived': False, } ]) + def test_list_models_filter_archived(self): + self.client.force_login(self.user2) + self.model1.archived = timezone.now() + self.model1.save() + + cases = [ + (True, [ + { + 'id': str(self.model1.id), + 'created': self.model1.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model1.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'First Model', + 'description': 'first', + 'rights': ['read', 'write'], + 'archived': True, + } + ]), + (False, [ + { + 'id': str(self.model2.id), + 'created': self.model2.created.isoformat().replace('+00:00', 'Z'), + 'updated': self.model2.updated.isoformat().replace('+00:00', 'Z'), + 'name': 'Second Model', + 'description': '', + 'rights': ['read', 'write', 'admin'], + 'archived': False, + } + ]), + ] + + for archived, expected_results in cases: + with self.subTest(archived=archived), self.assertNumQueries(7): + response = self.client.get(reverse('api:models'), {'archived': archived}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()['results'], expected_results) + def test_list_models_filter_compatible_worker_doesnt_exist(self): self.client.force_login(self.user2) with self.assertNumQueries(5): diff --git a/arkindex/training/tests/test_model_compatible_worker.py b/arkindex/training/tests/test_model_compatible_worker.py index 39c316ef9af9d3e417841f701eed6ada95938659..0231a496da16f888d3a9607da94065224a0b13ec 100644 --- a/arkindex/training/tests/test_model_compatible_worker.py +++ b/arkindex/training/tests/test_model_compatible_worker.py @@ -1,3 +1,5 @@ +from datetime import datetime, timezone + from django.urls import reverse from rest_framework import status @@ -276,6 +278,34 @@ class TestModelCompatibleWorkerManage(FixtureAPITestCase): ordered=False, ) + def test_create_archived(self): + self.model1.archived = datetime.now(timezone.utc) + self.model1.save() + self.client.force_login(self.user) + + with self.assertNumQueries(8): + response = self.client.post( + reverse('api:model-compatible-worker-manage', kwargs={ + 'model': str(self.model1.id), + 'worker': str(self.worker2.id), + }) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), { + 'model': ['This model is archived.'], + }) + self.assertQuerysetEqual( + self.model1.compatible_workers.all(), + [self.worker1], + ordered=False, + ) + self.assertQuerysetEqual( + self.model2.compatible_workers.all(), + [self.worker2], + ordered=False, + ) + def test_destroy_requires_login(self): with self.assertNumQueries(0): response = self.client.delete( @@ -516,3 +546,32 @@ class TestModelCompatibleWorkerManage(FixtureAPITestCase): [self.worker2], ordered=False, ) + + def test_destroy_archived(self): + self.model1.archived = datetime.now(timezone.utc) + self.model1.save() + self.model1.compatible_workers.add(self.worker2) + self.client.force_login(self.user) + + with self.assertNumQueries(8): + response = self.client.delete( + reverse('api:model-compatible-worker-manage', kwargs={ + 'model': str(self.model1.id), + 'worker': str(self.worker2.id), + }) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), { + 'model': ['This model is archived.'], + }) + self.assertQuerysetEqual( + self.model1.compatible_workers.all(), + [self.worker1, self.worker2], + ordered=False, + ) + self.assertQuerysetEqual( + self.model2.compatible_workers.all(), + [self.worker2], + ordered=False, + )