diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index 4949d81a7c06810fa97e1d347a66753f139aee73..a979eb92be52a2053b7a6d218a411e27393cf384 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -1,4 +1,5 @@ from django.shortcuts import get_object_or_404 +from django.core.exceptions import PermissionDenied from rest_framework.generics import \ ListAPIView, ListCreateAPIView, RetrieveUpdateDestroyAPIView, RetrieveAPIView from rest_framework.views import APIView @@ -7,7 +8,7 @@ from rest_framework.permissions import IsAuthenticated, IsAdminUser from rest_framework.response import Response from rest_framework import status from rest_framework.exceptions import ValidationError -from arkindex.documents.models import Corpus +from arkindex.documents.models import Corpus, Right from arkindex.dataimport.models import \ DataImport, DataFile, DataImportState, DataImportMode, DataImportFailure, Repository from arkindex.dataimport.serializers import \ @@ -123,6 +124,10 @@ class DataFileUpload(APIView): raise ValidationError({'corpus': ['Corpus not found']}) corpus = corpus_qs.get() + # Check corpus is writable for current user + if Right.Write not in corpus.get_acl_rights(self.request.user): + raise PermissionDenied + file_obj = request.FILES['file'] md5hash = hashlib.md5() diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py index a3b887ddc1e2308146c54c7191175344289f44be..48c69c11fae5bbe70775dfdbcb86305b471ce8eb 100644 --- a/arkindex/dataimport/models.py +++ b/arkindex/dataimport/models.py @@ -6,6 +6,7 @@ from celery import states from celery.canvas import Signature from celery.result import AsyncResult, GroupResult from enumfields import EnumField, Enum +from arkindex.project.celery import app as celery_app from arkindex.dataimport.providers import git_providers, get_provider from arkindex.project.models import IndexableModel from arkindex.project.fields import ArrayField @@ -73,12 +74,14 @@ class DataImport(IndexableModel): def build_workflow(self): if self.mode == DataImportMode.Images: - # Prevent circular imports - from arkindex.dataimport.tasks import check_images, import_images - return check_images.s(self) | import_images.s(self) + from arkindex.dataimport.tasks import check_images, import_images, save_classification + classify = celery_app.signature('arkindex_ml.tasks.classify_pages') + return check_images.s(self) | import_images.s(self) | classify | save_classification.s(dataimport_id=self.id) # noqa + elif self.mode == DataImportMode.Repository: from arkindex.dataimport.tasks import download_repo, import_repo, cleanup_repo return download_repo.si(self) | import_repo.si(self) | cleanup_repo.si(self) + else: raise NotImplementedError diff --git a/arkindex/dataimport/tasks.py b/arkindex/dataimport/tasks.py index ce2296c8af35a00e11bad20a225a41d1f03e745e..7eeefe0d9499929a28eb7074074104cbf79ab830 100644 --- a/arkindex/dataimport/tasks.py +++ b/arkindex/dataimport/tasks.py @@ -5,7 +5,7 @@ from celery.states import EXCEPTION_STATES from django.conf import settings from django.db import transaction from arkindex.project.celery import ReportingTask -from arkindex.documents.models import Element, ElementType +from arkindex.documents.models import Element, ElementType, Page from arkindex.documents.importer import import_page from arkindex.documents.tei import TeiParser from arkindex.images.models import ImageServer, ImageStatus @@ -82,6 +82,7 @@ def import_images(self, valid_files, dataimport, server_id=settings.LOCAL_IMAGES ) datafiles = dataimport.files.all() + pages = [] for i, datafile in enumerate(datafiles, 1): self.report_progress(i / len(datafiles), 'Importing image {} of {}'.format(i, len(datafiles))) @@ -101,10 +102,32 @@ def import_images(self, valid_files, dataimport, server_id=settings.LOCAL_IMAGES } ) - import_page(vol, img, volume_name, str(i), i) + page = import_page(vol, img, volume_name, str(i), i) + pages.append((page.id, img.get_thumbnail_url(max_width=500))) self.report_message("Imported files into {}".format(vol)) - return {'volume': str(vol.id)} + return { + 'volume': str(vol.id), + 'pages': pages, + } + + +@shared_task(bind=True, base=ReportingTask) +def save_classification(self, classification, **kwargs): + ''' + Save ML classification on images into our DB + ''' + assert isinstance(classification, dict) + + for page_id, classes in classification.items(): + try: + page = Page.objects.get(pk=page_id) + page.classification = classes + page.save() + except Page.DoesNotExist: + continue + + self.report_message("Saved classification for {}".format(page)) @shared_task(bind=True, base=ReportingTask) @@ -179,9 +202,21 @@ def dataimport_postrun(task_id, task, state, args=(), **kwargs): # Look for dataimport in args imps = [a for a in args if isinstance(a, DataImport)] - assert len(imps) == 1, 'No args on dataimport task.' - dataimport = imps[0] - dataimport.refresh_from_db() # avoid inconsistency + if len(imps) == 1: + dataimport = imps[0] + dataimport.refresh_from_db() # avoid inconsistency + + elif 'kwargs' in kwargs and 'dataimport_id' in kwargs['kwargs']: + # Try to load a dataimport from its id + # Needed to communicate cleanly with worker-ml (no django instances) + try: + dataimport = DataImport.objects.get(pk=kwargs['kwargs']['dataimport_id']) + except DataImport.DoesNotExists: + raise Exception('No dataimport found in args') + + else: + raise Exception('No args on dataimport task.') + assert isinstance(dataimport, DataImport), \ 'DataImport not found as first arg of task {}'.format(task) diff --git a/arkindex/dataimport/tests/test_files.py b/arkindex/dataimport/tests/test_files.py index f6cd813600bfbc951a2ea4c01953f15c516476be..0029fef3249d0b3279ccb730ace8468757cf664e 100644 --- a/arkindex/dataimport/tests/test_files.py +++ b/arkindex/dataimport/tests/test_files.py @@ -19,6 +19,9 @@ class TestFiles(APITestCase): self.corpus = Corpus.objects.create(name='Unit Tests') self.user = User.objects.create_user('test@test.test', 'testtest') + # Add write access + self.user.corpus_right.create(corpus=self.corpus, can_write=True) + def test_file_list(self): df = DataFile.objects.create( name='test.txt', size=42, hash='aaaa', content_type='text/plain', corpus=self.corpus) @@ -42,6 +45,17 @@ class TestFiles(APITestCase): response = self.client.get(reverse('api:file-list', kwargs={'pk': self.corpus.id})) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + def test_file_upload_ro(self): + ''' + An upload on a read only corpus should fail + ''' + public = Corpus.objects.create(name='Unit Tests', public=True) + f = SimpleUploadedFile('test.txt', b'This is a text file') + self.client.force_login(self.user) + + response = self.client.post(reverse('api:file-upload', kwargs={'pk': public.id}), data={'file': f}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + def test_file_upload(self): """ Assert a file upload creates a database instance and saves the file diff --git a/arkindex/dataimport/tests/test_tasks.py b/arkindex/dataimport/tests/test_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..33ba0dbd330245f99de827c86eaca6f83ea06729 --- /dev/null +++ b/arkindex/dataimport/tests/test_tasks.py @@ -0,0 +1,41 @@ +from arkindex.project.tests import RedisMockAPITestCase +from arkindex.dataimport.tasks import save_classification +from arkindex.documents.models import Page, Corpus + + +class TestTasks(RedisMockAPITestCase): + """ + Test data imports tasks + """ + def test_save_classification(self): + corpus = Corpus.objects.create(name='test class') + dog = Page.objects.create(corpus=corpus, name='A dog') + cat = Page.objects.create(corpus=corpus, name='A cat') + + classification = { + dog.id: [ + { + 'label': 'dog', + 'probability': 0.9, + } + ], + cat.id: [ + { + 'label': 'cat', + 'probability': 0.8, + } + ] + } + save_classification(classification) + + dog.refresh_from_db() + self.assertEqual(dog.classification, [{ + 'label': 'dog', + 'probability': 0.9, + }]) + + cat.refresh_from_db() + self.assertEqual(cat.classification, [{ + 'label': 'cat', + 'probability': 0.8, + }]) diff --git a/arkindex/documents/migrations/0023_auto_20180821_1606.py b/arkindex/documents/migrations/0023_auto_20180821_1606.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0504d145d3953520adb4cbca016b9c5a4f69c4 --- /dev/null +++ b/arkindex/documents/migrations/0023_auto_20180821_1606.py @@ -0,0 +1,30 @@ +# Generated by Django 2.1 on 2018-08-21 16:06 + +import django.contrib.postgres.fields.jsonb +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('documents', '0022_new_corpus_ids'), + ] + + operations = [ + migrations.AddField( + model_name='page', + name='classification', + field=django.contrib.postgres.fields.jsonb.JSONField(blank=True, null=True), + ), + migrations.AlterField( + model_name='metadata', + name='revision', + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to='dataimport.Revision', + ), + ), + ] diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index f66d0e4451ea05c1e8c8f3c49e657c56024f8efb..467a165dec0ae1ab251da857c06405eb8a2015c7 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -1,5 +1,6 @@ from django.db import models, transaction from django.contrib.postgres.indexes import GinIndex +from django.contrib.postgres.fields import JSONField from enumfields import EnumField, Enum from arkindex.project.models import IndexableModel from arkindex.project.celery import app as celery_app @@ -283,6 +284,9 @@ class Page(Element): """ folio = models.CharField(max_length=250) + # Machine learning classes + classification = JSONField(null=True, blank=True) + # Parsed folio page_type = EnumField(PageType, max_length=50, null=True, blank=True) nb = models.PositiveIntegerField(null=True, blank=True) @@ -361,14 +365,18 @@ class Page(Element): out.append(self.direction.value) return ' '.join(out) - def ml_classify(self): + def classify(self): ''' Use a machine learning worker to classify the page using its image Celery is triggered through an external signature ''' + # Classify a page through ML worker signature = celery_app.signature('arkindex_ml.tasks.classify') - return signature.delay(self.zone.image.get_thumbnail_url(max_width=None)) + task = signature.delay(self.zone.image.get_thumbnail_url(max_width=500)) + + # Wait for result + self.classification = task.get() class Act(Element): diff --git a/arkindex/documents/serializers/elements.py b/arkindex/documents/serializers/elements.py index 350376d3266f11fbc7c1a1f048bce5599da4793e..031029cab9c620f599497dc1bed07de431d2002a 100644 --- a/arkindex/documents/serializers/elements.py +++ b/arkindex/documents/serializers/elements.py @@ -42,6 +42,7 @@ class PageLightSerializer(serializers.ModelSerializer): 'direction', 'display_name', 'image', + 'classification', ) diff --git a/arkindex/project/celery.py b/arkindex/project/celery.py index 02a18c1823f79f8a698d1394268a235c183db725..702479d4cbe07a09b0bc7216bc073f3573ff4814 100644 --- a/arkindex/project/celery.py +++ b/arkindex/project/celery.py @@ -42,6 +42,8 @@ class ExtendedRedisBackend(RedisBackend): key = self.get_key_for_task(task_id) existing = self.get(key) meta = existing and self.decode(existing) or {} + if 'messages' not in meta: + meta['messages'] = [] meta['messages'].append(message) self.set(key, self.encode(meta)) diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index 82d4430b726aa689b1d853ded14e926b83cbfed3..a6bc942a2dbb0c9fdecffa6319a52db20a702dd9 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -369,3 +369,11 @@ try: INSTALLED_APPS.append('debug_toolbar') except ImportError: pass + +try: + import corsheaders # noqa + MIDDLEWARE.insert(0, 'corsheaders.middleware.CorsMiddleware') + INSTALLED_APPS.append('corsheaders') + CORS_ORIGIN_WHITELIST = ('localhost:5000', ) +except ImportError: + pass diff --git a/arkindex/project/tests/__init__.py b/arkindex/project/tests/__init__.py index 6734f1c1b2562ca01a148be44cf660899a98792c..74b0c4a9fc895a02f804ea748841a756a058bfc9 100644 --- a/arkindex/project/tests/__init__.py +++ b/arkindex/project/tests/__init__.py @@ -50,6 +50,10 @@ class RedisMockMixin(object): for m in self.mocked: m.return_value = self.redis + # Patch the add_message + self.messages = patch('arkindex.project.celery.ExtendedRedisBackend.add_message') + self.messages.start() + def tearDown(self): for p in self.patches: p.stop()