diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 0b76db44bc116caf6235b38e94c09a4229380a93..c1c13851aeb74d48e527060053e712ffbbabe4d7 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -574,7 +574,11 @@ class StartProcess(CorpusACLMixin, CreateAPIView): .objects .select_related('corpus') .filter(corpus_id__isnull=False) - .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version', 'model_version'))) + .prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related( + 'version__worker', + 'model_version', + 'configuration', + ))) .prefetch_related('datasets') # Uses Exists() for has_tasks and not a __isnull because we are not joining on tasks and do not need to fetch them .annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef('pk')))) diff --git a/arkindex/process/serializers/imports.py b/arkindex/process/serializers/imports.py index 0a3506f72fa2306a8d722d8b24c2b99077c1c814..662382004bb7972fb81ff32ef29330d8781a73b4 100644 --- a/arkindex/process/serializers/imports.py +++ b/arkindex/process/serializers/imports.py @@ -368,6 +368,7 @@ class StartProcessSerializer(serializers.Serializer): missing_model_versions = [] unavailable_versions = [] has_unavailable_model_versions = False + missing_required_configurations = [] for worker_run in self.instance.worker_runs.all(): if worker_run.version.model_usage and worker_run.model_version_id is None: @@ -380,12 +381,34 @@ class StartProcessSerializer(serializers.Serializer): if worker_run.model_version_id and worker_run.model_version.state != ModelVersionState.Available: has_unavailable_model_versions = True + # If the worker version has a user configuration, check that we aren't missing any required fields + if isinstance(worker_run.version.configuration.get('user_configuration'), dict): + required_fields = set( + name + for name, options in worker_run.version.configuration['user_configuration'].items() + # Only pick the fields that are required and without default values + if options.get('required') and 'default' not in options + ) + + # List all the fields defined on the WorkerRun's configuration if there is one + worker_run_fields = set() + if worker_run.configuration is not None: + worker_run_fields = set(worker_run.configuration.configuration.keys()) + + # Check that all the worker version's required fields are set in that configuration + if not required_fields.issubset(worker_run_fields): + missing_required_configurations.append(worker_run.version.worker.name) + if len(missing_model_versions) > 0: errors['model_version'].append(f"The following workers require a model version and none was set: {missing_model_versions}") if has_unavailable_model_versions: errors['model_version'].append('This process contains one or more unavailable model versions and cannot be started.') if len(unavailable_versions) > 0: errors['version'].append('This process contains one or more unavailable worker versions and cannot be started.') + if missing_required_configurations: + errors['worker_configuration'].append( + f'The following workers have required configuration fields that have not been set: {", ".join(sorted(missing_required_configurations))}', + ) else: if validated_data.get('worker_activity'): diff --git a/arkindex/process/tests/test_processes.py b/arkindex/process/tests/test_processes.py index e88d40fce636168af54bf5a8adbfb74f787c462a..84fbc98fe011e97e55b704e9863407da35d08ef4 100644 --- a/arkindex/process/tests/test_processes.py +++ b/arkindex/process/tests/test_processes.py @@ -122,10 +122,6 @@ class TestProcesses(FixtureAPITestCase): cls.import_worker_version = WorkerVersion.objects.get(worker__slug='file_import') - def setUp(self): - super().setUp() - self.maxDiff = None - def test_list_requires_login(self): with self.assertNumQueries(0): response = self.client.get(reverse('api:process-list')) @@ -2116,6 +2112,58 @@ class TestProcesses(FixtureAPITestCase): {'model_version': ['This process contains one or more unavailable model versions and cannot be started.']}, ) + def test_start_process_required_fields_no_config(self): + # Both workers now have a required field without a default value + self.dla.configuration['user_configuration'] = { + 'some_field': { + 'title': 'pls configure me', + 'type': 'bool', + 'required': True, + } + } + self.dla.save() + self.recognizer.configuration['user_configuration'] = { + 'some_field': { + 'title': 'pls configure me', + 'type': 'bool', + 'required': True, + } + } + self.recognizer.save() + + process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers) + # One worker version is used without any configuration + process2.worker_runs.create( + version=self.recognizer, + configuration=None, + model_version=None, + ) + # The other version is used with a configuration missing the required field + process2.worker_runs.create( + version=self.dla, + configuration=self.dla.worker.configurations.create( + name='oh no', + configuration={ + 'not_some_field': 'oops', + }, + ), + model_version=None, + ) + self.client.force_login(self.user) + + with self.assertNumQueries(8): + response = self.client.post( + reverse('api:process-start', kwargs={'pk': str(process2.id)}) + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(response.json(), { + 'worker_configuration': [ + 'The following workers have required configuration fields that have not been set: ' + 'Document layout analyser, Recognizer', + ], + }) + def test_start_process_workers(self): """ A user can start a process with no parameters.