From ff5bc245118ee770c87bb64842f4d9bf8788c611 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 10 Sep 2019 10:34:14 +0000 Subject: [PATCH] Django REST Framework 3.10 and custom schema generation --- .gitlab-ci.yml | 9 +- Makefile | 7 - arkindex/dataimport/api.py | 2 +- arkindex/documents/api/elements.py | 2 +- arkindex/project/openapi.py | 88 +++++++++ arkindex/project/pagination.py | 8 + arkindex/project/settings.py | 1 + arkindex/project/tests/test_openapi.py | 260 +++++++++++++++++++++++++ openapi/Dockerfile | 8 - openapi/requirements.txt | 3 - openapi/run.sh | 3 - requirements.txt | 2 +- tests-requirements.txt | 1 + 13 files changed, 363 insertions(+), 31 deletions(-) create mode 100644 arkindex/project/openapi.py create mode 100644 arkindex/project/tests/test_openapi.py delete mode 100644 openapi/Dockerfile delete mode 100644 openapi/requirements.txt delete mode 100755 openapi/run.sh diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index aab1140bc0..cd7fd45182 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -45,16 +45,11 @@ backend-lint: backend-openapi: stage: build - image: registry.gitlab.com/arkindex/backend/openapi:latest - - before_script: - - pip uninstall -y arkindex-common ponos-server - - "pip install git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/common#egg=arkindex-common" - - "pip install git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/ponos#egg=ponos-server" script: - mkdir -p output - - pip install --no-deps -e . + - pip install -e . + - pip install uritemplate==3 apistar>=0.7.2 - arkindex/manage.py generateschema > output/original.yml - openapi/patch.py openapi/patch.yml output/original.yml > output/schema.yml diff --git a/Makefile b/Makefile index 2a0a2f70b1..9dffd519a4 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,6 @@ VERSION=$(shell git rev-parse --short HEAD) TAG_APP=arkindex-app TAG_BASE=arkindex-base TAG_SHELL=arkindex-shell -TAG_OPENAPI=arkindex-openapi .PHONY: build base all: clean build @@ -30,17 +29,12 @@ build: build-shell: docker build -t $(TAG_SHELL):$(VERSION) -t $(TAG_SHELL):latest $(ROOT_DIR)/shell -build-openapi: - docker build --no-cache -t $(TAG_OPENAPI):$(VERSION) -t $(TAG_OPENAPI):latest $(ROOT_DIR)/openapi - publish-version: require-docker-auth [ -f $(ROOT_DIR)/arkindex/project/local_settings.py ] && mv $(ROOT_DIR)/arkindex/project/local_settings.py $(ROOT_DIR)/arkindex/project/local_settings.py.bak || true $(MAKE) build TAG_APP=registry.gitlab.com/arkindex/backend $(MAKE) build-shell TAG_SHELL=registry.gitlab.com/arkindex/backend/shell - $(MAKE) build-openapi TAG_OPENAPI=registry.gitlab.com/arkindex/backend/openapi docker push registry.gitlab.com/arkindex/backend:$(VERSION) docker push registry.gitlab.com/arkindex/backend/shell:$(VERSION) - docker push registry.gitlab.com/arkindex/backend/openapi:$(VERSION) [ -f $(ROOT_DIR)/arkindex/project/local_settings.py.bak ] && mv $(ROOT_DIR)/arkindex/project/local_settings.py.bak $(ROOT_DIR)/arkindex/project/local_settings.py || true latest: @@ -51,7 +45,6 @@ release: $(MAKE) publish-version VERSION=$(version) docker push registry.gitlab.com/arkindex/backend:latest docker push registry.gitlab.com/arkindex/backend/shell:latest - docker push registry.gitlab.com/arkindex/backend/openapi:latest git tag $(version) tunnel: diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index 32895b8563..2649e7df59 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -268,7 +268,7 @@ class DataFileList(CorpusACLMixin, ListAPIView): permission_classes = (IsVerified, ) serializer_class = DataFileSerializer # Tell the OpenAPI schema generator to believe this view is a list - action = 'list' + action = 'List' def get_queryset(self): return DataFile.objects.filter(corpus=self.get_corpus(self.kwargs['pk'])).prefetch_related('images__server') diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index 67654a2197..6d64c1ee75 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -77,7 +77,7 @@ class RelatedElementsList(ListAPIView): """ serializer_class = ElementSlimSerializer # Tell the OpenAPI schema generator to believe this view is a list - action = 'list' + action = 'List' def get_queryset(self): filtering = { diff --git a/arkindex/project/openapi.py b/arkindex/project/openapi.py new file mode 100644 index 0000000000..2de3e67c09 --- /dev/null +++ b/arkindex/project/openapi.py @@ -0,0 +1,88 @@ +from enum import Enum +from rest_framework import serializers +from rest_framework.schemas.openapi import AutoSchema as BaseAutoSchema +import warnings + + +class AutoSchema(BaseAutoSchema): + """ + A custom view schema generator. + The docs on OpenAPI generation are currently very incomplete. + If you need to implement more features to avoid the "patch.py" and allow + views to customize their OpenAPI schemas by themselves, you may see for + yourself how DRF does the schema generation: + https://github.com/encode/django-rest-framework/blob/master/rest_framework/schemas/openapi.py + """ + + def _map_serializer(self, serializer): + """ + This is a temporary patch because the default schema generation does not handle + callable defaults in fields properly, and adds 'default=null' to any field that + does not have a default value, even required fields. + + https://github.com/encode/django-rest-framework/issues/6858 + """ + schema = super()._map_serializer(serializer) + for field_name in schema['properties']: + if 'default' not in schema['properties'][field_name]: + continue + default = schema['properties'][field_name]['default'] + + if hasattr(default, 'openapi_value'): + # Allow a 'openapi_value' attribute on the default for custom defaults + default = default.openapi_value + + elif callable(default): + # Try to call the callable default; if it does not work, warn, then remove it + try: + default = default() + except Exception as e: + warnings.warn('Unsupported callable default for field {}: {!s}'.format(field_name, e)) + del schema['properties'][field_name]['default'] + continue + + elif isinstance(default, Enum): + # Convert enums into their string values + default = default.value + + # Remove null defaults on required fields + if default is None and field_name in schema.get('required', []): + del schema['properties'][field_name]['default'] + else: + schema['properties'][field_name]['default'] = default + + return schema + + def _map_field(self, field): + """ + Yet another temporary patch because HStoreField is not properly treated as `type: object`. + + https://github.com/encode/django-rest-framework/issues/6913 + https://github.com/encode/django-rest-framework/pull/6914 + """ + schema = super()._map_field(field) + if isinstance(field, serializers.HStoreField): + assert schema['type'] == 'string' # Will crash when this is fixed upstream + return { + 'type': 'object' + } + return schema + + def get_operation(self, path, method): + operation = super().get_operation(path, method) + + # Operation IDs for list endpoints are improperly cased: listThings instead of ListThings + # https://github.com/encode/django-rest-framework/pull/6917 + if operation['operationId'][0].islower(): + operation['operationId'] = operation['operationId'][0].upper() + operation['operationId'][1:] + + # Setting deprecated = True on a View makes it deprecated in OpenAPI + if getattr(self.view, 'deprecated', False): + operation['deprecated'] = True + + # Allow an `openapi_overrides` attribute to override the operation's properties + # TODO: Quite crude, not enough to make overriding request/response bodies easy + if hasattr(self.view, 'openapi_overrides'): + operation.update(self.view.openapi_overrides) + + return operation diff --git a/arkindex/project/pagination.py b/arkindex/project/pagination.py index d430b96abe..c4160e51f2 100644 --- a/arkindex/project/pagination.py +++ b/arkindex/project/pagination.py @@ -18,3 +18,11 @@ class PageNumberPagination(pagination.PageNumberPagination): ('previous', self.get_previous_link()), ('results', data) ])) + + def get_paginated_response_schema(self, schema): + schema = super().get_paginated_response_schema(schema) + schema['properties']['number'] = { + 'type': 'integer', + 'example': 123, + } + return schema diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index 6de5148ec8..73ecbd1ff3 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -223,6 +223,7 @@ REST_FRAMEWORK = { 'ponos.authentication.AgentAuthentication', ), 'DEFAULT_PAGINATION_CLASS': 'arkindex.project.pagination.PageNumberPagination', + 'DEFAULT_SCHEMA_CLASS': 'arkindex.project.openapi.AutoSchema', 'PAGE_SIZE': 20, } diff --git a/arkindex/project/tests/test_openapi.py b/arkindex/project/tests/test_openapi.py new file mode 100644 index 0000000000..c874a732da --- /dev/null +++ b/arkindex/project/tests/test_openapi.py @@ -0,0 +1,260 @@ +from unittest import TestCase +from django.test import RequestFactory +from rest_framework import serializers +from rest_framework.request import Request +from rest_framework.views import APIView +from rest_framework.schemas.openapi import SchemaGenerator +from arkindex.project.serializer_fields import EnumField +from arkindex.project.openapi import AutoSchema +from arkindex_common.enums import DataImportMode +import warnings + + +# Helper methods taken straight from the DRF test suite +def create_request(path): + factory = RequestFactory() + request = Request(factory.get(path)) + return request + + +def create_view(view_cls, method, request): + generator = SchemaGenerator() + view = generator.create_view(view_cls.as_view(), method, request) + return view + + +class TestOpenAPI(TestCase): + + def test_no_deprecated(self): + """ + Test deprecated is omitted in schemas if the attribute is absent + """ + + class ThingView(APIView): + action = 'Retrieve' + + def get(self, *args, **kwargs): + pass + + inspector = AutoSchema() + inspector.view = create_view(ThingView, 'GET', create_request('/test/')) + self.assertDictEqual( + inspector.get_operation('GET', '/test/'), + { + 'operationId': 'RetrieveThing', + 'parameters': [], + 'responses': { + '200': { + 'content': { + 'application/json': { + 'schema': {} + } + }, + 'description': '', + } + } + } + ) + + def test_deprecated_attribute(self): + """ + Test the optional `deprecated` attribute on views + """ + + class ThingView(APIView): + action = 'Retrieve' + deprecated = True + + def get(self, *args, **kwargs): + pass + + inspector = AutoSchema() + inspector.view = create_view(ThingView, 'GET', create_request('/test/')) + self.assertDictEqual( + inspector.get_operation('GET', '/test/'), + { + 'operationId': 'RetrieveThing', + 'parameters': [], + 'responses': { + '200': { + 'content': { + 'application/json': { + 'schema': {} + } + }, + 'description': '', + } + }, + 'deprecated': True, + } + ) + + def test_overrides(self): + """ + Test the optional `openapi_overrides` attribute on views + """ + + class ThingView(APIView): + action = 'Retrieve' + openapi_overrides = { + 'operationId': 'HaltAndCatchFire', + 'tags': ['bad-ideas'], + } + + def get(self, *args, **kwargs): + pass + + inspector = AutoSchema() + inspector.view = create_view(ThingView, 'GET', create_request('/test/')) + self.assertDictEqual( + inspector.get_operation('GET', '/test/'), + { + 'operationId': 'HaltAndCatchFire', + 'parameters': [], + 'responses': { + '200': { + 'content': { + 'application/json': { + 'schema': {} + } + }, + 'description': '', + } + }, + 'tags': ['bad-ideas'], + } + ) + + def test_bugfix_list_uppercase(self): + """ + Test list API views have title-cased endpoint names + """ + + class ThingView(APIView): + action = 'list' + + def get(self, *args, **kwargs): + pass + + inspector = AutoSchema() + inspector.view = create_view(ThingView, 'GET', create_request('/test/')) + self.assertDictEqual( + inspector.get_operation('GET', '/test/'), + { + 'operationId': 'ListThings', + 'parameters': [], + 'responses': { + '200': { + 'content': { + 'application/json': { + 'schema': { + 'type': 'array', + 'items': {}, + } + } + }, + 'description': '', + } + }, + } + ) + + def test_bugfix_hstorefield(self): + """ + Test the temporary fix for serializers.HStoreField being treated as a string + """ + inspector = AutoSchema() + self.assertDictEqual( + inspector._map_field(serializers.HStoreField()), + {'type': 'object'}, + ) + # Other fields should be unaffected + self.assertDictEqual( + inspector._map_field(serializers.CharField()), + {'type': 'string'}, + ) + + def test_bugfix_callable_defaults(self): + """ + Test the temporary fix to convert callable defaults into their results + """ + + def bad_default(): + raise Exception('Nope') + + def fancy_default(): + raise Exception("Don't touch me!") + + fancy_default.openapi_value = 1337.0 + + class ThingSerializer(serializers.Serializer): + my_field = serializers.FloatField(default=float) + my_field_on_fire = serializers.FloatField(default=bad_default) + my_fancy_field = serializers.FloatField(default=fancy_default) + + inspector = AutoSchema() + with warnings.catch_warnings(record=True) as warn_list: + schema = inspector._map_serializer(ThingSerializer()) + + self.assertEqual(len(warn_list), 1) + self.assertEqual( + str(warn_list[0].message), + 'Unsupported callable default for field my_field_on_fire: Nope', + ) + + self.assertDictEqual(schema, { + 'properties': { + 'my_field': { + 'type': 'number', + 'default': 0.0, + }, + 'my_field_on_fire': { + 'type': 'number', + }, + 'my_fancy_field': { + 'type': 'number', + 'default': 1337.0, + }, + }, + }) + + def test_bugfix_enum_defaults(self): + """ + Test the temporary fix to convert enums into strings + """ + + class ThingSerializer(serializers.Serializer): + my_field = EnumField(DataImportMode, default=DataImportMode.Images) + + inspector = AutoSchema() + self.assertDictEqual(inspector._map_serializer(ThingSerializer()), { + 'properties': { + 'my_field': { + 'default': 'images', + 'enum': [ + 'images', + 'pdf', + 'repository', + 'elements', + ] + } + } + }) + + def test_bugfix_null_defaults(self): + """ + Test the temporary fix that hides None defaults when they do not actually exist + """ + + class ThingSerializer(serializers.Serializer): + my_field = serializers.CharField() + + inspector = AutoSchema() + self.assertDictEqual(inspector._map_serializer(ThingSerializer()), { + 'properties': { + 'my_field': { + 'type': 'string', + } + }, + 'required': ['my_field'], + }) diff --git a/openapi/Dockerfile b/openapi/Dockerfile deleted file mode 100644 index f778aee54e..0000000000 --- a/openapi/Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -FROM registry.gitlab.com/arkindex/backend:latest - -RUN pip uninstall -y djangorestframework -COPY ["patch.py", "run.sh", "requirements.txt", "patch.yml", "/"] -RUN pip install -r /requirements.txt && rm /requirements.txt - -ENTRYPOINT ["/bin/sh", "-c"] -CMD ["/run.sh"] diff --git a/openapi/requirements.txt b/openapi/requirements.txt deleted file mode 100644 index ceabbab7b9..0000000000 --- a/openapi/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -git+https://github.com/encode/django-rest-framework.git@37f210a455cc92cb3f61a23e194a1d0de58d149b#egg=djangorestframework -coreapi==2.3.3 -apistar>=0.7.2 diff --git a/openapi/run.sh b/openapi/run.sh deleted file mode 100755 index 47c4f7391c..0000000000 --- a/openapi/run.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh -PONOS_DATA_DIR=/tmp manage.py generateschema > original.yml -./patch.py patch.yml original.yml diff --git a/requirements.txt b/requirements.txt index a865b2167c..4e9d9ec810 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ django-admin-hstore-widget==1.0.1 django-cors-headers==2.4.0 django-enumfields==1.0.0 django-filter==2.2.0 -djangorestframework==3.9.2 +djangorestframework==3.10.3 elasticsearch-dsl>=6.0.0,<7.0.0 et-xmlfile==1.0.1 gitpython==2.1.11 diff --git a/tests-requirements.txt b/tests-requirements.txt index 3c40425e7b..243c31c1b2 100644 --- a/tests-requirements.txt +++ b/tests-requirements.txt @@ -2,3 +2,4 @@ flake8==3.6.0 tripoli django-nose coverage +uritemplate==3 -- GitLab