diff --git a/.gitignore b/.gitignore index 377e82b5c3ea040a0f2485e78e1599233d2a5349..79a14630dbff84a516355482fcaecdc40eb03cda 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ dist .eggs logs media +workers .vscode local_settings.py arkindex/iiif-users/ diff --git a/Makefile b/Makefile index 5c334250e330b48d9501fb432ee5d234cbc9bd5c..7caf753a7db619566487707afc686adb6cad1148 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ TUNNEL_PORT:=8000 VERSION=$(shell git rev-parse --short HEAD) TAG_APP=arkindex-app TAG_BASE=arkindex-base +FRONTEND_BRANCH=master .PHONY: build base all: clean build @@ -21,7 +22,7 @@ clean: build: python setup.py sdist - docker build $(ROOT_DIR) -t $(TAG_APP):$(VERSION) -t $(TAG_APP):latest + docker build $(ROOT_DIR) -t $(TAG_APP):$(VERSION) -t $(TAG_APP):latest --build-arg FRONTEND_BRANCH=$(FRONTEND_BRANCH) publish-version: require-docker-auth $(MAKE) build TAG_APP=registry.gitlab.com/arkindex/backend diff --git a/arkindex/dataimport/admin.py b/arkindex/dataimport/admin.py index 9530bca1c257feee7a85b70677f56bef998f4f0c..70978a4399b6a1c0d40da0d3caf27777c405b05d 100644 --- a/arkindex/dataimport/admin.py +++ b/arkindex/dataimport/admin.py @@ -28,9 +28,9 @@ class RevisionInline(admin.StackedInline): class RepositoryAdmin(admin.ModelAdmin): - list_display = ('id', 'url', 'user', 'corpus') + list_display = ('id', 'url', 'corpus') list_filter = ('corpus', ) - fields = ('id', 'url', 'user', 'corpus', 'clone_user', 'clone_token', 'hook_token', 'watched_branches') + fields = ('id', 'url', 'corpus', 'hook_token', 'watched_branches') readonly_fields = ('id', ) inlines = [RevisionInline, ] diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index 37d7a44aa7e60c8b23c75ab04b5db666203e2f42..4949d81a7c06810fa97e1d347a66753f139aee73 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -1,17 +1,19 @@ from django.shortcuts import get_object_or_404 from rest_framework.generics import \ - ListAPIView, ListCreateAPIView, RetrieveUpdateDestroyAPIView + ListAPIView, ListCreateAPIView, RetrieveUpdateDestroyAPIView, RetrieveAPIView from rest_framework.views import APIView from rest_framework.parsers import MultiPartParser, FileUploadParser -from rest_framework.permissions import IsAuthenticated +from rest_framework.permissions import IsAuthenticated, IsAdminUser from rest_framework.response import Response from rest_framework import status -from rest_framework.exceptions import ValidationError, NotAuthenticated, AuthenticationFailed +from rest_framework.exceptions import ValidationError from arkindex.documents.models import Corpus from arkindex.dataimport.models import \ - DataImport, DataFile, DataImportState, DataImportMode, DataImportFailure, Repository, RepositorySource, Revision + DataImport, DataFile, DataImportState, DataImportMode, DataImportFailure, Repository from arkindex.dataimport.serializers import \ - DataImportLightSerializer, DataImportSerializer, DataImportFailureSerializer, DataFileSerializer + DataImportLightSerializer, DataImportSerializer, DataImportFailureSerializer, DataFileSerializer, \ + RepositoryLightSerializer, RepositorySerializer, ExternalRepositorySerializer +from arkindex.users.models import OAuthCredentials import hashlib import magic @@ -160,40 +162,64 @@ class GitRepositoryImportHook(APIView): def post(self, request, pk=None, **kwargs): repo = get_object_or_404(Repository, id=pk) - - if repo.source == RepositorySource.GitLab: - if 'HTTP_X_GITLAB_EVENT' not in request.META: - raise ValidationError("Missing GitLab event type") - if request.META['HTTP_X_GITLAB_EVENT'] != 'Push Hook': - raise ValidationError("Unsupported GitLab event type") - - if 'HTTP_X_GITLAB_TOKEN' not in request.META: - raise NotAuthenticated("Missing GitLab secret token") - if request.META['HTTP_X_GITLAB_TOKEN'] != repo.hook_token: - raise AuthenticationFailed("Invalid GitLab secret token") - - assert isinstance(request.data, dict) - assert request.data['object_kind'] == 'push' - - if request.data['ref'] not in repo.watched_branches: - return Response(status=status.HTTP_204_NO_CONTENT) - - # Already took care of this event - if Revision.objects.filter( - repo=repo, - ref=request.data['ref'], - hash=request.data['checkout_sha']).exists(): - return Response(status=status.HTTP_204_NO_CONTENT) - - rev = Revision.objects.create( - repo=repo, - hash=request.data['checkout_sha'], - ref=request.data['ref'], - message=request.data['commits'][0]['message'], - author=request.data['commits'][0]['author']['name'], - ) - else: - raise NotImplementedError - - rev.start_import() + repo.provider_class(credentials=repo.credentials).handle_webhook(repo, request) return Response(status=status.HTTP_204_NO_CONTENT) + + +class RepositoryList(ListAPIView): + permission_classes = (IsAuthenticated, ) + serializer_class = RepositoryLightSerializer + + def get_queryset(self): + return Repository.objects.filter(credentials__user=self.request.user) + + +class AvailableRepositoriesList(ListCreateAPIView): + permission_classes = (IsAuthenticated, ) + pagination_class = None + serializer_class = ExternalRepositorySerializer + + def get_queryset(self): + cred = get_object_or_404(OAuthCredentials, user=self.request.user, id=self.kwargs['pk']) + return cred.git_provider_class(credentials=cred).list_repos( + query=self.request.query_params.get('search'), + ) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + cred = get_object_or_404(OAuthCredentials, user=self.request.user, id=self.kwargs['pk']) + + provider = cred.git_provider_class(credentials=cred) + repo = provider.create_repo(**serializer.validated_data, request=self.request) + rev, _ = provider.get_or_create_latest_revision(repo) + dataimport = rev.start_import() + + return Response(data={'import_id': str(dataimport.id)}, status=status.HTTP_201_CREATED) + + +class RepositoryRetrieve(RetrieveUpdateDestroyAPIView): + permission_classes = (IsAuthenticated, ) + serializer_class = RepositorySerializer + + def get_queryset(self): + return Repository.objects.filter(credentials__user=self.request.user) + + +class RepositoryStartImport(RetrieveAPIView): + permission_classes = (IsAdminUser, ) + + def get_queryset(self): + return Repository.objects.filter(credentials__user=self.request.user) + + def get(self, request, *args, **kwargs): + repo = self.get_object() + + rev, _ = repo.credentials.git_provider_class( + credentials=repo.credentials, + ).get_or_create_latest_revision(repo) + + if rev.dataimports.filter(state=DataImportState.Running).exists(): + raise ValidationError("An import is already running for the latest revision") + + return Response(data={'import_id': str(rev.start_import().id)}) diff --git a/arkindex/dataimport/migrations/0004_remove_tokens.py b/arkindex/dataimport/migrations/0004_remove_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a88745e83749bd028e6f26e7c30f90ffc0e665 --- /dev/null +++ b/arkindex/dataimport/migrations/0004_remove_tokens.py @@ -0,0 +1,59 @@ +# Generated by Django 2.1 on 2018-08-20 13:22 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0003_oauthcredentials'), + ('dataimport', '0003_dataimportfailure'), + ] + + operations = [ + migrations.RemoveField( + model_name='repository', + name='clone_token', + ), + migrations.RemoveField( + model_name='repository', + name='clone_user', + ), + migrations.AlterField( + model_name='dataimport', + name='revision', + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name='dataimports', + to='dataimport.Revision', + ), + ), + migrations.RemoveField( + model_name='repository', + name='user', + ), + migrations.AddField( + model_name='repository', + name='credentials', + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name='repos', + to='users.OAuthCredentials', + ), + ), + migrations.AddField( + model_name='repository', + name='provider_name', + field=models.CharField( + default='GitLabProvider', + max_length=50, + choices=[('GitLabProvider', 'GitLab')], + ), + preserve_default=False, + ), + ] diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py index 76cce5042ac5e98e1d4e3f54966b7706eaa4d21e..a3b887ddc1e2308146c54c7191175344289f44be 100644 --- a/arkindex/dataimport/models.py +++ b/arkindex/dataimport/models.py @@ -6,12 +6,12 @@ from celery import states from celery.canvas import Signature from celery.result import AsyncResult, GroupResult from enumfields import EnumField, Enum +from arkindex.dataimport.providers import git_providers, get_provider from arkindex.project.models import IndexableModel from arkindex.project.fields import ArrayField import uuid import os import re -import urllib.parse class DataImportState(Enum): @@ -40,8 +40,8 @@ class DataImport(IndexableModel): state = EnumField(DataImportState, default=DataImportState.Created, max_length=30) mode = EnumField(DataImportMode, max_length=30) files = models.ManyToManyField('dataimport.DataFile', related_name='imports') - revision = models.OneToOneField( - 'dataimport.Revision', related_name='dataimport', on_delete=models.CASCADE, blank=True, null=True) + revision = models.ForeignKey( + 'dataimport.Revision', related_name='dataimports', on_delete=models.CASCADE, blank=True, null=True) payload = JSONField(null=True, blank=True) root_id = models.UUIDField(null=True, blank=True) task_count = models.PositiveSmallIntegerField(null=True, blank=True) @@ -77,8 +77,8 @@ class DataImport(IndexableModel): from arkindex.dataimport.tasks import check_images, import_images return check_images.s(self) | import_images.s(self) elif self.mode == DataImportMode.Repository: - from arkindex.dataimport.tasks import clone_repo, import_repo, cleanup_repo - return clone_repo.si(self) | import_repo.si(self) | cleanup_repo.si(self) + from arkindex.dataimport.tasks import download_repo, import_repo, cleanup_repo + return download_repo.si(self) | import_repo.si(self) | cleanup_repo.si(self) else: raise NotImplementedError @@ -173,11 +173,6 @@ class DataFile(models.Model): return os.path.join(settings.MEDIA_ROOT, str(self.id)) -class RepositorySource(Enum): - GitHub = 'github' - GitLab = 'gitlab' - - def repository_default_branches(): ''' This is needed to avoid re-using the same list instance @@ -191,33 +186,25 @@ class Repository(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4) url = models.URLField(unique=True) hook_token = models.CharField(max_length=250, unique=True) - clone_user = models.CharField(max_length=100) - clone_token = models.CharField(max_length=250) corpus = models.ForeignKey('documents.Corpus', on_delete=models.CASCADE, related_name='repos') - user = models.ForeignKey('users.User', on_delete=models.CASCADE, related_name='repos') + credentials = models.ForeignKey( + 'users.OAuthCredentials', on_delete=models.CASCADE, related_name='repos', blank=True, null=True) watched_branches = ArrayField(models.CharField(max_length=50), default=repository_default_branches) + provider_name = models.CharField( + max_length=50, + choices=[(p.__name__, p.display_name) for p in git_providers], + ) class Meta: verbose_name_plural = 'repositories' @property - def auth_url(self): - """Repository URL with added credentials""" - parsed = list(urllib.parse.urlsplit(self.url)) - if '@' in parsed[1]: # URL seems to already have credentials - return self.url - parsed[1] = '{}:{}@{}'.format(self.clone_user, self.clone_token, parsed[1]) - return urllib.parse.urlunsplit(parsed) + def provider_class(self): + return get_provider(self.provider_name) @property - def source(self): - parsed = urllib.parse.urlsplit(self.url) - if parsed.netloc == 'gitlab.com': - return RepositorySource.GitLab - elif parsed.netloc == 'github.com': - return RepositorySource.GitHub - else: - raise ValueError('Unknown repository source') + def provider(self): + return self.provider_class(credentials=self.credentials) @property def clone_dir(self): @@ -240,10 +227,11 @@ class Revision(models.Model): return '{}/commit/{}'.format(self.repo.url.rstrip('/'), self.hash) def start_import(self): - DataImport.objects.create( - creator=self.repo.user, + dataimport = self.dataimports.create( + creator=self.repo.credentials.user, corpus=self.repo.corpus, mode=DataImportMode.Repository, state=DataImportState.Configured, - revision=self, - ).start() + ) + dataimport.start() + return dataimport diff --git a/arkindex/dataimport/providers.py b/arkindex/dataimport/providers.py new file mode 100644 index 0000000000000000000000000000000000000000..a73ce6c2a6be83628a18de7efb3e3cb3c77a1cd1 --- /dev/null +++ b/arkindex/dataimport/providers.py @@ -0,0 +1,174 @@ +from abc import ABC, abstractmethod +from django.urls import reverse +from rest_framework.exceptions import NotAuthenticated, AuthenticationFailed, APIException, ValidationError +from gitlab import Gitlab, GitlabGetError, GitlabCreateError +from arkindex.documents.models import Corpus +import urllib.parse +import base64 +import uuid + + +class GitProvider(ABC): + + display_name = None + + def __init__(self, credentials=None, url=None): + if credentials: + from arkindex.users.models import OAuthCredentials + assert isinstance(credentials, OAuthCredentials) + self.url = credentials.provider_url + if url: + self.url = url + self.credentials = credentials + + @abstractmethod + def list_repos(self, query=None): + """ + List all repositories or filter with a search query. + """ + + @abstractmethod + def create_repo(self, **kwargs): + """ + Create a Repository instance from an external repository + """ + + @abstractmethod + def download_archive(self, revision, path): + """ + Download an archive for a given Revision instance. + """ + + @abstractmethod + def get_or_create_latest_revision(self, repo): + """ + Get a Revision instance for the last revision on the main branch of a given repository. + """ + + @abstractmethod + def handle_webhook(self, repo, request): + """ + Handle a webhook event on a given repository. + """ + + +class GitLabProvider(GitProvider): + + display_name = "GitLab" + url = 'https://gitlab.com' + + def list_repos(self, query=None): + if not self.credentials: + raise NotAuthenticated() + gl = Gitlab(self.url, oauth_token=self.credentials.token) + return gl.projects.list(membership=True, search=query) + + def create_repo(self, id=None, corpus=None, request=None, **kwargs): + assert isinstance(corpus, Corpus) + if not self.credentials and request: + raise NotAuthenticated() + gl = Gitlab(self.url, oauth_token=self.credentials.token) + try: + project = gl.projects.get(int(id)) + except GitlabGetError as e: + raise APIException("Error while fetching GitLab project: {}".format(str(e))) + + from arkindex.dataimport.models import Repository + if Repository.objects.filter(url=project.web_url).exists(): + raise ValidationError("A repository with this URL already exists") + + repo = self.credentials.repos.create( + corpus=corpus, + url=project.web_url, + watched_branches=['refs/heads/{}'.format(project.default_branch)], + hook_token=str(base64.b64encode(uuid.uuid4().bytes)), + provider_name=self.__class__.__name__, + ) + + try: + project.hooks.create({ + 'url': request.build_absolute_uri( + reverse('api:import-hook', kwargs={'pk': repo.id}) + ), + 'push_events': True, + 'token': repo.hook_token, + }) + except GitlabCreateError as e: + raise APIException("Error while creating GitLab hook: {}".format(str(e))) + + return repo + + def download_archive(self, revision, path): + gl = Gitlab(self.url, oauth_token=revision.repo.credentials.token) + try: + project = gl.projects.get(urllib.parse.urlsplit(revision.repo.url).path.strip('/')) + except GitlabGetError as e: + raise APIException("Error while fetching GitLab project: {}".format(str(e))) + + with open(path, 'wb') as f: + project.repository_archive(sha=revision.hash, streamed=True, action=f.write) + + def get_or_create_latest_revision(self, repo): + gl = Gitlab(self.url, oauth_token=repo.credentials.token) + try: + project = gl.projects.get(urllib.parse.urlsplit(repo.url).path.strip('/')) + except GitlabGetError as e: + raise APIException("Error while fetching GitLab project: {}".format(str(e))) + + latest_commit = project.commits.list()[0] + return repo.revisions.get_or_create( + repo=repo, + hash=latest_commit.id, + defaults={ + 'ref': latest_commit.refs()[0]['name'], + 'message': latest_commit.message, + 'author': latest_commit.author_name, + }, + ) + + def handle_webhook(self, repo, request): + if 'HTTP_X_GITLAB_EVENT' not in request.META: + raise ValidationError("Missing GitLab event type") + if request.META['HTTP_X_GITLAB_EVENT'] != 'Push Hook': + raise ValidationError("Unsupported GitLab event type") + + if 'HTTP_X_GITLAB_TOKEN' not in request.META: + raise NotAuthenticated("Missing GitLab secret token") + if request.META['HTTP_X_GITLAB_TOKEN'] != repo.hook_token: + raise AuthenticationFailed("Invalid GitLab secret token") + + assert isinstance(request.data, dict) + assert request.data['object_kind'] == 'push' + + if request.data['ref'] not in repo.watched_branches: + return + + # Already took care of this event + if repo.revisions.filter( + ref=request.data['ref'], + hash=request.data['checkout_sha']).exists(): + return + + rev = repo.revisions.create( + hash=request.data['checkout_sha'], + ref=request.data['ref'], + message=request.data['commits'][0]['message'], + author=request.data['commits'][0]['author']['name'], + ) + rev.start_import() + + +git_providers = [ + GitLabProvider, +] +oauth_to_git = { + "GitLabOAuthProvider": GitLabProvider, +} + + +def get_provider(name): + return next(filter(lambda p: p.__name__ == name, git_providers), None) + + +def from_oauth(name): + return oauth_to_git.get(name) diff --git a/arkindex/dataimport/serializers.py b/arkindex/dataimport/serializers.py index fcada061331e44ce746b3e809b90c70584503af0..dbc94a96d4d4d3efe028288bb22d0858dd7c7adb 100644 --- a/arkindex/dataimport/serializers.py +++ b/arkindex/dataimport/serializers.py @@ -2,8 +2,10 @@ from rest_framework import serializers from rest_framework.utils import model_meta from arkindex.project.serializer_fields import EnumField from arkindex.dataimport.models import \ - DataImport, DataImportMode, DataImportState, DataImportFailure, DataFile, Revision + DataImport, DataImportMode, DataImportState, DataImportFailure, DataFile, Repository, Revision +from arkindex.documents.models import Corpus from arkindex.documents.serializers.light import ElementLightSerializer +import gitlab.v4.objects import celery.states @@ -96,6 +98,7 @@ class RevisionSerializer(serializers.ModelSerializer): 'message', 'author', 'commit_url', + 'repo_id', ) @@ -186,3 +189,66 @@ class DataImportFailureSerializer(serializers.ModelSerializer): 'context', 'view_url', ) + + +class RepositoryLightSerializer(serializers.ModelSerializer): + """ + Serialize a repository + """ + + class Meta: + model = Repository + fields = ( + 'id', + 'url', + 'corpus', + ) + extra_kwargs = { + 'id': {'read_only': True}, + 'url': {'read_only': True}, + } + + +class RepositorySerializer(RepositoryLightSerializer): + """ + Fully serialize a repository + """ + + class Meta(RepositoryLightSerializer.Meta): + fields = ( + 'id', + 'url', + 'corpus', + 'watched_branches', + ) + + +class ExternalRepositorySerializer(serializers.BaseSerializer): + """ + Serialize a Git repository from an external API + """ + + def to_representation(self, obj): + if isinstance(obj, gitlab.v4.objects.Project): + return { + "id": obj.id, + "name": obj.name_with_namespace, + "url": obj.web_url, + } + else: + raise NotImplementedError + + def to_internal_value(self, data): + """ + Deserializing only requires a 'id' attribute + """ + if not data.get('id'): + raise serializers.ValidationError({ + 'id': 'This field is required.' + }) + if not data.get('corpus'): + raise serializers.ValidationError({ + 'corpus': 'This field is required.' + }) + + return {'id': data['id'], 'corpus': Corpus.objects.get(id=data['corpus'])} diff --git a/arkindex/dataimport/tasks.py b/arkindex/dataimport/tasks.py index 436e815d277815833fba1d66340f7643d872d7dc..ce2296c8af35a00e11bad20a225a41d1f03e745e 100644 --- a/arkindex/dataimport/tasks.py +++ b/arkindex/dataimport/tasks.py @@ -15,7 +15,6 @@ import os import glob import logging import shutil -import git import urllib.parse root_logger = logging.getLogger(__name__) @@ -109,21 +108,26 @@ def import_images(self, valid_files, dataimport, server_id=settings.LOCAL_IMAGES @shared_task(bind=True, base=ReportingTask) -def clone_repo(self, dataimport): +def download_repo(self, dataimport): assert isinstance(dataimport, DataImport) assert dataimport.mode == DataImportMode.Repository assert dataimport.revision is not None - self.report_progress(0, "Cloning repository...") repo_dir = dataimport.revision.repo.clone_dir if os.path.exists(repo_dir): shutil.rmtree(repo_dir) - repo = git.Repo.clone_from(dataimport.revision.repo.auth_url, repo_dir, no_checkout=True) - + archive_path = "{}.tar.gz".format(repo_dir) commit_hash = dataimport.revision.hash - self.report_progress(0.5, "Checking out commit {}...".format(commit_hash)) - repo.head.reference = repo.create_head('commit_{}'.format(commit_hash), commit_hash) - repo.head.reset(index=True, working_tree=True) + + if dataimport.revision.repo.provider_class is None: + raise ValueError("No repository provider found for {}".format(dataimport.revision.repo.url)) + + self.report_progress(0, "Downloading repository archive at {}...".format(commit_hash)) + dataimport.revision.repo.provider.download_archive(dataimport.revision, archive_path) + + self.report_progress(0.5, "Extracting...") + shutil.unpack_archive(archive_path, repo_dir) + os.remove(archive_path) @shared_task(bind=True, base=ReportingTask) @@ -134,6 +138,9 @@ def import_repo(self, dataimport): self.report_progress(0, "Finding XML files...") xml_files = glob.glob(os.path.join(dataimport.revision.repo.clone_dir, '**/*.xml'), recursive=True) + if len(xml_files) < 1: + self.report_message("No XML files found.", level=logging.WARNING) + for i, xml_file in enumerate(xml_files, 1): filename = os.path.basename(xml_file) self.report_progress(i / len(xml_files), 'Importing file {} of {}: {}'.format(i, len(xml_files), filename)) diff --git a/arkindex/dataimport/tests/test_imports.py b/arkindex/dataimport/tests/test_imports.py index c4f2d6c215e4b76ba4cee72ced93cc7a5611e1f7..d1cd9cc47daebe692ad3c07d9aaaac4d78c6e467 100644 --- a/arkindex/dataimport/tests/test_imports.py +++ b/arkindex/dataimport/tests/test_imports.py @@ -6,7 +6,7 @@ from arkindex.dataimport.models import \ from arkindex.dataimport.serializers import DataImportSerializer, ImagesPayloadSerializer from arkindex.documents.models import Corpus from arkindex.project.tests import RedisMockAPITestCase -from arkindex.users.models import User +from arkindex.users.models import User, OAuthCredentials class TestImports(RedisMockAPITestCase): @@ -184,10 +184,12 @@ class TestImports(RedisMockAPITestCase): repo=Repository.objects.create( url='http://repo', hook_token='token', - clone_user='user', - clone_token='token', corpus=self.corpus, - user=self.user, + credentials=OAuthCredentials.objects.create( + user=self.user, + provider_name='provider', + provider_url='https://somewhere', + ), ), hash='42', ref='ref/heads/master', diff --git a/arkindex/dataimport/tests/test_providers.py b/arkindex/dataimport/tests/test_providers.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8c0b4984ce6adca0a1b8cef35eab5154423820 --- /dev/null +++ b/arkindex/dataimport/tests/test_providers.py @@ -0,0 +1,47 @@ +from unittest.mock import patch +from django.urls import reverse +from rest_framework import status +from arkindex.dataimport.models import Repository +from arkindex.dataimport.providers import GitLabProvider +from arkindex.users.models import OAuthCredentials +from arkindex.project.tests import FixtureAPITestCase + + +class TestProviders(FixtureAPITestCase): + + def setUp(self): + self.creds = OAuthCredentials.objects.create( + user=self.user, + provider_name='GitLabOAuthProvider', + provider_url='https://somewhere', + ) + self.repo = Repository.objects.create( + url='http://repo', + hook_token='token', + corpus=self.corpus, + credentials=self.creds, + provider_name='GitLabProvider', + ) + + def test_init(self): + glp = GitLabProvider() + self.assertEqual(glp.url, GitLabProvider.url) + + glp = GitLabProvider(url='something') + self.assertEqual(glp.url, 'something') + + glp = GitLabProvider(credentials=self.creds) + self.assertEqual(glp.url, self.creds.provider_url) + + with self.assertRaises(Exception): + GitLabProvider(credentials='not a OAuthCredentials') + + glp = GitLabProvider(credentials=self.creds, url='something') + self.assertEqual(glp.url, 'something') + + @patch('arkindex.dataimport.api.Repository.provider_class') + def test_webhook(self, provider_class): + self.client.force_login(self.user) + response = self.client.post(reverse('api:import-hook', kwargs={'pk': self.repo.id})) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertTrue(provider_class.return_value.handle_webhook.called) diff --git a/arkindex/dataimport/urls.py b/arkindex/dataimport/urls.py index 29769765d943e0f83ef4340250b421f23d1f5f7f..8363e3ac93e5e8df800980635fa0f785d67e6bda 100644 --- a/arkindex/dataimport/urls.py +++ b/arkindex/dataimport/urls.py @@ -1,6 +1,8 @@ from django.conf.urls import url from arkindex.dataimport.views import \ - DataImportsList, DataImportCreate, DataImportConfig, DataImportStatus, DataImportFailures, DataFileList + DataImportsList, DataImportCreate, DataImportConfig, DataImportStatus, DataImportFailures, \ + DataFileList, RepositoryList, RepositoryCreate, RepositoryConfig +from arkindex.users.views import CredentialsList, OAuthSignIn, OAuthCallback urlpatterns = [ @@ -10,4 +12,10 @@ urlpatterns = [ url(r'^(?P<pk>[\w\-]+)/status/?$', DataImportStatus.as_view(), name='import-status'), url(r'^(?P<pk>[\w\-]+)/failures/?$', DataImportFailures.as_view(), name='import-failures'), url(r'^files/?$', DataFileList.as_view(), name='files'), + url(r'^repos/?$', RepositoryList.as_view(), name='repositories'), + url(r'^repos/new/?$', RepositoryCreate.as_view(), name='repositories-create'), + url(r'^repos/(?P<pk>[\w\-]+)/?$', RepositoryConfig.as_view(), name='repositories-config'), + url(r'^credentials/?$', CredentialsList.as_view(), name='credentials'), + url(r'^oauth/(?P<provider>\w+)/signin/?$', OAuthSignIn.as_view(), name='oauth-signin'), + url(r'^oauth/(?P<provider>\w+)/callback/?$', OAuthCallback.as_view(), name='oauth-callback'), ] diff --git a/arkindex/dataimport/views.py b/arkindex/dataimport/views.py index 40862a55711551e406f25417899acd5808ea9d07..1ed08eb841e82b4054bc8f6fc1eeea54f274f6bc 100644 --- a/arkindex/dataimport/views.py +++ b/arkindex/dataimport/views.py @@ -1,6 +1,6 @@ from django.views.generic import TemplateView, DetailView from django.contrib.auth.mixins import LoginRequiredMixin -from arkindex.dataimport.models import DataImport, DataImportState +from arkindex.dataimport.models import DataImport, DataImportState, Repository class DataImportsList(LoginRequiredMixin, TemplateView): @@ -61,3 +61,28 @@ class DataFileList(LoginRequiredMixin, TemplateView): View and manage uploaded files """ template_name = 'dataimport/files.html' + + +class RepositoryList(LoginRequiredMixin, TemplateView): + """ + Manage repositories + """ + template_name = 'dataimport/repositories.html' + + +class RepositoryCreate(LoginRequiredMixin, TemplateView): + """ + Create a new repository + """ + template_name = 'dataimport/repository.new.html' + + +class RepositoryConfig(LoginRequiredMixin, DetailView): + """ + Create a new repository + """ + template_name = 'dataimport/repository.config.html' + context_object_name = 'repo' + + def get_queryset(self): + return Repository.objects.filter(credentials__user=self.request.user) diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 7731497c548784bce6649b616748207e14c90d9b..ad4488ce031ff182d56079c7c22f7b526ed71a38 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -438,7 +438,7 @@ class MetaData(models.Model): name = models.CharField(max_length=250) type = EnumField(MetaType, max_length=50, db_index=True) value = models.TextField() - revision = models.ForeignKey('dataimport.Revision', on_delete=models.CASCADE, blank=True, null=True) + revision = models.ForeignKey('dataimport.Revision', on_delete=models.SET_NULL, blank=True, null=True) class Meta: ordering = ('element', 'name') diff --git a/arkindex/documents/tei.py b/arkindex/documents/tei.py index 12cad6434a26ff99e3421e2d0ba2a5a160589345..d6edebc25dc7870d9389e84865fa5ec6b9dba2b8 100644 --- a/arkindex/documents/tei.py +++ b/arkindex/documents/tei.py @@ -138,7 +138,8 @@ class TeiElement(object): ) if created: continue - if (db_meta.type, db_meta.value) == meta: # Nothing to update + if (db_meta.type, db_meta.value) == meta and db_meta.revision: + # Nothing to update and revision is set continue db_meta.type, db_meta.value = meta db_meta.revision = revision diff --git a/arkindex/documents/tests/test_tei.py b/arkindex/documents/tests/test_tei.py index 650bf52c6a513b719f35148147a648c056d2f1e9..4c397e13cc388ec83e0a538d0a41c363cab10885 100644 --- a/arkindex/documents/tests/test_tei.py +++ b/arkindex/documents/tests/test_tei.py @@ -2,6 +2,7 @@ from lxml import etree from arkindex.documents.models import Act from arkindex.documents.tei import Text, TeiParser from arkindex.dataimport.models import Repository, Revision, DataImportFailure +from arkindex.users.models import OAuthCredentials from arkindex.project.tests import FixtureTestCase import os.path @@ -17,10 +18,12 @@ class TestTeiElement(FixtureTestCase): self.repo = Repository.objects.create( url='http://repo', hook_token='token', - clone_user='user', - clone_token='token', corpus=self.corpus, - user=self.user, + credentials=OAuthCredentials.objects.create( + user=self.user, + provider_name='provider', + provider_url='https://somewhere', + ), ) def test_apply_xslt(self): diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index bfce7f7553a2ce442bd0ca971492f5863d51ea76..df0b1491a39a082828e42367c805fe7887455559 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -8,7 +8,9 @@ from arkindex.documents.api import \ ActEdit, TranscriptionCreate, TranscriptionBulk, SurfaceDetails from arkindex.dataimport.api import \ DataImportsList, DataImportDetails, DataImportFailures, \ - DataFileList, DataFileRetrieve, DataFileUpload, GitRepositoryImportHook + DataFileList, DataFileRetrieve, DataFileUpload, \ + GitRepositoryImportHook, RepositoryList, AvailableRepositoriesList, RepositoryRetrieve, RepositoryStartImport +from arkindex.users.api import ProvidersList, CredentialsList, CredentialsRetrieve api = [ @@ -70,10 +72,21 @@ api = [ # Import workflows url(r'^imports/$', DataImportsList.as_view(), name='import-list'), + url(r'^imports/repos/?$', RepositoryList.as_view(), name='repository-list'), + url(r'^imports/repos/(?P<pk>[\w\-]+)/?$', RepositoryRetrieve.as_view(), name='repository-retrieve'), + url(r'^imports/repos/(?P<pk>[\w\-]+)/start?$', RepositoryStartImport.as_view(), name='repository-import'), + url(r'^imports/repos/search/(?P<pk>[\w\-]+)/?$', + AvailableRepositoriesList.as_view(), + name='available-repositories'), url(r'^imports/(?P<pk>[\w\-]+)$', DataImportDetails.as_view(), name='import-details'), url(r'^imports/(?P<pk>[\w\-]+)/failures$', DataImportFailures.as_view(), name='import-failures'), url(r'^imports/files/(?P<pk>[\w\-]+)$', DataFileList.as_view(), name='file-list'), url(r'^imports/file/(?P<pk>[\w\-]+)$', DataFileRetrieve.as_view(), name='file-retrieve'), url(r'^imports/upload/(?P<pk>[\w\-]+)$', DataFileUpload.as_view(), name='file-upload'), url(r'^imports/hook/(?P<pk>[\w\-]+)$', GitRepositoryImportHook.as_view(), name='import-hook'), + + # Manage OAuth integrations + url(r'^oauth/providers/?$', ProvidersList.as_view(), name='providers-list'), + url(r'^oauth/credentials/?$', CredentialsList.as_view(), name='credentials-list'), + url(r'^oauth/credentials/(?P<pk>[\w\-]+)/?$', CredentialsRetrieve.as_view(), name='credentials-retrieve'), ] diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index 4c6e4dfa8b0ad4ce9604a002d6133d5172f51367..c7be6c1063f2ee35801488c57b65bf4d5c0942b3 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -227,6 +227,9 @@ IIIF_TRANSCRIPTION_LIST = False # TEI XSLT file path TEI_XSLT_PATH = os.path.join(BASE_DIR, 'documents/teitohtml.xsl') +# GitLab OAuth +GITLAB_APP_ID = os.environ.get('GITLAB_APP_ID') +GITLAB_APP_SECRET = os.environ.get('GITLAB_APP_SECRET') # Cache into memcached CACHES = { diff --git a/arkindex/templates/base.html b/arkindex/templates/base.html index 097bd77b6c151001d9cd1b0da810521c4210f608..2c160d1dc0a02d9c3277640fb4409e01f19be666 100644 --- a/arkindex/templates/base.html +++ b/arkindex/templates/base.html @@ -33,26 +33,26 @@ <a class="navbar-item" href="{% url 'files' %}"> Files </a> + <a class="navbar-item" href="{% url 'repositories' %}"> + Repositories + </a> {% endif %} </div> <div class="navbar-end"> {% if user.is_authenticated %} - <div class="navbar-item"> - {{ user }} - </div> - <div class="navbar-item"> - <p class="control"> - <a class="button is-info" href="{% url 'logout' %}">Log out</a> - </p> - </div> - {% if user.is_admin %} - <div class="navbar-item"> - <p class="control"> - <a class="button is-primary" href="{% url 'admin:index' %}">Admin</a> - </p> + <div class="navbar-item has-dropdown is-hoverable"> + <a class="navbar-link"> + {{ user }} + </a> + <div class="navbar-dropdown"> + <a href="{% url 'credentials' %}" class="navbar-item">OAuth</a> + {% if user.is_admin %} + <a href="{% url 'admin:index' %}" class="navbar-item">Admin</a> + {% endif %} + <a href="{% url 'logout' %}" class="navbar-item">Log out</a> + </div> </div> - {% endif %} {% else %} <div class="navbar-item"> <p class="control"> diff --git a/arkindex/templates/dataimport/credentials.html b/arkindex/templates/dataimport/credentials.html new file mode 100644 index 0000000000000000000000000000000000000000..fc4e0e077268f0f70212c876a19860e317dc4241 --- /dev/null +++ b/arkindex/templates/dataimport/credentials.html @@ -0,0 +1,10 @@ +{% extends 'base.html' %} + +{% block content %} +<h1 class="title">Credentials</h1> +<h2 class="subtitle">Manage connected applications</h2> + +<div id="app"> + <OAuth-List /> +</div> +{% endblock %} diff --git a/arkindex/templates/dataimport/new.html b/arkindex/templates/dataimport/new.html index 481098d2e8589f62d36a8bcad74f2f6ad96e48d1..97e09cf8cba30759f9375efc3338378e402c23f3 100644 --- a/arkindex/templates/dataimport/new.html +++ b/arkindex/templates/dataimport/new.html @@ -6,4 +6,5 @@ <div id="app"> <Import-Create /> +</div> {% endblock %} diff --git a/arkindex/templates/dataimport/repositories.html b/arkindex/templates/dataimport/repositories.html new file mode 100644 index 0000000000000000000000000000000000000000..49d5ea665ffbc99736f3dbf3e38fd9d34607c07c --- /dev/null +++ b/arkindex/templates/dataimport/repositories.html @@ -0,0 +1,10 @@ +{% extends 'base.html' %} + +{% block content %} +<h1 class="title">Repositories</h1> +<h2 class="subtitle">Manage Git repositories</h2> + +<div id="app"> + <Repos-List /> +</div> +{% endblock %} diff --git a/arkindex/templates/dataimport/repository.config.html b/arkindex/templates/dataimport/repository.config.html new file mode 100644 index 0000000000000000000000000000000000000000..7af0d0b2304fb4eaf3a426f576727a5f1ac14ea1 --- /dev/null +++ b/arkindex/templates/dataimport/repository.config.html @@ -0,0 +1,10 @@ +{% extends 'base.html' %} + +{% block content %} +<h1 class="title">Configure a repository</h1> +<h2 class="subtitle">Change a repository's settings and configure webhooks</h2> + +<div id="app"> + <Repos-Config id="{{ repo.id }}" /> +</div> +{% endblock %} diff --git a/arkindex/templates/dataimport/repository.new.html b/arkindex/templates/dataimport/repository.new.html new file mode 100644 index 0000000000000000000000000000000000000000..6c9644f418aba23da5997bd6fe7db246e7df0503 --- /dev/null +++ b/arkindex/templates/dataimport/repository.new.html @@ -0,0 +1,10 @@ +{% extends 'base.html' %} + +{% block content %} +<h1 class="title">New repository</h1> +<h2 class="subtitle">Add a repository from an external provider</h2> + +<div id="app"> + <Repos-Create /> +</div> +{% endblock %} diff --git a/arkindex/templates/dataimport/status.html b/arkindex/templates/dataimport/status.html index ae580348f2e742a309fd4798ae58b40f06c4624e..465907898c8b3264eb4366e0cc4f2837f16a1304 100644 --- a/arkindex/templates/dataimport/status.html +++ b/arkindex/templates/dataimport/status.html @@ -6,4 +6,5 @@ <div id="app"> <Import-Status id="{{ dataimport.id }}" /> +</div> {% endblock %} diff --git a/arkindex/users/api.py b/arkindex/users/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a25010fd133641cf50bf20fb757a00fde145bdb0 --- /dev/null +++ b/arkindex/users/api.py @@ -0,0 +1,33 @@ +from rest_framework.generics import ListAPIView, RetrieveDestroyAPIView +from rest_framework.permissions import IsAuthenticated +from arkindex.users.providers import oauth_providers +from arkindex.users.serializers import OAuthCredentialsSerializer, OAuthProviderClassSerializer + + +class ProvidersList(ListAPIView): + permission_classes = (IsAuthenticated, ) + serializer_class = OAuthProviderClassSerializer + pagination_class = None + + def get_queryset(self): + return list(filter(lambda p: p.enabled(), oauth_providers)) + + +class CredentialsList(ListAPIView): + permission_classes = (IsAuthenticated, ) + serializer_class = OAuthCredentialsSerializer + + def get_queryset(self): + return self.request.user.credentials.exclude(token=None).order_by('id') + + +class CredentialsRetrieve(RetrieveDestroyAPIView): + permission_classes = (IsAuthenticated, ) + serializer_class = OAuthCredentialsSerializer + + def get_queryset(self): + return self.request.user.credentials.exclude(token=None).order_by('id') + + def perform_destroy(self, instance): + instance.provider_class(request=self.request, credentials=instance).disconnect() + super().perform_destroy(instance) diff --git a/arkindex/users/migrations/0003_oauthcredentials.py b/arkindex/users/migrations/0003_oauthcredentials.py new file mode 100644 index 0000000000000000000000000000000000000000..f3cfcc0006d903f1437cf7b7e94a12a76db186b1 --- /dev/null +++ b/arkindex/users/migrations/0003_oauthcredentials.py @@ -0,0 +1,37 @@ +# Generated by Django 2.1 on 2018-08-17 14:25 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0002_create_tokens'), + ] + + operations = [ + migrations.CreateModel( + name='OAuthCredentials', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('provider_name', models.CharField(max_length=50, choices=[('GitLabOAuthProvider', 'GitLab')])), + ('provider_url', models.URLField()), + ('token', models.CharField(blank=True, max_length=64, null=True)), + ('refresh_token', models.CharField(blank=True, max_length=64, null=True)), + ('expiry', models.DateTimeField(blank=True, null=True)), + ('account_name', models.CharField(max_length=100, blank=True, null=True)), + ], + ), + migrations.AddField( + model_name='oauthcredentials', + name='user', + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='credentials', + to=settings.AUTH_USER_MODEL, + ), + ), + ] diff --git a/arkindex/users/models.py b/arkindex/users/models.py index bded287e15f871499d29dba3e2d03dbf66fd1328..22aa54689e64f035391e14872f4cacf801da8cd4 100644 --- a/arkindex/users/models.py +++ b/arkindex/users/models.py @@ -1,6 +1,8 @@ from django.db import models from django.contrib.auth.models import AbstractBaseUser from arkindex.users.managers import UserManager +from arkindex.users.providers import oauth_providers, get_provider +import uuid class User(AbstractBaseUser): @@ -34,3 +36,34 @@ class User(AbstractBaseUser): "Is the user a member of staff?" # Simplest possible answer: All admins are staff return self.is_admin + + +class OAuthCredentials(models.Model): + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + user = models.ForeignKey('users.User', on_delete=models.CASCADE, related_name='credentials') + provider_name = models.CharField( + max_length=50, + choices=[(p.__name__, p.display_name) for p in oauth_providers], + ) + provider_url = models.URLField() + token = models.CharField(max_length=64, blank=True, null=True) + refresh_token = models.CharField(max_length=64, blank=True, null=True) + expiry = models.DateTimeField(blank=True, null=True) + account_name = models.CharField(max_length=100, blank=True, null=True) + + @property + def provider_class(self): + return get_provider(self.provider_name) + + @property + def provider(self): + return self.provider_class(credentials=self) + + @property + def git_provider_class(self): + from arkindex.dataimport.providers import from_oauth + return from_oauth(self.provider_name) + + @property + def git_provider(self): + return self.git_provider_class(credentials=self) diff --git a/arkindex/users/providers.py b/arkindex/users/providers.py new file mode 100644 index 0000000000000000000000000000000000000000..8c68cf6ae8214ab47496c07378463f462eaf3737 --- /dev/null +++ b/arkindex/users/providers.py @@ -0,0 +1,166 @@ +from abc import ABC, abstractmethod +from django.conf import settings +from django.urls import reverse +from rest_framework.exceptions import NotAuthenticated +from gitlab import Gitlab, GitlabError +import urllib.parse +import datetime +import requests + + +class OAuthProvider(ABC): + """ + An OAuth authentication provider. + """ + + display_name = "" + + def __init__(self, request=None, credentials=None, url=None): + if request is not None: + # Allow Django and Django REST Framework requests + from django.http.request import HttpRequest + from rest_framework.request import Request + assert isinstance(request, (HttpRequest, Request)) + if credentials is not None: + from arkindex.users.models import OAuthCredentials + assert isinstance(credentials, OAuthCredentials) + self.url = credentials.provider_url + elif url is not None: + self.url = url + self.request = request + self.credentials = credentials + + @classmethod + @abstractmethod + def enabled(cls): + """ + Boolean stating the provider's availability. + """ + + @abstractmethod + def get_callback_uri(self): + """ + Get the OAuth callback URI + """ + + @abstractmethod + def get_authorize_uri(self): + """ + Get the OAuth authorization endpoint URI + """ + + @abstractmethod + def handle_callback(self): + """ + Handle a OAuth callback and save token data. Should raise exceptions if the process fails. + """ + + @abstractmethod + def disconnect(self): + """ + Erase token data and logout the user from the service. + """ + + +class GitLabOAuthProvider(OAuthProvider): + + display_name = 'GitLab' + url = 'https://gitlab.com' + authorize_endpoint = '/oauth/authorize' + token_endpoint = '/oauth/token' + + @classmethod + def enabled(cls): + return settings.GITLAB_APP_ID and settings.GITLAB_APP_SECRET + + def get_callback_uri(self): + if not self.request: + return + return self.request.build_absolute_uri( + reverse('oauth-callback', kwargs={'provider': self.__class__.__name__}), + ) + + def get_authorize_uri(self): + if not self.request or not self.credentials: + return + return '{}?{}'.format( + urllib.parse.urljoin(self.credentials.provider_url, self.authorize_endpoint), + urllib.parse.urlencode({ + 'client_id': settings.GITLAB_APP_ID, + 'redirect_uri': self.get_callback_uri(), + 'scope': 'api', + 'response_type': 'code', + 'state': str(self.credentials.id), + }), + ) + + def handle_callback(self): + if not self.request: + return + if not any(param not in self.request.GET for param in ('code', 'error')): + raise ValueError('Callback called without a valid response') + if 'error' in self.request.GET: + raise ValueError(self.request.GET.get('error_description', self.request.GET['error'])) + + state = self.request.GET.get('state') + if not state: + raise ValueError('No state hash') + self.credentials = self.request.user.credentials.get(id=state) + + response = requests.post( + urllib.parse.urljoin(self.credentials.provider_url, self.token_endpoint), + { + 'client_id': settings.GITLAB_APP_ID, + 'client_secret': settings.GITLAB_APP_SECRET, + 'code': self.request.GET.get('code', ''), + 'grant_type': 'authorization_code', + 'redirect_uri': self.get_callback_uri(), + } + ) + response.raise_for_status() + data = response.json() + + assert 'access_token' in data, 'GitLab returned no OAuth token' + assert data.get('token_type', '') == 'bearer', 'Unsupported OAuth token type' + + self.credentials.token = data['access_token'] + self.credentials.refresh_token = data.get('refresh_token') + if 'expires_in' in data: + self.credentials.expiry = datetime.datetime.now() + datetime.timedelta(0, int(data['expires_in'])) + + gl = Gitlab(self.credentials.provider_url, oauth_token=self.credentials.token) + gl.auth() + self.credentials.account_name = gl.user.username + + self.credentials.save() + + def disconnect(self): + if not self.request and self.credentials: + raise NotAuthenticated() + + if not self.credentials.repos.exists(): + return + + # Remove all webhooks + try: + gl = Gitlab(self.url, oauth_token=self.credentials.token) + for repo in self.credentials.repos.all(): + project = gl.projects.get(urllib.parse.urlsplit(repo.url).path.strip('/')) + hook_url = self.request.build_absolute_uri( + reverse('api:import-hook', kwargs={'pk': repo.id}) + ) + # Try to find the webhook + hook = next((h for h in project.hooks.list() if h.url == hook_url), None) + if hook: + hook.delete() + except GitlabError: + pass + + +oauth_providers = [ + GitLabOAuthProvider, +] + + +def get_provider(name): + return next(filter(lambda p: p.__name__ == name, oauth_providers), None) diff --git a/arkindex/users/serializers.py b/arkindex/users/serializers.py new file mode 100644 index 0000000000000000000000000000000000000000..a786bc63cd96fe433b4d68d26a1d6b79b5ee2ab2 --- /dev/null +++ b/arkindex/users/serializers.py @@ -0,0 +1,24 @@ +from rest_framework import serializers +from arkindex.users.models import OAuthCredentials + + +class OAuthCredentialsSerializer(serializers.ModelSerializer): + + provider_display_name = serializers.CharField(source='provider_class.display_name') + + class Meta: + model = OAuthCredentials + fields = ( + 'id', + 'provider_name', + 'provider_display_name', + 'provider_url', + 'account_name', + ) + + +class OAuthProviderClassSerializer(serializers.Serializer): + + name = serializers.CharField(source='__name__') + display_name = serializers.CharField() + default_url = serializers.URLField(source='url') diff --git a/arkindex/users/views.py b/arkindex/users/views.py new file mode 100644 index 0000000000000000000000000000000000000000..452e4d4a5507d73ca14b8c213840c5b5bed05150 --- /dev/null +++ b/arkindex/users/views.py @@ -0,0 +1,46 @@ +from django.views.generic import RedirectView, TemplateView +from django.contrib.auth.mixins import LoginRequiredMixin +from arkindex.users.providers import get_provider + + +class OAuthSignIn(LoginRequiredMixin, RedirectView): + """ + Start a OAuth authentication workflow + """ + + def get_redirect_url(self, *args, **kwargs): + if 'provider' not in kwargs: + return + provider = get_provider(kwargs['provider']) + if not provider: + return + url = self.request.GET.get('url', provider.url) + # Create OAuthCredentials without a token + creds = self.request.user.credentials.create( + provider_name=kwargs['provider'], + provider_url=url, + ) + return provider(request=self.request, credentials=creds).get_authorize_uri() + + +class OAuthCallback(LoginRequiredMixin, RedirectView): + """ + Callback for OAuth responses + """ + + pattern_name = 'credentials' + + def get(self, request, *args, **kwargs): + assert 'provider' in kwargs + provider = get_provider(kwargs['provider']) + if not provider: + raise ValueError('Unknown provider') + provider(self.request).handle_callback() + return super().get(request) + + +class CredentialsList(LoginRequiredMixin, TemplateView): + """ + View and manage OAuth providers + """ + template_name = 'dataimport/credentials.html' diff --git a/requirements.txt b/requirements.txt index 2be728a555ff55b0dbec62e6ef64fc84cb963555..d502a0496723dcb5d774c3fe87334f5c2ebf2732 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ gitpython==2.1.11 idna==2.6 jdcal==1.3 olefile==0.44 +python-gitlab==1.5.1 python-magic==0.4.15 python-memcached==1.59 pytz==2017.2