From 34d60e4990c1db44d7a326153178837e0333bb7d Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Fri, 20 May 2022 13:56:12 +0000 Subject: [PATCH] Refresh OAuth tokens --- .gitlab-ci.yml | 2 +- arkindex/dataimport/providers.py | 77 ++-- .../dataimport/tests/test_gitlab_provider.py | 334 ++++++++++++++---- arkindex/dataimport/tests/test_providers.py | 10 +- arkindex/documents/fixtures/data.json | 4 +- .../management/commands/build_fixtures.py | 3 + arkindex/project/config.py | 2 +- .../tests/config_samples/defaults.yaml | 2 +- arkindex/project/tests/test_config.py | 5 +- arkindex/users/admin.py | 2 +- arkindex/users/api.py | 12 +- arkindex/users/models.py | 5 + arkindex/users/providers.py | 97 ++--- arkindex/users/serializers.py | 2 +- arkindex/users/tests/test_gitlab_oauth.py | 197 ++++++++--- ci/static-collect.sh | 2 +- 16 files changed, 540 insertions(+), 216 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 712126b4cf..7fcd617c62 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,7 +22,7 @@ include: - "pip install -e git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/ponos#egg=ponos-server" - "pip install -e git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/transkribus#egg=transkribus-client" - pip install -r tests-requirements.txt codecov - - "echo 'database: {host: postgres, port: 5432}' > $CONFIG_PATH" + - "echo 'database: {host: postgres, port: 5432}\npublic_hostname: http://ci.arkindex.localhost' > $CONFIG_PATH" # Those jobs require the base image; they might fail if the image is not up to date. # Allow them to fail when building a new base image, to prevent them from blocking a new base image build diff --git a/arkindex/dataimport/providers.py b/arkindex/dataimport/providers.py index c43d37d8f9..3d1a4f5820 100644 --- a/arkindex/dataimport/providers.py +++ b/arkindex/dataimport/providers.py @@ -21,13 +21,10 @@ class GitProvider(ABC): display_name = None - def __init__(self, credentials=None, url=None): + def __init__(self, credentials=None): if credentials: from arkindex.users.models import OAuthCredentials assert isinstance(credentials, OAuthCredentials) - self.url = credentials.provider_url - if url: - self.url = url self.credentials = credentials @abstractmethod @@ -49,8 +46,8 @@ class GitProvider(ABC): """ @abstractmethod - def create_hook(self, repository, project_id=None, request=None, base_url=None): - """Create a webhook to receive events from the provide""" + def create_hook(self, repository, project_id=None, base_url=None): + """Create a webhook to receive events from the project""" @abstractmethod def get_repository_type(self, project): @@ -138,7 +135,12 @@ class GitProvider(ABC): class GitLabProvider(GitProvider): display_name = "GitLab" - url = 'https://gitlab.com' + + def _get_gitlab_client(self, credentials): + if credentials.expired: + credentials.provider.refresh_token() + + return Gitlab(credentials.provider_url, oauth_token=credentials.token) def _try_get_project(self, gl, id): try: @@ -149,13 +151,13 @@ class GitLabProvider(GitProvider): def _get_project_from_repo(self, repo): assert repo.credentials, "Missing Gitlab credentials" - gl = Gitlab(self.url, oauth_token=repo.credentials.token) + gl = self._get_gitlab_client(repo.credentials) return self._try_get_project(gl, urllib.parse.urlsplit(repo.url).path.strip('/')) def list_repos(self, query=None): if not self.credentials: - raise NotAuthenticated() - gl = Gitlab(self.url, oauth_token=self.credentials.token) + raise NotAuthenticated + gl = self._get_gitlab_client(self.credentials) # Creating a webhook on a repo requires Maintainer (40) or Owner (50) access levels # See https://docs.gitlab.com/ce/api/members.html#valid-access-levels return gl.projects.list(min_access_level=40, search=query) @@ -200,10 +202,11 @@ class GitLabProvider(GitProvider): 'url': project.web_url } - def create_repo(self, id=None, corpus=None, request=None, **kwargs): - if not self.credentials and request: + def create_repo(self, id=None, **kwargs): + if not self.credentials: raise NotAuthenticated() - gl = Gitlab(self.url, oauth_token=self.credentials.token) + + gl = self._get_gitlab_client(self.credentials) project = self._try_get_project(gl, int(id)) from arkindex.dataimport.models import Repository @@ -233,40 +236,44 @@ class GitLabProvider(GitProvider): provider_name=self.__class__.__name__, ) - # Create a webhook using information from the HTTP request - self.create_hook(repo, project_id=int(id), request=request) + self.create_hook(repo, project_id=int(id)) return repo - def create_hook(self, repository, project_id=None, request=None, base_url=None): - """Configure the Gitlab hook to get events for that project""" + def create_hook(self, repository, project_id=None, base_url=None): + """ + Configure the Gitlab hook to get events for a project. + + If `project_id` is set, then the project is retrieved directly from its GitLab project ID + instead of matched using its path. - # Load project using a project if or its repo path - gitlab = Gitlab(self.url, oauth_token=repository.credentials.token) + The webhook's URL will use `base_url` as its base URL if it is set, and otherwise + falls back on settings.BACKEND_PUBLIC_URL_OAUTH first, then settings.PUBLIC_HOSTNAME. + """ + # Load project using a project ID or its repo path + gitlab = Gitlab(repository.credentials.provider_url, oauth_token=repository.credentials.token) if project_id: project = self._try_get_project(gitlab, project_id) else: path = urllib.parse.urlparse(repository.url).path[1:] project = self._try_get_project(gitlab, path) - try: - url = reverse('api:import-hook', kwargs={'pk': repository.id}) - if settings.BACKEND_PUBLIC_URL_OAUTH: - url = urllib.parse.urljoin(settings.BACKEND_PUBLIC_URL_OAUTH, url) - elif request is not None: - url = request.build_absolute_uri(url) - elif base_url is not None: - url = urllib.parse.urljoin(base_url, url) - else: - raise Exception("You need to specify an HTTP request or a base url") - logger.info(f"Webhook will be created as {url}") + if base_url is None: + base_url = settings.BACKEND_PUBLIC_URL_OAUTH or settings.PUBLIC_HOSTNAME - # Delete already configured hooks to be able to update - for hook in project.hooks.list(all=True): - if hook.url == url: - hook.delete() - logger.info(f"Deleted existing hook {hook.id}") + if base_url is None: + raise APIException('Either the `base_url` argument, settings.BACKEND_PUBLIC_URL_OAUTH or settings.PUBLIC_HOSTNAME must be set.') + url = urllib.parse.urljoin(base_url, reverse('api:import-hook', kwargs={'pk': repository.id})) + logger.info(f"Webhook will be created as {url}") + + # Delete already configured hooks to be able to update + for hook in project.hooks.list(all=True): + if hook.url == url: + hook.delete() + logger.info(f"Deleted existing hook {hook.id}") + + try: # Create a new hook hook = project.hooks.create({ 'url': url, diff --git a/arkindex/dataimport/tests/test_gitlab_provider.py b/arkindex/dataimport/tests/test_gitlab_provider.py index 1aea8a63f1..0c1c593268 100644 --- a/arkindex/dataimport/tests/test_gitlab_provider.py +++ b/arkindex/dataimport/tests/test_gitlab_provider.py @@ -2,8 +2,11 @@ import collections from pathlib import Path from unittest.mock import MagicMock, patch +import responses from django.conf import settings +from django.test import override_settings from gitlab.exceptions import GitlabCreateError, GitlabGetError +from responses import matchers from rest_framework.exceptions import APIException, AuthenticationFailed, NotAuthenticated, ValidationError from arkindex.dataimport.models import DataImport, DataImportMode, GitRefType, RepositoryType, Revision @@ -34,15 +37,16 @@ class TestGitLabProvider(FixtureTestCase): super().tearDown() self.gl_patch.stop() + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_list_repos(self): """ Test GitLabProvider can list repositories from GitLab """ - GitLabProvider(url='http://aaa', credentials=self.creds).list_repos() + GitLabProvider(credentials=self.creds).list_repos() self.assertEqual(self.gl_mock.call_count, 1) args, kwargs = self.gl_mock.call_args - self.assertTupleEqual(args, ('http://aaa', )) + self.assertTupleEqual(args, ('https://somewhere', )) self.assertDictEqual(kwargs, {'oauth_token': self.creds.token}) self.assertEqual(self.gl_mock().projects.list.call_count, 1) @@ -50,15 +54,50 @@ class TestGitLabProvider(FixtureTestCase): self.assertTupleEqual(args, ()) self.assertDictEqual(kwargs, {'min_access_level': 40, 'search': None}) + @responses.activate + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost', GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t') + def test_list_repos_refresh(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'refresh_token', + 'refresh_token': 'refresh-token' + }), + ], + json={ + 'access_token': 'new-token', + 'refresh_token': 'new-refresh-token', + 'token_type': 'Bearer', + 'created_at': 1582984800, + 'expires_in': '3600', + }, + ) + responses.get('https://somewhere/api/v4/user', json={'id': 42, 'username': 'Someone'}) + self.creds.expiry = None + self.creds.save() + + GitLabProvider(credentials=self.creds).list_repos() + + self.creds.refresh_from_db() + self.assertEqual(self.creds.token, 'new-token') + self.assertEqual(self.creds.refresh_token, 'new-refresh-token') + self.assertEqual(self.creds.account_name, 'Someone') + self.assertEqual(self.creds.expiry.isoformat(), '2020-02-29T15:00:00+00:00') + + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_list_repos_query(self): """ Test GitLabProvider can search repositories from GitLab """ - GitLabProvider(url='http://aaa', credentials=self.creds).list_repos(query='meh') + GitLabProvider(credentials=self.creds).list_repos(query='meh') self.assertEqual(self.gl_mock.call_count, 1) args, kwargs = self.gl_mock.call_args - self.assertTupleEqual(args, ('http://aaa', )) + self.assertTupleEqual(args, ('https://somewhere', )) self.assertDictEqual(kwargs, {'oauth_token': self.creds.token}) self.assertEqual(self.gl_mock().projects.list.call_count, 1) @@ -66,31 +105,30 @@ class TestGitLabProvider(FixtureTestCase): self.assertTupleEqual(args, ()) self.assertDictEqual(kwargs, {'min_access_level': 40, 'search': 'meh'}) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_list_repos_requires_credentials(self): """ Test GitLabProvider checks for credentials when requesting repositories list """ with self.assertRaises(NotAuthenticated): - GitLabProvider(url='http://aaa').list_repos() + GitLabProvider().list_repos() + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_create_repo(self): """ Test GitLabProvider can create a Repository instance from a GitLab repo """ self.gl_mock().projects.get.return_value.web_url = 'http://new_repo_url' - self.gl_mock().projects.get.return_value.default_branch = 'branchname' self.gl_mock().projects.get.return_value.permissions = { 'project_access': {'access_level': 50}, 'group_access': None } - request_mock = MagicMock() - request_mock.build_absolute_uri.return_value = 'http://hook' - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) glp.get_repository_type = MagicMock() glp.get_repository_type.return_value = RepositoryType.IIIF - new_repo = glp.create_repo(id='1337', request=request_mock, corpus=self.corpus) + new_repo = glp.create_repo(id='1337') self.assertEqual(self.gl_mock().projects.get.call_count, 2) args, kwargs = self.gl_mock().projects.get.call_args @@ -105,104 +143,136 @@ class TestGitLabProvider(FixtureTestCase): self.assertEqual(len(args), 1) self.assertDictEqual(kwargs, {}) self.assertDictEqual(args[0], { - 'url': 'http://hook', + 'url': f'https://arkindex.localhost/api/v1/imports/hook/{new_repo.id}/', 'push_events': True, 'tag_push_events': True, 'token': new_repo.hook_token, }) + @responses.activate + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost', GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t') + def test_create_repo_refresh(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'refresh_token', + 'refresh_token': 'refresh-token' + }), + ], + json={ + 'access_token': 'new-token', + 'refresh_token': 'new-refresh-token', + 'token_type': 'Bearer', + 'created_at': 1582984800, + 'expires_in': '3600', + }, + ) + responses.get('https://somewhere/api/v4/user', json={'id': 42, 'username': 'Someone'}) + self.creds.expiry = None + self.creds.save() + + self.gl_mock().projects.get.return_value.web_url = 'http://new_repo_url' + self.gl_mock().projects.get.return_value.permissions = { + 'project_access': {'access_level': 50}, + 'group_access': None + } + + glp = GitLabProvider(credentials=self.creds) + glp.get_repository_type = MagicMock() + glp.get_repository_type.return_value = RepositoryType.IIIF + glp.create_repo(id='1337') + + self.creds.refresh_from_db() + self.assertEqual(self.creds.token, 'new-token') + self.assertEqual(self.creds.refresh_token, 'new-refresh-token') + self.assertEqual(self.creds.account_name, 'Someone') + self.assertEqual(self.creds.expiry.isoformat(), '2020-02-29T15:00:00+00:00') + + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_create_repo_requires_credentials(self): """ Test GitLabProvider checks for credentials when requesting a repository creation """ - request_mock = MagicMock() - request_mock.build_absolute_uri.return_value = 'http://hook' with self.assertRaises(NotAuthenticated): - GitLabProvider(url='http://aaa').create_repo( - id='repo_id', request=request_mock, corpus=self.corpus) + GitLabProvider().create_repo(id='repo_id') + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_create_repo_already_exists(self): """ Test GitLabProvider checks for duplicate repositories """ self.gl_mock().projects.get.return_value.web_url = 'http://new_repo_url' - self.gl_mock().projects.get.return_value.default_branch = 'branchname' self.gl_mock().projects.get.return_value.permissions = { 'project_access': {'access_level': 40}, 'group_access': None } - request_mock = MagicMock() - request_mock.build_absolute_uri.return_value = 'http://hook' - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) glp.get_repository_type = MagicMock() glp.get_repository_type.return_value = RepositoryType.IIIF - glp.create_repo(id='1337', request=request_mock, corpus=self.corpus) + glp.create_repo(id='1337') with self.assertRaises(ValidationError): - GitLabProvider(url='http://aaa', credentials=self.creds).create_repo( - id='1337', request=request_mock, corpus=self.corpus) + GitLabProvider(credentials=self.creds).create_repo( + id='1337') + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_create_repo_requires_maintainer(self): """ Test GitLabProvider checks for duplicate repositories """ self.gl_mock().projects.get.return_value.web_url = 'http://new_repo_url' - self.gl_mock().projects.get.return_value.default_branch = 'branchname' self.gl_mock().projects.get.return_value.permissions = { 'group_access': {'access_level': 30}, 'project_access': None, } - request_mock = MagicMock() - request_mock.build_absolute_uri.return_value = 'http://hook' - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) glp.get_repository_type = MagicMock() glp.get_repository_type.return_value = RepositoryType.IIIF with self.assertRaisesRegex( ValidationError, 'Maintainer or Owner access is required to add a GitLab repository'): - glp.create_repo(id='1337', request=request_mock, corpus=self.corpus) + glp.create_repo(id='1337') + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_create_repo_handle_get_error(self): """ Test GitLabProvider handles GitLab repo GET errors """ self.gl_mock().projects.get.side_effect = GitlabGetError - request_mock = MagicMock() - request_mock.build_absolute_uri.return_value = 'http://hook' - with self.assertRaises(APIException): - GitLabProvider(url='http://aaa', credentials=self.creds).create_repo( - id='1337', request=request_mock, corpus=self.corpus) + GitLabProvider(credentials=self.creds).create_repo(id='1337') self.assertEqual(self.gl_mock().projects.get.call_count, 1) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_create_repo_handle_hook_create_error(self): """ Test GitLabProvider handles GitLab hook creation errors """ self.gl_mock().projects.get.return_value.web_url = 'http://new_repo_url' - self.gl_mock().projects.get.return_value.default_branch = 'branchname' self.gl_mock().projects.get.return_value.permissions = { 'project_access': {'access_level': 50} } self.gl_mock().projects.get.return_value.hooks.create.side_effect = GitlabCreateError - request_mock = MagicMock() - request_mock.build_absolute_uri.return_value = 'http://hook' - with self.assertRaises(APIException): - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) glp.get_repository_type = MagicMock() glp.get_repository_type.return_value = RepositoryType.IIIF - glp.create_repo(id='1337', request=request_mock, corpus=self.corpus) + glp.create_repo(id='1337') self.assertEqual(self.gl_mock().projects.get.call_count, 2) self.assertEqual(self.gl_mock().projects.get().hooks.create.call_count, 1) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_update_or_create_ref(self): """ Test GitLabProvider can create or update a GitRef instance for a repo and a revision @@ -229,7 +299,7 @@ class TestGitLabProvider(FixtureTestCase): # Assert that git references are created properly for ref in commit_refs: - GitLabProvider(url='http://aaa', credentials=self.creds) \ + GitLabProvider(credentials=self.creds) \ .update_or_create_ref(self.repo, rev1, ref['name'], ref['type']) refs = [ @@ -245,9 +315,9 @@ class TestGitLabProvider(FixtureTestCase): self.assertEqual(len(self.repo.refs.all()), 3) # Assert that git references are updated with another revision properly - GitLabProvider(url='http://aaa', credentials=self.creds) \ + GitLabProvider(credentials=self.creds) \ .update_or_create_ref(self.repo, rev2, commit_refs[0]['name'], commit_refs[0]['type']) - GitLabProvider(url='http://aaa', credentials=self.creds) \ + GitLabProvider(credentials=self.creds) \ .update_or_create_ref(self.repo, rev2, commit_refs[1]['name'], commit_refs[1]['type']) refs_rev1 = [ @@ -269,17 +339,19 @@ class TestGitLabProvider(FixtureTestCase): ]) self.assertEqual(len(self.repo.refs.all()), 3) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_get_revision(self): """ Test GitLabProvider can create a Revision instance for a repo by hash """ - revision, created = GitLabProvider(url='http://aaa', credentials=self.creds) \ + revision, created = GitLabProvider(credentials=self.creds) \ .get_or_create_revision(self.repo, '42') self.assertEqual(revision, self.rev) self.assertFalse(created) self.assertEqual(self.gl_mock.call_count, 0) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_create_revision(self): """ Test GitLabProvider can create a Revision instance for a repo by hash @@ -293,7 +365,7 @@ class TestGitLabProvider(FixtureTestCase): self.gl_mock().projects.get.return_value.commits.get.return_value.message = 'commit message' self.gl_mock().projects.get.return_value.commits.get.return_value.author_name = 'bob' - revision, created = GitLabProvider(url='http://aaa', credentials=self.creds) \ + revision, created = GitLabProvider(credentials=self.creds) \ .get_or_create_revision(self.repo, '1337') self.assertTrue(created) @@ -324,11 +396,50 @@ class TestGitLabProvider(FixtureTestCase): self.assertTupleEqual(args, ('1337', )) self.assertDictEqual(kwargs, {}) + @responses.activate + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost', GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t') + def test_create_revision_refresh(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'refresh_token', + 'refresh_token': 'refresh-token' + }), + ], + json={ + 'access_token': 'new-token', + 'refresh_token': 'new-refresh-token', + 'token_type': 'Bearer', + 'created_at': 1582984800, + 'expires_in': '3600', + }, + ) + responses.get('https://somewhere/api/v4/user', json={'id': 42, 'username': 'Someone'}) + self.creds.expiry = None + self.creds.save() + + self.gl_mock().projects.get.return_value.commits.get.return_value.refs.return_value = [] + self.gl_mock().projects.get.return_value.commits.get.return_value.message = 'commit message' + self.gl_mock().projects.get.return_value.commits.get.return_value.author_name = 'bob' + + GitLabProvider(credentials=self.creds).get_or_create_revision(self.repo, '1337') + + self.creds.refresh_from_db() + self.assertEqual(self.creds.token, 'new-token') + self.assertEqual(self.creds.refresh_token, 'new-refresh-token') + self.assertEqual(self.creds.account_name, 'Someone') + self.assertEqual(self.creds.expiry.isoformat(), '2020-02-29T15:00:00+00:00') + + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_handle_webhook_missing_headers(self): """ Test GitLabProvider checks HTTP headers on webhooks """ - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) request_mock = MagicMock() request_mock.data = { @@ -375,6 +486,7 @@ class TestGitLabProvider(FixtureTestCase): with self.assertRaises(AuthenticationFailed): glp.handle_webhook(self.repo, request_mock) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_handle_webhook_create_revision(self): """ Test GitLabProvider does create a revision when a push event is received @@ -411,11 +523,12 @@ class TestGitLabProvider(FixtureTestCase): self.assertFalse(rev.exists()) repo_imports = DataImport.objects.filter(revision__repo_id=str(self.repo.id)) self.assertFalse(repo_imports.exists()) - GitLabProvider(url='http://aaa', credentials=self.creds).handle_webhook(self.repo, request_mock) + GitLabProvider(credentials=self.creds).handle_webhook(self.repo, request_mock) di = repo_imports.get() self.assertEqual(di.mode, DataImportMode.Repository) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_handle_webhook_duplicate_events(self): """ Test GitLabProvider checks for already handled events @@ -444,13 +557,14 @@ class TestGitLabProvider(FixtureTestCase): self.assertTrue(rev.exists()) repo_imports = DataImport.objects.filter(revision__repo_id=str(self.repo.id)) self.assertFalse(repo_imports.exists()) - GitLabProvider(url='http://aaa', credentials=self.creds).handle_webhook(self.repo, request_mock) + GitLabProvider(credentials=self.creds).handle_webhook(self.repo, request_mock) # Checking that we didn't initiate revision creation self.assertEqual(self.gl_mock().projects.get.call_count, 0) self.assertEqual(self.gl_mock().projects.get().commits.get.call_count, 0) self.assertFalse(repo_imports.exists()) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_handle_webhook_wrong_kind(self): """ Test GitLabProvider gracefully fails when GitLab breaks things @@ -460,7 +574,7 @@ class TestGitLabProvider(FixtureTestCase): self.assertFalse(rev.exists()) repo_imports = DataImport.objects.filter(revision__repo_id=str(self.repo.id)) - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) request_mock = MagicMock() request_mock.META = { 'HTTP_X_GITLAB_EVENT': 'Push Hook', @@ -491,6 +605,7 @@ class TestGitLabProvider(FixtureTestCase): self.assertFalse(rev.exists()) self.assertFalse(repo_imports.exists()) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_handle_webhook_delete_branch(self): """ Test GitLabProvider properly handles a branch deletion @@ -505,7 +620,7 @@ class TestGitLabProvider(FixtureTestCase): self.assertTrue(self.repo.revisions.filter(hash='1').exists()) repo_imports = DataImport.objects.filter(revision__repo_id=str(self.repo.id)) - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) glp.update_or_create_ref(self.repo, rev, 'test', GitRefType.Branch) self.assertEqual(len(self.repo.refs.all()), 1) request_mock = MagicMock() @@ -524,6 +639,58 @@ class TestGitLabProvider(FixtureTestCase): self.assertEqual(len(self.repo.refs.all()), 0) self.assertFalse(repo_imports.exists()) + @responses.activate + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost', GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t') + def test_handle_webhook_refresh(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'refresh_token', + 'refresh_token': 'refresh-token' + }), + ], + json={ + 'access_token': 'new-token', + 'refresh_token': 'new-refresh-token', + 'token_type': 'Bearer', + 'created_at': 1582984800, + 'expires_in': '3600', + }, + ) + responses.get('https://somewhere/api/v4/user', json={'id': 42, 'username': 'Someone'}) + self.creds.expiry = None + self.creds.save() + + self.gl_mock().projects.get.return_value.commits.get.return_value.refs.return_value = [] + self.gl_mock().projects.get.return_value.commits.get.return_value.message = 'commit message' + self.gl_mock().projects.get.return_value.commits.get.return_value.author_name = 'bob' + + request_mock = MagicMock() + request_mock.META = { + 'HTTP_X_GITLAB_EVENT': 'Push Hook', + 'HTTP_X_GITLAB_TOKEN': 'hook-token', + } + request_mock.data = { + 'object_kind': 'push', + 'ref': 'refs/heads/something', + 'commits': [], + 'checkout_sha': '1337', + } + + glp = GitLabProvider(credentials=self.creds) + glp.handle_webhook(self.repo, request_mock) + + self.creds.refresh_from_db() + self.assertEqual(self.creds.token, 'new-token') + self.assertEqual(self.creds.refresh_token, 'new-refresh-token') + self.assertEqual(self.creds.account_name, 'Someone') + self.assertEqual(self.creds.expiry.isoformat(), '2020-02-29T15:00:00+00:00') + + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_retrieve_repo_type(self): """ Gitlab provider allow to retrieve a project type @@ -538,13 +705,14 @@ class TestGitLabProvider(FixtureTestCase): with (SAMPLES / 'worker.yml').open('r') as f: project.repository_raw_blob.return_value = f.read() - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) repo_type = glp.get_repository_type(project) raw_blob_args, _ = project.repository_raw_blob.call_args self.assertEqual(raw_blob_args, ('424242', )) self.assertEqual(repo_type, RepositoryType.Worker) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_retrieve_repo_type_file_missing(self): project = MagicMock() project.name = 'Test' @@ -554,10 +722,11 @@ class TestGitLabProvider(FixtureTestCase): {'id': '445566', 'name': 'pong'} ] - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) with self.assertRaisesMessage(ValidationError, 'Test project is missing a .arkindex.yml configuration file'): glp.get_repository_type(project) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_retrieve_repo_type_wrong_type(self): project = MagicMock() self.gl_mock().projects.get.return_value = project @@ -566,11 +735,12 @@ class TestGitLabProvider(FixtureTestCase): with (SAMPLES / 'wrong_type.yml').open('r') as f: project.repository_raw_blob.return_value = f.read() - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) type_err = f'type is not defined or does not match handled repository types ({list(RepositoryType)})' with self.assertRaisesMessage(ValidationError, type_err): glp.get_repository_type(project) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_retrieve_repo_type_wrong_version(self): project = MagicMock() self.gl_mock().projects.get.return_value = project @@ -579,26 +749,28 @@ class TestGitLabProvider(FixtureTestCase): with (SAMPLES / 'wrong_version.yml').open('r') as f: project.repository_raw_blob.return_value = f.read() - glp = GitLabProvider(url='http://aaa', credentials=self.creds) + glp = GitLabProvider(credentials=self.creds) version_err = f'version is not defined or is different from the latest version ({settings.WORKERS_CONF_VERSION})' with self.assertRaisesMessage(ValidationError, version_err): glp.get_repository_type(project) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_retrieve_repo_type_invalid_file(self): project = MagicMock() project.name = 'Test' self.gl_mock().projects.get.return_value = project project.repository_tree.return_value = [{'id': 'file_id', 'name': settings.WORKERS_CONF_PATH}] - for invalid_file in ('wrong_format.txt', 'text.txt', 'empty'): - with (SAMPLES / 'wrong_format.txt').open('r') as f: - project.repository_raw_blob.return_value = f.read() + for invalid_file in ('wrong_format.txt', 'text.txt', 'empty.yml'): + with self.subTest(file=invalid_file): + project.repository_raw_blob.return_value = (SAMPLES / invalid_file).read_text() - glp = GitLabProvider(url='http://aaa', credentials=self.creds) - parsing_err = f'Test project has an invalid {settings.WORKERS_CONF_PATH} configuration file' - with self.assertRaisesMessage(ValidationError, parsing_err): - glp.get_repository_type(project) + glp = GitLabProvider(credentials=self.creds) + parsing_err = f'Test project has an invalid {settings.WORKERS_CONF_PATH} configuration file' + with self.assertRaisesMessage(ValidationError, parsing_err): + glp.get_repository_type(project) + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost') def test_update_repo_references(self): """ Check that we are able to fetch new branch and tags references @@ -606,14 +778,14 @@ class TestGitLabProvider(FixtureTestCase): Ref = collections.namedtuple('Ref', 'name, commit') # Add 2 commits on 2 branches - self.gl_mock().projects.get.return_value.branches.list = lambda: [ + self.gl_mock().projects.get.return_value.branches.list.return_value = [ Ref("master", {"id": "commit1", "message": "A commit on master", "author_email": "someone@teklia.com"}), # Create a commit with a very long branch name Ref("A" * 1000, {"id": "commit2", "message": "My fancy feature", "author_email": "another@teklia.com"}) ] # Add 1 new commit on a tag, and reuse a commit on another tag - self.gl_mock().projects.get.return_value.tags.list = lambda: [ + self.gl_mock().projects.get.return_value.tags.list.return_value = [ Ref("v0.1", {"id": "commit0", "message": "This is a legacy version", "author_email": "old@teklia.com"}), Ref("v1.0", {"id": "commit1", "message": "A commit on master", "author_email": "someone@teklia.com"}) ] @@ -624,7 +796,7 @@ class TestGitLabProvider(FixtureTestCase): self.assertEqual(DataImport.objects.count(), 1) # Update references for this repo - provider = GitLabProvider(url='http://aaa', credentials=self.creds) + provider = GitLabProvider(credentials=self.creds) provider.update_repository_references(self.repo) # We should now have 3 revisions @@ -662,3 +834,41 @@ class TestGitLabProvider(FixtureTestCase): (f"Import {'A' * 93}", DataImportMode.Repository, 'commit2'), ("Process fixture", DataImportMode.Workers, None), ]) + + @responses.activate + @override_settings(PUBLIC_HOSTNAME='https://arkindex.localhost', GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t') + def test_update_repo_references_refresh(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'refresh_token', + 'refresh_token': 'refresh-token' + }), + ], + json={ + 'access_token': 'new-token', + 'refresh_token': 'new-refresh-token', + 'token_type': 'Bearer', + 'created_at': 1582984800, + 'expires_in': '3600', + }, + ) + responses.get('https://somewhere/api/v4/user', json={'id': 42, 'username': 'Someone'}) + self.creds.expiry = None + self.creds.save() + + self.gl_mock().projects.get.return_value.branches.list.return_value = [] + self.gl_mock().projects.get.return_value.tags.list.return_value = [] + + provider = GitLabProvider(credentials=self.creds) + provider.update_repository_references(self.repo) + + self.creds.refresh_from_db() + self.assertEqual(self.creds.token, 'new-token') + self.assertEqual(self.creds.refresh_token, 'new-refresh-token') + self.assertEqual(self.creds.account_name, 'Someone') + self.assertEqual(self.creds.expiry.isoformat(), '2020-02-29T15:00:00+00:00') diff --git a/arkindex/dataimport/tests/test_providers.py b/arkindex/dataimport/tests/test_providers.py index f3d6318e24..e53aadb9d4 100644 --- a/arkindex/dataimport/tests/test_providers.py +++ b/arkindex/dataimport/tests/test_providers.py @@ -35,20 +35,14 @@ class TestProviders(FixtureAPITestCase): def test_init(self): glp = GitLabProvider() - self.assertEqual(glp.url, GitLabProvider.url) - - glp = GitLabProvider(url='something') - self.assertEqual(glp.url, 'something') + self.assertIsNone(glp.credentials) glp = GitLabProvider(credentials=self.creds) - self.assertEqual(glp.url, self.creds.provider_url) + self.assertEqual(glp.credentials, self.creds) with self.assertRaises(Exception): GitLabProvider(credentials='not a OAuthCredentials') - glp = GitLabProvider(credentials=self.creds, url='something') - self.assertEqual(glp.url, 'something') - @patch('arkindex.dataimport.api.Repository.provider_class') def test_webhook_no_credentials(self, provider_class): self.client.force_login(self.user) diff --git a/arkindex/documents/fixtures/data.json b/arkindex/documents/fixtures/data.json index 1a390c897b..fbe896df80 100644 --- a/arkindex/documents/fixtures/data.json +++ b/arkindex/documents/fixtures/data.json @@ -1589,8 +1589,8 @@ "provider_url": "https://somewhere", "status": "created", "token": "oauth-token", - "refresh_token": null, - "expiry": null, + "refresh_token": "refresh-token", + "expiry": "2100-12-31T23:59:59.999Z", "account_name": null } }, diff --git a/arkindex/documents/management/commands/build_fixtures.py b/arkindex/documents/management/commands/build_fixtures.py index c70fc9826c..a143ba32f3 100644 --- a/arkindex/documents/management/commands/build_fixtures.py +++ b/arkindex/documents/management/commands/build_fixtures.py @@ -87,6 +87,9 @@ class Command(BaseCommand): provider_name='gitlab', provider_url='https://somewhere', token='oauth-token', + refresh_token='refresh-token', + # Use an expiry very far away to avoid OAuth token refreshes in every test + expiry=datetime(2100, 12, 31, 23, 59, 59, 999999, timezone.utc), ) # Create a worker repository diff --git a/arkindex/project/config.py b/arkindex/project/config.py index 857e0afc7d..4243317cdb 100644 --- a/arkindex/project/config.py +++ b/arkindex/project/config.py @@ -58,7 +58,7 @@ def get_settings_parser(base_dir): parser.add_option('imports_worker_version', type=uuid.UUID, default=None) parser.add_option('workers_max_chunks', type=int, default=10) parser.add_option('robots_txt_disallow', type=str, many=True, default=[]) - parser.add_option('public_hostname', type=public_hostname, default=None) + parser.add_option('public_hostname', type=public_hostname) # SECURITY WARNING: keep the secret key used in production secret! parser.add_option('secret_key', type=str, default='jf0w^y&ml(caax8f&a1mub)(js9(l5mhbbhosz3gi+m01ex+lo') diff --git a/arkindex/project/tests/config_samples/defaults.yaml b/arkindex/project/tests/config_samples/defaults.yaml index a6d0507367..00553e7d8d 100644 --- a/arkindex/project/tests/config_samples/defaults.yaml +++ b/arkindex/project/tests/config_samples/defaults.yaml @@ -60,7 +60,7 @@ ponos: default_env: {} default_farm: null private_key: /somewhere/backend/arkindex/ponos.key -public_hostname: null +public_hostname: https://default.config.arkindex.localhost redis: db: 0 host: localhost diff --git a/arkindex/project/tests/test_config.py b/arkindex/project/tests/test_config.py index ffa2863967..33ba744d05 100644 --- a/arkindex/project/tests/test_config.py +++ b/arkindex/project/tests/test_config.py @@ -46,7 +46,10 @@ class TestConfig(TestCase): def test_settings_defaults(self): parser = get_settings_parser(Path('/somewhere/backend/arkindex')) self.assertIsInstance(parser, ConfigParser) - data = parser.parse_data({}) + data = parser.parse_data({ + # Settings that have no default values at all - not setting them causes an error + 'public_hostname': 'https://default.config.arkindex.localhost' + }) with (SAMPLES / 'defaults.yaml').open() as f: expected = f.read() diff --git a/arkindex/users/admin.py b/arkindex/users/admin.py index 03a7f08ccf..b4f3acfd9d 100644 --- a/arkindex/users/admin.py +++ b/arkindex/users/admin.py @@ -124,7 +124,7 @@ class GroupAdmin(admin.ModelAdmin): class OAuthCredentialAdmin(admin.ModelAdmin): list_display = ('id', 'user', 'provider_name') - fields = ('id', 'user', 'provider_name', 'token', 'refresh_token', 'status') + fields = ('id', 'user', 'provider_name', 'token', 'refresh_token', 'expiry', 'status') readonly_fields = ('id', ) list_filter = ('provider_name', ) search_fields = ('user', ) diff --git a/arkindex/users/api.py b/arkindex/users/api.py index 56173b0796..62ded6097a 100644 --- a/arkindex/users/api.py +++ b/arkindex/users/api.py @@ -107,7 +107,7 @@ class CredentialsRetrieve(RetrieveDestroyAPIView): return self.request.user.credentials.order_by('id') def perform_destroy(self, instance): - instance.provider_class(request=self.request, credentials=instance).disconnect() + instance.provider_class(credentials=instance).disconnect() super().perform_destroy(instance) @@ -384,7 +384,7 @@ class OAuthSignIn(APIView): provider_class = get_provider(kwargs['provider']) if not provider_class: raise ValidationError('Unknown provider') - url = self.request.GET.get('url', provider_class.url) + url = self.request.GET.get('url', provider_class.default_url) if url: parsed = urllib.parse.urlparse(url) if (parsed.scheme != 'https' @@ -399,7 +399,7 @@ class OAuthSignIn(APIView): provider_url=url, ) return Response({ - 'url': provider_class(request=self.request, credentials=creds).get_authorize_uri(), + 'url': provider_class(credentials=creds).get_authorize_uri(), }) @@ -424,7 +424,7 @@ class OAuthRetry(RetrieveAPIView): def retrieve(self, request, *args, **kwargs): creds = self.get_object() return Response({ - 'url': creds.provider_class(request=self.request, credentials=creds).get_authorize_uri(), + 'url': creds.provider_class(credentials=creds).get_authorize_uri(), }) @@ -447,9 +447,9 @@ class OAuthCallback(UserPassesTestMixin, RedirectView): provider_class = get_provider(kwargs['provider']) if not provider_class: raise ValueError('Unknown provider') - provider = provider_class(self.request) + provider = provider_class() try: - provider.handle_callback() + provider.handle_callback(request) provider.credentials.status = OAuthStatus.Done except Exception as e: if not provider.credentials: diff --git a/arkindex/users/models.py b/arkindex/users/models.py index 5fcd2f1bed..629116fb39 100644 --- a/arkindex/users/models.py +++ b/arkindex/users/models.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime, timezone from django.contrib.auth.models import AbstractBaseUser from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation @@ -203,6 +204,10 @@ class OAuthCredentials(models.Model): def git_provider(self): return self.git_provider_class(credentials=self) + @property + def expired(self): + return self.expiry is None or self.expiry < datetime.now(timezone.utc) + class Meta: verbose_name = 'OAuth credentials' verbose_name_plural = 'OAuth credentials' diff --git a/arkindex/users/providers.py b/arkindex/users/providers.py index 6bfbfc895a..ad4ccc8570 100644 --- a/arkindex/users/providers.py +++ b/arkindex/users/providers.py @@ -1,6 +1,6 @@ -import datetime import urllib.parse from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone import requests from django.conf import settings @@ -17,19 +17,10 @@ class OAuthProvider(ABC): display_name = "" slug = "" - def __init__(self, request=None, credentials=None, url=None): - if request is not None: - # Allow Django and Django REST Framework requests - from django.http.request import HttpRequest - from rest_framework.request import Request - assert isinstance(request, (HttpRequest, Request)) + def __init__(self, credentials=None, url=None): if credentials is not None: from arkindex.users.models import OAuthCredentials assert isinstance(credentials, OAuthCredentials) - self.url = credentials.provider_url - elif url is not None: - self.url = url - self.request = request self.credentials = credentials @classmethod @@ -52,11 +43,17 @@ class OAuthProvider(ABC): """ @abstractmethod - def handle_callback(self): + def handle_callback(self, request): """ Handle a OAuth callback and save token data. Should raise exceptions if the process fails. """ + @abstractmethod + def refresh_token(self): + """ + Refresh an expired OAuth token and update OAuthCredentials attributes. + """ + @abstractmethod def disconnect(self): """ @@ -68,7 +65,7 @@ class GitLabOAuthProvider(OAuthProvider): display_name = 'GitLab' slug = 'gitlab' - url = 'https://gitlab.com' + default_url = 'https://gitlab.com' authorize_endpoint = '/oauth/authorize' token_endpoint = '/oauth/token' @@ -77,18 +74,18 @@ class GitLabOAuthProvider(OAuthProvider): return settings.GITLAB_APP_ID and settings.GITLAB_APP_SECRET def get_callback_uri(self): - if not self.request: - return - url = reverse('api:oauth-callback', kwargs={'provider': self.slug}) + if settings.BACKEND_PUBLIC_URL_OAUTH: return urllib.parse.urljoin(settings.BACKEND_PUBLIC_URL_OAUTH, url) - return self.request.build_absolute_uri(url) + assert settings.PUBLIC_HOSTNAME, 'PUBLIC_HOSTNAME is required to generate callback URIs' + return urllib.parse.urljoin(settings.PUBLIC_HOSTNAME, url) def get_authorize_uri(self): - if not self.request or not self.credentials: + if not self.credentials: return + return '{}?{}'.format( urllib.parse.urljoin(self.credentials.provider_url, self.authorize_endpoint), urllib.parse.urlencode({ @@ -100,33 +97,51 @@ class GitLabOAuthProvider(OAuthProvider): }), ) - def handle_callback(self): - if not self.request: - return - - state = self.request.GET.get('state') + def handle_callback(self, request): + state = request.GET.get('state') if not state: raise ValueError('No state hash') - self.credentials = self.request.user.credentials.get(id=state) + self.credentials = request.user.credentials.get(id=state) - if not any(param in self.request.GET for param in ('code', 'error')): + if not any(param in request.GET for param in ('code', 'error')): raise ValueError('Callback called without a valid response') - if 'error' in self.request.GET: - raise ValueError(self.request.GET.get('error_description', self.request.GET['error'])) + if 'error' in request.GET: + raise ValueError(request.GET.get('error_description', request.GET['error'])) from arkindex.users.models import OAuthStatus assert self.credentials.status != OAuthStatus.Done, 'Cannot overwrite existing credentials' + self.grant_token(grant_type='authorization_code', code=request.GET.get('code', '')) + + def refresh_token(self): + try: + self.grant_token(grant_type='refresh_token', refresh_token=self.credentials.refresh_token) + except Exception: + # Set an error state to allow the user to retry manually + from arkindex.users.models import OAuthStatus + self.credentials.status = OAuthStatus.Error + self.credentials.save() + raise + + def grant_token(self, **kwargs): + """ + Use Doorkeeper's OAuth token endpoint to get a token, either from an OAuth authorization code flow + or when refreshing a token, and update the OAuthCredentials attributes. + + https://github.com/doorkeeper-gem/doorkeeper/wiki/API-endpoint-descriptions-and-examples#post---oauthtoken + """ + assert self.credentials is not None, 'An OAuthCredentials instance is required to use the OAuth token endpoint' + + payload = { + 'client_id': settings.GITLAB_APP_ID, + 'client_secret': settings.GITLAB_APP_SECRET, + 'redirect_uri': self.get_callback_uri(), + } + payload.update(kwargs) response = requests.post( urllib.parse.urljoin(self.credentials.provider_url, self.token_endpoint), - { - 'client_id': settings.GITLAB_APP_ID, - 'client_secret': settings.GITLAB_APP_SECRET, - 'code': self.request.GET.get('code', ''), - 'grant_type': 'authorization_code', - 'redirect_uri': self.get_callback_uri(), - } + payload, ) response.raise_for_status() data = response.json() @@ -136,29 +151,31 @@ class GitLabOAuthProvider(OAuthProvider): self.credentials.token = data['access_token'] self.credentials.refresh_token = data.get('refresh_token') - if 'expires_in' in data: - self.credentials.expiry = datetime.datetime.now() + datetime.timedelta(0, int(data['expires_in'])) + # Get absolute date of expiration from the created_at UNIX timestamp and the seconds until expiration + self.credentials.expiry = datetime.fromtimestamp(data['created_at'], tz=timezone.utc) + timedelta(seconds=int(data['expires_in'])) gl = Gitlab(self.credentials.provider_url, oauth_token=self.credentials.token) gl.auth() self.credentials.account_name = gl.user.username + from arkindex.users.models import OAuthStatus self.credentials.status = OAuthStatus.Done self.credentials.save() def disconnect(self): - if not self.request and self.credentials: - raise NotAuthenticated() + if not self.credentials: + raise NotAuthenticated if not self.credentials.token or not self.credentials.repos.exists(): return # Remove all webhooks try: - gl = Gitlab(self.url, oauth_token=self.credentials.token) + gl = Gitlab(self.credentials.provider_url, oauth_token=self.credentials.token) for repo in self.credentials.repos.all(): project = gl.projects.get(urllib.parse.urlsplit(repo.url).path.strip('/')) - hook_url = self.request.build_absolute_uri( + hook_url = urllib.parse.urljoin( + settings.PUBLIC_HOSTNAME, reverse('api:import-hook', kwargs={'pk': repo.id}) ) # Try to find the webhook diff --git a/arkindex/users/serializers.py b/arkindex/users/serializers.py index 0df1bfb172..ef172a1377 100644 --- a/arkindex/users/serializers.py +++ b/arkindex/users/serializers.py @@ -50,7 +50,7 @@ class OAuthProviderClassSerializer(serializers.Serializer): name = serializers.CharField(source='slug') display_name = serializers.CharField() - default_url = serializers.URLField(source='url') + default_url = serializers.URLField() class OAuthRetrySerializer(serializers.Serializer): diff --git a/arkindex/users/tests/test_gitlab_oauth.py b/arkindex/users/tests/test_gitlab_oauth.py index 5a85cfd4b7..bb3649a313 100644 --- a/arkindex/users/tests/test_gitlab_oauth.py +++ b/arkindex/users/tests/test_gitlab_oauth.py @@ -1,8 +1,11 @@ import urllib.parse from unittest.mock import MagicMock, patch +import responses from django.http.request import HttpRequest -from django.urls import reverse +from django.test import override_settings +from requests import HTTPError +from responses import matchers from arkindex.project.tests import FixtureTestCase from arkindex.users.models import OAuthStatus @@ -31,19 +34,23 @@ class TestGitLabOAuthProvider(FixtureTestCase): with self.settings(GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='1234'): self.assertTrue(GitLabOAuthProvider.enabled()) - def test_callback_uri(self): - request_mock = MagicMock(spec=HttpRequest) - GitLabOAuthProvider(request=request_mock).get_callback_uri() - self.assertEqual(request_mock.build_absolute_uri.call_count, 1) - args, kwargs = request_mock.build_absolute_uri.call_args - self.assertTupleEqual(args, (reverse('api:oauth-callback', kwargs={'provider': 'gitlab'}), )) - self.assertDictEqual(kwargs, {}) + @override_settings(BACKEND_PUBLIC_URL_OAUTH='http://arkindex.localhost:8000', PUBLIC_HOSTNAME='http://arkindex.localhost:8080') + def test_callback_uri_dev_override(self): + self.assertEqual(GitLabOAuthProvider().get_callback_uri(), 'http://arkindex.localhost:8000/api/v1/oauth/providers/gitlab/callback/') + + @override_settings(BACKEND_PUBLIC_URL_OAUTH=None, PUBLIC_HOSTNAME='https://arkindex.localhost/') + def test_callback_uri_public_hostname(self): + self.assertEqual(GitLabOAuthProvider().get_callback_uri(), 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/') + @override_settings(BACKEND_PUBLIC_URL_OAUTH=None, PUBLIC_HOSTNAME=None) + def test_callback_uri_request(self): + with self.assertRaisesMessage(AssertionError, 'PUBLIC_HOSTNAME is required to generate callback URIs'): + GitLabOAuthProvider().get_callback_uri() + + @override_settings(BACKEND_PUBLIC_URL_OAUTH=None, PUBLIC_HOSTNAME='https://arkindex.localhost/') def test_authorize_uri(self): - request_mock = MagicMock(spec=HttpRequest) - request_mock.build_absolute_uri.return_value = 'callback' with self.settings(GITLAB_APP_ID='abcd'): - uri = GitLabOAuthProvider(request=request_mock, credentials=self.creds).get_authorize_uri() + uri = GitLabOAuthProvider(credentials=self.creds).get_authorize_uri() parsed = urllib.parse.urlparse(uri) self.assertEqual(parsed.scheme, 'https') self.assertEqual(parsed.netloc, 'somewhere') @@ -51,7 +58,7 @@ class TestGitLabOAuthProvider(FixtureTestCase): query = urllib.parse.parse_qs(parsed.query) self.assertDictEqual(query, { 'client_id': ['abcd'], - 'redirect_uri': ['callback'], + 'redirect_uri': ['https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/'], 'scope': ['api'], 'response_type': ['code'], 'state': [str(self.creds.id)], @@ -62,21 +69,21 @@ class TestGitLabOAuthProvider(FixtureTestCase): request_mock.user = self.user request_mock.GET = {} with self.assertRaisesMessage(ValueError, 'No state hash'): - GitLabOAuthProvider(request=request_mock).handle_callback() + GitLabOAuthProvider().handle_callback(request_mock) def test_handle_callback_bad_response(self): request_mock = MagicMock(spec=HttpRequest) request_mock.user = self.user request_mock.GET = {'state': str(self.creds.id)} with self.assertRaisesRegex(ValueError, 'valid response'): - GitLabOAuthProvider(request=request_mock).handle_callback() + GitLabOAuthProvider().handle_callback(request_mock) def test_handle_callback_error(self): request_mock = MagicMock(spec=HttpRequest) request_mock.user = self.user request_mock.GET = {'state': str(self.creds.id), 'error': 'error message'} with self.assertRaisesMessage(ValueError, 'error message'): - GitLabOAuthProvider(request=request_mock).handle_callback() + GitLabOAuthProvider().handle_callback(request_mock) def test_handle_callback_overwrite(self): self.creds.status = OAuthStatus.Done @@ -85,10 +92,30 @@ class TestGitLabOAuthProvider(FixtureTestCase): request_mock.user = self.user request_mock.GET = {'state': str(self.creds.id), 'code': 'something'} with self.assertRaisesRegex(Exception, 'overwrite'): - GitLabOAuthProvider(request=request_mock).handle_callback() - - @patch('arkindex.users.providers.requests') - def test_handle_callback_success(self, requests): + GitLabOAuthProvider().handle_callback(request_mock) + + @responses.activate + @override_settings(BACKEND_PUBLIC_URL_OAUTH=None, PUBLIC_HOSTNAME='https://arkindex.localhost') + def test_handle_callback_success(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'code': 'abc123', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'authorization_code', + }), + ], + json={ + 'access_token': 't0k3n', + 'refresh_token': 'r3fr3sh', + 'token_type': 'Bearer', + 'created_at': 1582984800, + 'expires_in': '3600', + }, + ) request_mock = MagicMock(spec=HttpRequest) request_mock.user = self.user request_mock.GET = {'state': str(self.creds.id), 'code': 'abc123'} @@ -96,57 +123,115 @@ class TestGitLabOAuthProvider(FixtureTestCase): self.gl_mock.return_value.user.username = 'bobby' - requests.post.return_value.json.return_value = { - 'access_token': 't0k3n', - 'refresh_token': 'r3fr3sh', - 'token_type': 'Bearer', - 'expires_in': '3600', - } - with self.settings(GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t'): - GitLabOAuthProvider(request=request_mock).handle_callback() + GitLabOAuthProvider().handle_callback(request_mock) self.creds.refresh_from_db() self.assertEqual(self.creds.token, 't0k3n') self.assertEqual(self.creds.refresh_token, 'r3fr3sh') + self.assertEqual(self.creds.expiry.isoformat(), '2020-02-29T15:00:00+00:00') self.assertEqual(self.creds.account_name, 'bobby') self.assertEqual(self.creds.status, OAuthStatus.Done) - self.assertEqual(requests.post.call_count, 1) self.assertEqual(self.gl_mock.return_value.auth.call_count, 1) - args, kwargs = requests.post.call_args - self.assertDictEqual(kwargs, {}) - self.assertEqual(len(args), 2) - parsed = urllib.parse.urlparse(args[0]) - self.assertEqual(parsed.scheme, 'https') - self.assertEqual(parsed.netloc, 'somewhere') - self.assertEqual(parsed.path, '/oauth/token') - self.assertEqual(parsed.query, '') - self.assertDictEqual(args[1], { - 'client_id': 'abcd', - 'client_secret': 's3kr3t', - 'code': 'abc123', - 'redirect_uri': 'callback', - 'grant_type': 'authorization_code', - }) + args, kwargs = self.gl_mock.call_args + self.assertEqual(args, ('https://somewhere', )) + self.assertDictEqual(kwargs, {'oauth_token': 't0k3n'}) + + @responses.activate + @override_settings(BACKEND_PUBLIC_URL_OAUTH=None, PUBLIC_HOSTNAME='https://arkindex.localhost') + def test_refresh_token(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'refresh_token', + 'refresh_token': 'r3fr3sh' + }), + ], + json={ + 'access_token': 't0k3n', + 'refresh_token': 'r4fr4sh', + 'token_type': 'Bearer', + 'created_at': 1582984800, + 'expires_in': '3600', + }, + ) + + self.gl_mock.return_value.user.username = 'bobby' + + self.creds.refresh_token = 'r3fr3sh' + self.creds.save() + self.assertEqual(self.creds.expiry.isoformat(), '2100-12-31T23:59:59.999000+00:00') + self.assertEqual(self.creds.token, 'oauth-token') + + with self.settings(GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t'): + GitLabOAuthProvider(credentials=self.creds).refresh_token() + + self.assertEqual(self.creds.token, 't0k3n') + self.assertEqual(self.creds.refresh_token, 'r4fr4sh') + self.assertEqual(self.creds.expiry.isoformat(), '2020-02-29T15:00:00+00:00') + self.assertEqual(self.creds.account_name, 'bobby') + self.assertEqual(self.creds.status, OAuthStatus.Done) + self.assertEqual(self.gl_mock.return_value.auth.call_count, 1) args, kwargs = self.gl_mock.call_args self.assertEqual(args, ('https://somewhere', )) self.assertDictEqual(kwargs, {'oauth_token': 't0k3n'}) - def test_disconnect(self): - request_mock = MagicMock(spec=HttpRequest) - request_mock.build_absolute_uri.return_value = 'hook' + @responses.activate + @override_settings(BACKEND_PUBLIC_URL_OAUTH=None, PUBLIC_HOSTNAME='https://arkindex.localhost') + def test_refresh_token_error(self): + responses.post( + 'https://somewhere/oauth/token', + match=[ + matchers.urlencoded_params_matcher({ + 'client_id': 'abcd', + 'client_secret': 's3kr3t', + 'redirect_uri': 'https://arkindex.localhost/api/v1/oauth/providers/gitlab/callback/', + 'grant_type': 'refresh_token', + 'refresh_token': 'r3fr3sh' + }), + ], + status=418, + ) + + self.creds.refresh_token = 'r3fr3sh' + self.creds.save() + self.assertEqual(self.creds.expiry.isoformat(), '2100-12-31T23:59:59.999000+00:00') + self.assertEqual(self.creds.token, 'oauth-token') - hook_mock = MagicMock() - hook_mock.url = 'hook' - self.gl_mock.return_value.projects.get.return_value.id = 'repo_id' - self.gl_mock.return_value.projects.get.return_value.hooks.list.return_value = [hook_mock] + with self.settings(GITLAB_APP_ID='abcd', GITLAB_APP_SECRET='s3kr3t'), self.assertRaises(HTTPError): + GitLabOAuthProvider(credentials=self.creds).refresh_token() - GitLabOAuthProvider(request=request_mock, credentials=self.creds).disconnect() + self.assertEqual(self.creds.token, 'oauth-token') + self.assertEqual(self.creds.refresh_token, 'r3fr3sh') + self.assertEqual(self.creds.expiry.isoformat(), '2100-12-31T23:59:59.999000+00:00') + self.assertEqual(self.creds.status, OAuthStatus.Error) + self.assertFalse(self.gl_mock.return_value.auth.called) - # Number of repositories associated to those credentials - repos_count = self.creds.repos.count() - self.assertEqual(self.gl_mock().projects.get.call_count, repos_count) - self.assertEqual(self.gl_mock().projects.get.return_value.hooks.list.call_count, repos_count) - self.assertEqual(hook_mock.delete.call_count, repos_count) + @override_settings(BACKEND_PUBLIC_URL_OAUTH=None, PUBLIC_HOSTNAME='https://arkindex.localhost') + def test_disconnect(self): + iiif_repo, worker_repo = self.creds.repos.order_by('type') + # This hook does not match the expected hook URL and should not be removed + ignored_hook = MagicMock() + ignored_hook.url = 'https://potato.localhost/something' + iiif_hook = MagicMock() + iiif_hook.url = f'https://arkindex.localhost/api/v1/imports/hook/{iiif_repo.id}/' + worker_hook = MagicMock() + worker_hook.url = f'https://arkindex.localhost/api/v1/imports/hook/{worker_repo.id}/' + self.gl_mock().projects.get.return_value.id = 'repo_id' + self.gl_mock().projects.get.return_value.hooks.list.return_value = [ + ignored_hook, iiif_hook, worker_hook + ] + + GitLabOAuthProvider(credentials=self.creds).disconnect() + + self.assertEqual(self.gl_mock().projects.get.call_count, 2) + self.assertEqual(self.gl_mock().projects.get.return_value.hooks.list.call_count, 2) + self.assertFalse(ignored_hook.delete.called) + self.assertEqual(iiif_hook.delete.call_count, 1) + self.assertEqual(worker_hook.delete.call_count, 1) diff --git a/ci/static-collect.sh b/ci/static-collect.sh index 989ac6442b..81ac94ec17 100755 --- a/ci/static-collect.sh +++ b/ci/static-collect.sh @@ -1,5 +1,5 @@ #!/bin/sh mkdir -p static pip install -e . -echo "static: {root_path: '$(pwd)/static'}" > "$CONFIG_PATH" +echo "static: {root_path: '$(pwd)/static'}" >> "$CONFIG_PATH" arkindex/manage.py collectstatic --noinput -- GitLab