diff --git a/arkindex/process/serializers/workers.py b/arkindex/process/serializers/workers.py index 55549a04e5531d03ca1ecf833d34140d5d491ff6..7c91b90adc3f3cd08a3abc2781279e887e44e429 100644 --- a/arkindex/process/serializers/workers.py +++ b/arkindex/process/serializers/workers.py @@ -5,6 +5,7 @@ from collections import defaultdict from enum import Enum from textwrap import dedent +from django.core.exceptions import ValidationError as DjangoValidationError from django.db import transaction from django.db.models import Max, Q from drf_spectacular.utils import extend_schema_field @@ -32,6 +33,7 @@ from arkindex.process.models import ( from arkindex.process.serializers.git import GitRefSerializer, RevisionWithRefsSerializer from arkindex.process.utils import hash_object from arkindex.project.serializer_fields import ArchivedField, EnumField +from arkindex.training.models import Model from arkindex.training.serializers import ModelVersionLightSerializer from arkindex.users.models import Role from arkindex.users.utils import get_max_level @@ -160,6 +162,7 @@ class UserConfigurationFieldType(Enum): Boolean = 'bool' Dict = 'dict' List = 'list' + Model = 'model' class UserConfigurationFieldSerializer(serializers.Serializer): @@ -180,7 +183,8 @@ class UserConfigurationFieldSerializer(serializers.Serializer): UserConfigurationFieldType.Enum: [serializers.ChoiceField(choices=[]), None], UserConfigurationFieldType.Boolean: [serializers.BooleanField(), bool], UserConfigurationFieldType.Dict: [serializers.DictField(child=serializers.CharField()), dict], - UserConfigurationFieldType.List: [serializers.ListField(), list] + UserConfigurationFieldType.List: [serializers.ListField(), list], + UserConfigurationFieldType.Model: [serializers.UUIDField(), str] } if not isinstance(data, dict): @@ -221,16 +225,21 @@ class UserConfigurationFieldSerializer(serializers.Serializer): data_type, data_class = data_types[field_type] if not isinstance(default_value, data_class): raise ValidationError + # In the case of model fields, the validation error is raised here if the default value is a string but not a UUID data_type.to_internal_value(default_value) # For lists, check that list elements are of given subtype if field_type == UserConfigurationFieldType.List and not errors.get('subtype'): _, data_subclass = data_types[subtype] if any(not isinstance(item, data_subclass) for item in default_value): errors['default'].append(f'All items in the default value must be of type {data_subclass.__name__}.') - except ValidationError: - errors['default'].append(f'Default value is not of type {field_type.value}.') + except (ValidationError, DjangoValidationError): + errors['default'].append(f'This is not a valid value for a field of type {field_type.value}.') except KeyError: errors['default'].append(f'Cannot check type: {field_type.value}.') + # Check that the UUID is that of a model that exists, for Model fields + if field_type == UserConfigurationFieldType.Model and 'default' not in errors: + if not Model.objects.filter(id=default_value).exists(): + errors['default'].append(f'Model {default_value} not found.') if errors: raise ValidationError(errors) diff --git a/arkindex/process/tests/test_workers.py b/arkindex/process/tests/test_workers.py index 9202c56602411ef5cac9a325fde3aab4210a53aa..9d5f603c6589c0d1235891232fe76e90d8f35c4a 100644 --- a/arkindex/process/tests/test_workers.py +++ b/arkindex/process/tests/test_workers.py @@ -1618,6 +1618,38 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): } }) + def test_create_version_valid_user_configuration_model(self): + response = self.client.post( + reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), + data={ + "revision_id": str(self.rev2.id), + "configuration": { + "user_configuration": { + "demo_model": {"title": "Model for training", "type": "model", "required": True}, + "other_model": {"title": "Model the second", "type": "model", "default": str(self.model.id)} + } + }, + "gpu_usage": "disabled", + }, + format="json", + HTTP_AUTHORIZATION=f'Ponos {self.task.token}', + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json()['configuration'], { + "user_configuration": { + "demo_model": { + "title": "Model for training", + "type": "model", + "required": True + }, + "other_model": { + "title": "Model the second", + "type": "model", + "default": str(self.model.id) + } + } + }) + def test_create_version_invalid_user_configuration_list_requires_subtype(self): response = self.client.post( reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), @@ -1664,7 +1696,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): "configuration": { "user_configuration": [{ "demo_list": { - "default": ["Default value is not of type list."] + "default": ["This is not a valid value for a field of type list."] } }] } @@ -1843,7 +1875,7 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): "configuration": { "user_configuration": [{ "one_float": { - "default": ["Default value is not of type float."] + "default": ["This is not a valid value for a field of type float."] } }] } @@ -1952,7 +1984,9 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): ({"type": "bool", "default": 0}, 'bool'), ({"type": "bool", "default": 1}, 'bool'), ({"type": "string", "default": 1}, 'string'), - ({"type": "dict", "default": ["a", "b"]}, 'dict') + ({"type": "dict", "default": ["a", "b"]}, 'dict'), + ({"type": "model", "default": "gigi hadid"}, 'model'), + ({"type": "model", "default": False}, 'model') ] for params, expected in cases: with self.subTest(**params): @@ -1970,7 +2004,26 @@ class TestWorkersWorkerVersions(FixtureAPITestCase): HTTP_AUTHORIZATION=f'Ponos {self.task.token}', ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {"configuration": {"user_configuration": [{'param': {'default': [f'Default value is not of type {expected}.']}}]}}) + self.assertEqual(response.json(), {"configuration": {"user_configuration": [{'param': {'default': [f'This is not a valid value for a field of type {expected}.']}}]}}) + + def test_create_version_user_configuration_model_default_doesnt_exist(self): + response = self.client.post( + reverse("api:worker-versions", kwargs={"pk": str(self.worker_dla.id)}), + data={ + "revision_id": str(self.rev2.id), + "configuration": { + "user_configuration": { + "param": {"title": "Model to train", "type": "model", "default": "12341234-1234-1234-1234-123412341234"} + } + }, + }, + format="json", + HTTP_AUTHORIZATION=f'Ponos {self.task.token}', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {"configuration": {"user_configuration": [{'param': {'default': [ + 'Model 12341234-1234-1234-1234-123412341234 not found.' + ]}}]}}) def test_retrieve_version_invalid_id(self): self.client.force_login(self.user)