From ff5bc245118ee770c87bb64842f4d9bf8788c611 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Tue, 10 Sep 2019 10:34:14 +0000
Subject: [PATCH] Django REST Framework 3.10 and custom schema generation

---
 .gitlab-ci.yml                         |   9 +-
 Makefile                               |   7 -
 arkindex/dataimport/api.py             |   2 +-
 arkindex/documents/api/elements.py     |   2 +-
 arkindex/project/openapi.py            |  88 +++++++++
 arkindex/project/pagination.py         |   8 +
 arkindex/project/settings.py           |   1 +
 arkindex/project/tests/test_openapi.py | 260 +++++++++++++++++++++++++
 openapi/Dockerfile                     |   8 -
 openapi/requirements.txt               |   3 -
 openapi/run.sh                         |   3 -
 requirements.txt                       |   2 +-
 tests-requirements.txt                 |   1 +
 13 files changed, 363 insertions(+), 31 deletions(-)
 create mode 100644 arkindex/project/openapi.py
 create mode 100644 arkindex/project/tests/test_openapi.py
 delete mode 100644 openapi/Dockerfile
 delete mode 100644 openapi/requirements.txt
 delete mode 100755 openapi/run.sh

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index aab1140bc0..cd7fd45182 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -45,16 +45,11 @@ backend-lint:
 
 backend-openapi:
   stage: build
-  image: registry.gitlab.com/arkindex/backend/openapi:latest
-
-  before_script:
-    - pip uninstall -y arkindex-common ponos-server
-    - "pip install git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/common#egg=arkindex-common"
-    - "pip install git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/ponos#egg=ponos-server"
 
   script:
     - mkdir -p output
-    - pip install --no-deps -e .
+    - pip install -e .
+    - pip install uritemplate==3 apistar>=0.7.2
     - arkindex/manage.py generateschema > output/original.yml
     - openapi/patch.py openapi/patch.yml output/original.yml > output/schema.yml
 
diff --git a/Makefile b/Makefile
index 2a0a2f70b1..9dffd519a4 100644
--- a/Makefile
+++ b/Makefile
@@ -8,7 +8,6 @@ VERSION=$(shell git rev-parse --short HEAD)
 TAG_APP=arkindex-app
 TAG_BASE=arkindex-base
 TAG_SHELL=arkindex-shell
-TAG_OPENAPI=arkindex-openapi
 .PHONY: build base
 
 all: clean build
@@ -30,17 +29,12 @@ build:
 build-shell:
 	docker build -t $(TAG_SHELL):$(VERSION) -t $(TAG_SHELL):latest $(ROOT_DIR)/shell
 
-build-openapi:
-	docker build --no-cache -t $(TAG_OPENAPI):$(VERSION) -t $(TAG_OPENAPI):latest $(ROOT_DIR)/openapi
-
 publish-version: require-docker-auth
 	[ -f $(ROOT_DIR)/arkindex/project/local_settings.py ] && mv $(ROOT_DIR)/arkindex/project/local_settings.py $(ROOT_DIR)/arkindex/project/local_settings.py.bak || true
 	$(MAKE) build TAG_APP=registry.gitlab.com/arkindex/backend
 	$(MAKE) build-shell TAG_SHELL=registry.gitlab.com/arkindex/backend/shell
-	$(MAKE) build-openapi TAG_OPENAPI=registry.gitlab.com/arkindex/backend/openapi
 	docker push registry.gitlab.com/arkindex/backend:$(VERSION)
 	docker push registry.gitlab.com/arkindex/backend/shell:$(VERSION)
-	docker push registry.gitlab.com/arkindex/backend/openapi:$(VERSION)
 	[ -f $(ROOT_DIR)/arkindex/project/local_settings.py.bak ] && mv $(ROOT_DIR)/arkindex/project/local_settings.py.bak $(ROOT_DIR)/arkindex/project/local_settings.py || true
 
 latest:
@@ -51,7 +45,6 @@ release:
 	$(MAKE) publish-version VERSION=$(version)
 	docker push registry.gitlab.com/arkindex/backend:latest
 	docker push registry.gitlab.com/arkindex/backend/shell:latest
-	docker push registry.gitlab.com/arkindex/backend/openapi:latest
 	git tag $(version)
 
 tunnel:
diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py
index 32895b8563..2649e7df59 100644
--- a/arkindex/dataimport/api.py
+++ b/arkindex/dataimport/api.py
@@ -268,7 +268,7 @@ class DataFileList(CorpusACLMixin, ListAPIView):
     permission_classes = (IsVerified, )
     serializer_class = DataFileSerializer
     # Tell the OpenAPI schema generator to believe this view is a list
-    action = 'list'
+    action = 'List'
 
     def get_queryset(self):
         return DataFile.objects.filter(corpus=self.get_corpus(self.kwargs['pk'])).prefetch_related('images__server')
diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py
index 67654a2197..6d64c1ee75 100644
--- a/arkindex/documents/api/elements.py
+++ b/arkindex/documents/api/elements.py
@@ -77,7 +77,7 @@ class RelatedElementsList(ListAPIView):
     """
     serializer_class = ElementSlimSerializer
     # Tell the OpenAPI schema generator to believe this view is a list
-    action = 'list'
+    action = 'List'
 
     def get_queryset(self):
         filtering = {
diff --git a/arkindex/project/openapi.py b/arkindex/project/openapi.py
new file mode 100644
index 0000000000..2de3e67c09
--- /dev/null
+++ b/arkindex/project/openapi.py
@@ -0,0 +1,88 @@
+from enum import Enum
+from rest_framework import serializers
+from rest_framework.schemas.openapi import AutoSchema as BaseAutoSchema
+import warnings
+
+
+class AutoSchema(BaseAutoSchema):
+    """
+    A custom view schema generator.
+    The docs on OpenAPI generation are currently very incomplete.
+    If you need to implement more features to avoid the "patch.py" and allow
+    views to customize their OpenAPI schemas by themselves, you may see for
+    yourself how DRF does the schema generation:
+    https://github.com/encode/django-rest-framework/blob/master/rest_framework/schemas/openapi.py
+    """
+
+    def _map_serializer(self, serializer):
+        """
+        This is a temporary patch because the default schema generation does not handle
+        callable defaults in fields properly, and adds 'default=null' to any field that
+        does not have a default value, even required fields.
+
+        https://github.com/encode/django-rest-framework/issues/6858
+        """
+        schema = super()._map_serializer(serializer)
+        for field_name in schema['properties']:
+            if 'default' not in schema['properties'][field_name]:
+                continue
+            default = schema['properties'][field_name]['default']
+
+            if hasattr(default, 'openapi_value'):
+                # Allow a 'openapi_value' attribute on the default for custom defaults
+                default = default.openapi_value
+
+            elif callable(default):
+                # Try to call the callable default; if it does not work, warn, then remove it
+                try:
+                    default = default()
+                except Exception as e:
+                    warnings.warn('Unsupported callable default for field {}: {!s}'.format(field_name, e))
+                    del schema['properties'][field_name]['default']
+                    continue
+
+            elif isinstance(default, Enum):
+                # Convert enums into their string values
+                default = default.value
+
+            # Remove null defaults on required fields
+            if default is None and field_name in schema.get('required', []):
+                del schema['properties'][field_name]['default']
+            else:
+                schema['properties'][field_name]['default'] = default
+
+        return schema
+
+    def _map_field(self, field):
+        """
+        Yet another temporary patch because HStoreField is not properly treated as `type: object`.
+
+        https://github.com/encode/django-rest-framework/issues/6913
+        https://github.com/encode/django-rest-framework/pull/6914
+        """
+        schema = super()._map_field(field)
+        if isinstance(field, serializers.HStoreField):
+            assert schema['type'] == 'string'  # Will crash when this is fixed upstream
+            return {
+                'type': 'object'
+            }
+        return schema
+
+    def get_operation(self, path, method):
+        operation = super().get_operation(path, method)
+
+        # Operation IDs for list endpoints are improperly cased: listThings instead of ListThings
+        # https://github.com/encode/django-rest-framework/pull/6917
+        if operation['operationId'][0].islower():
+            operation['operationId'] = operation['operationId'][0].upper() + operation['operationId'][1:]
+
+        # Setting deprecated = True on a View makes it deprecated in OpenAPI
+        if getattr(self.view, 'deprecated', False):
+            operation['deprecated'] = True
+
+        # Allow an `openapi_overrides` attribute to override the operation's properties
+        # TODO: Quite crude, not enough to make overriding request/response bodies easy
+        if hasattr(self.view, 'openapi_overrides'):
+            operation.update(self.view.openapi_overrides)
+
+        return operation
diff --git a/arkindex/project/pagination.py b/arkindex/project/pagination.py
index d430b96abe..c4160e51f2 100644
--- a/arkindex/project/pagination.py
+++ b/arkindex/project/pagination.py
@@ -18,3 +18,11 @@ class PageNumberPagination(pagination.PageNumberPagination):
             ('previous', self.get_previous_link()),
             ('results', data)
         ]))
+
+    def get_paginated_response_schema(self, schema):
+        schema = super().get_paginated_response_schema(schema)
+        schema['properties']['number'] = {
+            'type': 'integer',
+            'example': 123,
+        }
+        return schema
diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py
index 6de5148ec8..73ecbd1ff3 100644
--- a/arkindex/project/settings.py
+++ b/arkindex/project/settings.py
@@ -223,6 +223,7 @@ REST_FRAMEWORK = {
         'ponos.authentication.AgentAuthentication',
     ),
     'DEFAULT_PAGINATION_CLASS': 'arkindex.project.pagination.PageNumberPagination',
+    'DEFAULT_SCHEMA_CLASS': 'arkindex.project.openapi.AutoSchema',
     'PAGE_SIZE': 20,
 }
 
diff --git a/arkindex/project/tests/test_openapi.py b/arkindex/project/tests/test_openapi.py
new file mode 100644
index 0000000000..c874a732da
--- /dev/null
+++ b/arkindex/project/tests/test_openapi.py
@@ -0,0 +1,260 @@
+from unittest import TestCase
+from django.test import RequestFactory
+from rest_framework import serializers
+from rest_framework.request import Request
+from rest_framework.views import APIView
+from rest_framework.schemas.openapi import SchemaGenerator
+from arkindex.project.serializer_fields import EnumField
+from arkindex.project.openapi import AutoSchema
+from arkindex_common.enums import DataImportMode
+import warnings
+
+
+# Helper methods taken straight from the DRF test suite
+def create_request(path):
+    factory = RequestFactory()
+    request = Request(factory.get(path))
+    return request
+
+
+def create_view(view_cls, method, request):
+    generator = SchemaGenerator()
+    view = generator.create_view(view_cls.as_view(), method, request)
+    return view
+
+
+class TestOpenAPI(TestCase):
+
+    def test_no_deprecated(self):
+        """
+        Test deprecated is omitted in schemas if the attribute is absent
+        """
+
+        class ThingView(APIView):
+            action = 'Retrieve'
+
+            def get(self, *args, **kwargs):
+                pass
+
+        inspector = AutoSchema()
+        inspector.view = create_view(ThingView, 'GET', create_request('/test/'))
+        self.assertDictEqual(
+            inspector.get_operation('GET', '/test/'),
+            {
+                'operationId': 'RetrieveThing',
+                'parameters': [],
+                'responses': {
+                  '200': {
+                    'content': {
+                        'application/json': {
+                            'schema': {}
+                        }
+                    },
+                    'description': '',
+                  }
+                }
+            }
+        )
+
+    def test_deprecated_attribute(self):
+        """
+        Test the optional `deprecated` attribute on views
+        """
+
+        class ThingView(APIView):
+            action = 'Retrieve'
+            deprecated = True
+
+            def get(self, *args, **kwargs):
+                pass
+
+        inspector = AutoSchema()
+        inspector.view = create_view(ThingView, 'GET', create_request('/test/'))
+        self.assertDictEqual(
+            inspector.get_operation('GET', '/test/'),
+            {
+                'operationId': 'RetrieveThing',
+                'parameters': [],
+                'responses': {
+                  '200': {
+                    'content': {
+                        'application/json': {
+                            'schema': {}
+                        }
+                    },
+                    'description': '',
+                  }
+                },
+                'deprecated': True,
+            }
+        )
+
+    def test_overrides(self):
+        """
+        Test the optional `openapi_overrides` attribute on views
+        """
+
+        class ThingView(APIView):
+            action = 'Retrieve'
+            openapi_overrides = {
+                'operationId': 'HaltAndCatchFire',
+                'tags': ['bad-ideas'],
+            }
+
+            def get(self, *args, **kwargs):
+                pass
+
+        inspector = AutoSchema()
+        inspector.view = create_view(ThingView, 'GET', create_request('/test/'))
+        self.assertDictEqual(
+            inspector.get_operation('GET', '/test/'),
+            {
+                'operationId': 'HaltAndCatchFire',
+                'parameters': [],
+                'responses': {
+                  '200': {
+                    'content': {
+                        'application/json': {
+                            'schema': {}
+                        }
+                    },
+                    'description': '',
+                  }
+                },
+                'tags': ['bad-ideas'],
+            }
+        )
+
+    def test_bugfix_list_uppercase(self):
+        """
+        Test list API views have title-cased endpoint names
+        """
+
+        class ThingView(APIView):
+            action = 'list'
+
+            def get(self, *args, **kwargs):
+                pass
+
+        inspector = AutoSchema()
+        inspector.view = create_view(ThingView, 'GET', create_request('/test/'))
+        self.assertDictEqual(
+            inspector.get_operation('GET', '/test/'),
+            {
+                'operationId': 'ListThings',
+                'parameters': [],
+                'responses': {
+                  '200': {
+                    'content': {
+                        'application/json': {
+                            'schema': {
+                                'type': 'array',
+                                'items': {},
+                            }
+                        }
+                    },
+                    'description': '',
+                  }
+                },
+            }
+        )
+
+    def test_bugfix_hstorefield(self):
+        """
+        Test the temporary fix for serializers.HStoreField being treated as a string
+        """
+        inspector = AutoSchema()
+        self.assertDictEqual(
+            inspector._map_field(serializers.HStoreField()),
+            {'type': 'object'},
+        )
+        # Other fields should be unaffected
+        self.assertDictEqual(
+            inspector._map_field(serializers.CharField()),
+            {'type': 'string'},
+        )
+
+    def test_bugfix_callable_defaults(self):
+        """
+        Test the temporary fix to convert callable defaults into their results
+        """
+
+        def bad_default():
+            raise Exception('Nope')
+
+        def fancy_default():
+            raise Exception("Don't touch me!")
+
+        fancy_default.openapi_value = 1337.0
+
+        class ThingSerializer(serializers.Serializer):
+            my_field = serializers.FloatField(default=float)
+            my_field_on_fire = serializers.FloatField(default=bad_default)
+            my_fancy_field = serializers.FloatField(default=fancy_default)
+
+        inspector = AutoSchema()
+        with warnings.catch_warnings(record=True) as warn_list:
+            schema = inspector._map_serializer(ThingSerializer())
+
+        self.assertEqual(len(warn_list), 1)
+        self.assertEqual(
+            str(warn_list[0].message),
+            'Unsupported callable default for field my_field_on_fire: Nope',
+        )
+
+        self.assertDictEqual(schema, {
+            'properties': {
+                'my_field': {
+                    'type': 'number',
+                    'default': 0.0,
+                },
+                'my_field_on_fire': {
+                    'type': 'number',
+                },
+                'my_fancy_field': {
+                    'type': 'number',
+                    'default': 1337.0,
+                },
+            },
+        })
+
+    def test_bugfix_enum_defaults(self):
+        """
+        Test the temporary fix to convert enums into strings
+        """
+
+        class ThingSerializer(serializers.Serializer):
+            my_field = EnumField(DataImportMode, default=DataImportMode.Images)
+
+        inspector = AutoSchema()
+        self.assertDictEqual(inspector._map_serializer(ThingSerializer()), {
+            'properties': {
+                'my_field': {
+                    'default': 'images',
+                    'enum': [
+                        'images',
+                        'pdf',
+                        'repository',
+                        'elements',
+                    ]
+                }
+            }
+        })
+
+    def test_bugfix_null_defaults(self):
+        """
+        Test the temporary fix that hides None defaults when they do not actually exist
+        """
+
+        class ThingSerializer(serializers.Serializer):
+            my_field = serializers.CharField()
+
+        inspector = AutoSchema()
+        self.assertDictEqual(inspector._map_serializer(ThingSerializer()), {
+            'properties': {
+                'my_field': {
+                    'type': 'string',
+                }
+            },
+            'required': ['my_field'],
+        })
diff --git a/openapi/Dockerfile b/openapi/Dockerfile
deleted file mode 100644
index f778aee54e..0000000000
--- a/openapi/Dockerfile
+++ /dev/null
@@ -1,8 +0,0 @@
-FROM registry.gitlab.com/arkindex/backend:latest
-
-RUN pip uninstall -y djangorestframework
-COPY ["patch.py", "run.sh", "requirements.txt", "patch.yml", "/"]
-RUN pip install -r /requirements.txt && rm /requirements.txt
-
-ENTRYPOINT ["/bin/sh", "-c"]
-CMD ["/run.sh"]
diff --git a/openapi/requirements.txt b/openapi/requirements.txt
deleted file mode 100644
index ceabbab7b9..0000000000
--- a/openapi/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-git+https://github.com/encode/django-rest-framework.git@37f210a455cc92cb3f61a23e194a1d0de58d149b#egg=djangorestframework
-coreapi==2.3.3
-apistar>=0.7.2
diff --git a/openapi/run.sh b/openapi/run.sh
deleted file mode 100755
index 47c4f7391c..0000000000
--- a/openapi/run.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#!/bin/sh
-PONOS_DATA_DIR=/tmp manage.py generateschema > original.yml
-./patch.py patch.yml original.yml
diff --git a/requirements.txt b/requirements.txt
index a865b2167c..4e9d9ec810 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,7 +8,7 @@ django-admin-hstore-widget==1.0.1
 django-cors-headers==2.4.0
 django-enumfields==1.0.0
 django-filter==2.2.0
-djangorestframework==3.9.2
+djangorestframework==3.10.3
 elasticsearch-dsl>=6.0.0,<7.0.0
 et-xmlfile==1.0.1
 gitpython==2.1.11
diff --git a/tests-requirements.txt b/tests-requirements.txt
index 3c40425e7b..243c31c1b2 100644
--- a/tests-requirements.txt
+++ b/tests-requirements.txt
@@ -2,3 +2,4 @@ flake8==3.6.0
 tripoli
 django-nose
 coverage
+uritemplate==3
-- 
GitLab