diff --git a/arkindex/process/serializers/workers.py b/arkindex/process/serializers/workers.py index 9faef9fcefe3938a44c5ac498ed6ddae90ad834d..6a00e11b79a3ece24b90f75a001c0aab4b09e0fc 100644 --- a/arkindex/process/serializers/workers.py +++ b/arkindex/process/serializers/workers.py @@ -605,6 +605,7 @@ class DockerWorkerVersionSerializer(serializers.ModelSerializer): ) gpu_usage = EnumField(FeatureUsage, required=False, default=FeatureUsage.Disabled) model_usage = EnumField(FeatureUsage, required=False, default=FeatureUsage.Disabled) + configuration = serializers.DictField(required=False, default={}) class Meta: model = WorkerVersion @@ -630,6 +631,20 @@ class DockerWorkerVersionSerializer(serializers.ModelSerializer): } } + def validate_configuration(self, configuration): + errors = defaultdict(list) + user_configuration = configuration.get("user_configuration") + if not user_configuration: + return configuration + field = serializers.DictField(child=UserConfigurationFieldSerializer()) + try: + field.to_internal_value(user_configuration) + except ValidationError as e: + errors["user_configuration"].append(e.detail) + if errors: + raise ValidationError(errors) + return configuration + @transaction.atomic def create(self, validated_data): """ diff --git a/arkindex/process/tests/test_docker_worker_version.py b/arkindex/process/tests/test_docker_worker_version.py index 38c872135b0d169496c1afa5e0c9544ef36ca8a3..bea809bbacc6cf5ad02407964f8e1c85857b9133 100644 --- a/arkindex/process/tests/test_docker_worker_version.py +++ b/arkindex/process/tests/test_docker_worker_version.py @@ -6,6 +6,7 @@ from rest_framework import status from arkindex.ponos.models import Farm from arkindex.process.models import FeatureUsage, GitRefType, Process, ProcessMode, Repository, Worker, WorkerType from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Model from arkindex.users.models import Role, Scope @@ -128,6 +129,7 @@ class TestDockerWorkerVersion(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { + "configuration": ['Expected a dictionary of items but got type "str".'], "docker_image_iid": ["Not a valid string."], "gpu_usage": ["Value is not of type FeatureUsage"], "model_usage": ["Value is not of type FeatureUsage"], @@ -546,3 +548,209 @@ class TestDockerWorkerVersion(FixtureAPITestCase): self.assertListEqual(list(new_repo.memberships.values_list("user", "level")), [ (self.user.id, Role.Admin.value) ]) + + # Test user configuration + + def test_create_version_valid_user_configuration(self): + test_model = Model.objects.create( + name="Generic model", + public=False, + ) + self.client.force_login(self.user) + response = self.client.post( + reverse("api:version-from-docker"), + data={ + "docker_image_iid": "a_docker_image", + "repository_url": self.repo.url, + "revision_hash": "new_revision_hash", + "worker_slug": self.worker.slug, + "configuration": { + "user_configuration": { + "demo_integer": {"title": "Demo Integer", "type": "int", "required": True, "default": 1}, + "demo_boolean": {"title": "Demo Boolean", "type": "bool", "required": False, "default": True}, + "demo_dict": {"title": "Demo Dict", "type": "dict", "required": True, "default": {"a": "b", "c": "d"}}, + "demo_choice": {"title": "Decisions", "type": "enum", "required": True, "default": 1, "choices": [1, 2, 3]}, + "demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": [1, 2, 3, 4]}, + "boolean_list": {"title": "It's a list of booleans", "type": "list", "required": False, "subtype": "bool", "default": [True, False, False]}, + "demo_model": {"title": "Model for training", "type": "model", "required": True}, + "other_model": {"title": "Model the second", "type": "model", "default": str(test_model.id)} + } + }, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json()["configuration"], { + "user_configuration": { + "demo_integer": { + "title": "Demo Integer", + "type": "int", + "required": True, + "default": 1 + }, + "demo_boolean": { + "title": "Demo Boolean", + "type": "bool", + "required": False, + "default": True + }, + "demo_dict": { + "title": "Demo Dict", + "type": "dict", + "required": True, + "default": {"a": "b", "c": "d"} + }, + "demo_choice": { + "title": "Decisions", + "type": "enum", + "required": True, + "default": 1, + "choices": [1, 2, 3] + }, + "demo_list": { + "title": "Demo List", + "type": "list", + "subtype": "int", + "required": True, + "default": [1, 2, 3, 4] + }, + "boolean_list": { + "title": "It's a list of booleans", + "type": "list", + "subtype": "bool", + "required": False, + "default": [True, False, False] + }, + "demo_model": { + "title": "Model for training", + "type": "model", + "required": True + }, + "other_model": { + "title": "Model the second", + "type": "model", + "default": str(test_model.id) + } + } + }) + + def test_create_invalid_user_configuration(self): + cases = [ + ( + "non", + ['Expected a dictionary of items but got type "str".'] + ), + ( + {"demo_list": {"title": "Demo List", "type": "list", "required": True, "default": [1, 2, 3, 4]}}, + {"demo_list": { + "subtype": ['The "subtype" field must be set for "list" type properties.'] + }} + ), + ( + {"demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "dict", "default": [1, 2, 3, 4]}}, + {"demo_list": { + "subtype": ["Subtype can only be int, float, bool or string."] + }} + ), + ( + {"demo_list": {"title": "Demo List", "type": "list", "required": True, "subtype": "int", "default": [1, 2, "three", 4]}}, + {"demo_list": { + "default": ["All items in the default value must be of type int."] + }} + ), + ( + {"demo_choice": {"title": "Decisions", "type": "enum", "required": True, "default": 1, "choices": "eeee"}}, + {"demo_choice": { + "choices": ['Expected a list of items but got type "str".'] + }} + ), + ( + {"secrets": ["aaaaaa"]}, + {"secrets": {"__all__": ["User configuration field definitions should be of type dict, not list."]}} + ), + ( + {"something": {"title": "some thing", "type": "uh oh", "required": 2}}, + {"something": { + "required": ["Must be a valid boolean."], + "type": ["Value is not of type UserConfigurationFieldType"] + }} + ), + ( + {"something": {"title": "some thing", "type": "int", "required": False, "choices": [1, 2, 3]}}, + {"something": { + "choices": ['The "choices" field can only be set for an "enum" type property.'] + }} + ), + ( + {"demo_integer": {"type": "int", "required": True, "default": 1}}, + {"demo_integer": {"title": ["This field is required."]}} + ), + ( + {"demo_integer": {"title": "an integer", "type": "int", "required": True, "default": 1, "some_key": "oh no"}}, + {"demo_integer": { + "some_key": ["Configurable properties can only be defined using the following keys: title, type, required, default, subtype, choices."] + }} + ), + ( + {"param": {"title": "Model to train", "type": "model", "default": "12341234-1234-1234-1234-123412341234"}}, + {"param": {"default": ["Model 12341234-1234-1234-1234-123412341234 not found."]}} + ) + ] + + self.client.force_login(self.user) + for user_configuration, error in cases: + with self.subTest(user_configuration=user_configuration, error=error): + response = self.client.post( + reverse("api:version-from-docker"), + data={ + "docker_image_iid": "a_docker_image", + "repository_url": self.repo.url, + "revision_hash": "new_revision_hash", + "worker_slug": self.worker.slug, + "configuration": { + "user_configuration": user_configuration + }, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + "configuration": { + "user_configuration": [error] + } + }) + + def test_create_version_invalid_user_configuration_default_value(self): + self.client.force_login(self.user) + cases = [ + ({"type": "int", "default": False}, "int"), + ({"type": "int", "default": True}, "int"), + ({"type": "float", "default": False}, "float"), + ({"type": "float", "default": True}, "float"), + ({"type": "bool", "default": 0}, "bool"), + ({"type": "bool", "default": 1}, "bool"), + ({"type": "string", "default": 1}, "string"), + ({"type": "dict", "default": ["a", "b"]}, "dict"), + ({"type": "model", "default": "gigi hadid"}, "model"), + ({"type": "model", "default": False}, "model"), + ({"type": "list", "subtype": "int", "default": 12}, "list") + ] + for params, expected in cases: + with self.subTest(**params): + response = self.client.post( + reverse("api:version-from-docker"), + data={ + "docker_image_iid": "a_docker_image", + "repository_url": self.repo.url, + "revision_hash": "new_revision_hash", + "worker_slug": self.worker.slug, + "configuration": { + "user_configuration": { + "param": {"title": "param", **params} + } + }, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {"configuration": {"user_configuration": [{"param": {"default": [f"This is not a valid value for a field of type {expected}."]}}]}}) diff --git a/arkindex/process/tests/test_workers.py b/arkindex/process/tests/test_workers.py index b1a24c423fe0b298b581d5a28522514b9762c6ae..42efd23b2a115cce187872da243f93f9ca7508d4 100644 --- a/arkindex/process/tests/test_workers.py +++ b/arkindex/process/tests/test_workers.py @@ -1552,7 +1552,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): }, }) - def test_create_version_empty_configuration(self): + def test_create_version_configuration_wrong_type(self): """ Configuration body must be an object """