Skip to content
Snippets Groups Projects
Commit cd466f78 authored by Bastien Abadie's avatar Bastien Abadie
Browse files

API for classification demo

See merge request !81
parents bdccbc76 86e16dca
No related branches found
No related tags found
1 merge request!81Api demo
from django.shortcuts import get_object_or_404
from django.core.exceptions import PermissionDenied
from rest_framework.generics import \
ListAPIView, ListCreateAPIView, RetrieveUpdateDestroyAPIView, RetrieveAPIView
from rest_framework.views import APIView
......@@ -7,7 +8,7 @@ from rest_framework.permissions import IsAuthenticated, IsAdminUser
from rest_framework.response import Response
from rest_framework import status
from rest_framework.exceptions import ValidationError
from arkindex.documents.models import Corpus
from arkindex.documents.models import Corpus, Right
from arkindex.dataimport.models import \
DataImport, DataFile, DataImportState, DataImportMode, DataImportFailure, Repository
from arkindex.dataimport.serializers import \
......@@ -123,6 +124,10 @@ class DataFileUpload(APIView):
raise ValidationError({'corpus': ['Corpus not found']})
corpus = corpus_qs.get()
# Check corpus is writable for current user
if Right.Write not in corpus.get_acl_rights(self.request.user):
raise PermissionDenied
file_obj = request.FILES['file']
md5hash = hashlib.md5()
......
......@@ -6,6 +6,7 @@ from celery import states
from celery.canvas import Signature
from celery.result import AsyncResult, GroupResult
from enumfields import EnumField, Enum
from arkindex.project.celery import app as celery_app
from arkindex.dataimport.providers import git_providers, get_provider
from arkindex.project.models import IndexableModel
from arkindex.project.fields import ArrayField
......@@ -73,12 +74,14 @@ class DataImport(IndexableModel):
def build_workflow(self):
if self.mode == DataImportMode.Images:
# Prevent circular imports
from arkindex.dataimport.tasks import check_images, import_images
return check_images.s(self) | import_images.s(self)
from arkindex.dataimport.tasks import check_images, import_images, save_classification
classify = celery_app.signature('arkindex_ml.tasks.classify_pages')
return check_images.s(self) | import_images.s(self) | classify | save_classification.s(dataimport_id=self.id) # noqa
elif self.mode == DataImportMode.Repository:
from arkindex.dataimport.tasks import download_repo, import_repo, cleanup_repo
return download_repo.si(self) | import_repo.si(self) | cleanup_repo.si(self)
else:
raise NotImplementedError
......
......@@ -5,7 +5,7 @@ from celery.states import EXCEPTION_STATES
from django.conf import settings
from django.db import transaction
from arkindex.project.celery import ReportingTask
from arkindex.documents.models import Element, ElementType
from arkindex.documents.models import Element, ElementType, Page
from arkindex.documents.importer import import_page
from arkindex.documents.tei import TeiParser
from arkindex.images.models import ImageServer, ImageStatus
......@@ -82,6 +82,7 @@ def import_images(self, valid_files, dataimport, server_id=settings.LOCAL_IMAGES
)
datafiles = dataimport.files.all()
pages = []
for i, datafile in enumerate(datafiles, 1):
self.report_progress(i / len(datafiles), 'Importing image {} of {}'.format(i, len(datafiles)))
......@@ -101,10 +102,32 @@ def import_images(self, valid_files, dataimport, server_id=settings.LOCAL_IMAGES
}
)
import_page(vol, img, volume_name, str(i), i)
page = import_page(vol, img, volume_name, str(i), i)
pages.append((page.id, img.get_thumbnail_url(max_width=500)))
self.report_message("Imported files into {}".format(vol))
return {'volume': str(vol.id)}
return {
'volume': str(vol.id),
'pages': pages,
}
@shared_task(bind=True, base=ReportingTask)
def save_classification(self, classification, **kwargs):
'''
Save ML classification on images into our DB
'''
assert isinstance(classification, dict)
for page_id, classes in classification.items():
try:
page = Page.objects.get(pk=page_id)
page.classification = classes
page.save()
except Page.DoesNotExist:
continue
self.report_message("Saved classification for {}".format(page))
@shared_task(bind=True, base=ReportingTask)
......@@ -179,9 +202,21 @@ def dataimport_postrun(task_id, task, state, args=(), **kwargs):
# Look for dataimport in args
imps = [a for a in args if isinstance(a, DataImport)]
assert len(imps) == 1, 'No args on dataimport task.'
dataimport = imps[0]
dataimport.refresh_from_db() # avoid inconsistency
if len(imps) == 1:
dataimport = imps[0]
dataimport.refresh_from_db() # avoid inconsistency
elif 'kwargs' in kwargs and 'dataimport_id' in kwargs['kwargs']:
# Try to load a dataimport from its id
# Needed to communicate cleanly with worker-ml (no django instances)
try:
dataimport = DataImport.objects.get(pk=kwargs['kwargs']['dataimport_id'])
except DataImport.DoesNotExists:
raise Exception('No dataimport found in args')
else:
raise Exception('No args on dataimport task.')
assert isinstance(dataimport, DataImport), \
'DataImport not found as first arg of task {}'.format(task)
......
......@@ -19,6 +19,9 @@ class TestFiles(APITestCase):
self.corpus = Corpus.objects.create(name='Unit Tests')
self.user = User.objects.create_user('test@test.test', 'testtest')
# Add write access
self.user.corpus_right.create(corpus=self.corpus, can_write=True)
def test_file_list(self):
df = DataFile.objects.create(
name='test.txt', size=42, hash='aaaa', content_type='text/plain', corpus=self.corpus)
......@@ -42,6 +45,17 @@ class TestFiles(APITestCase):
response = self.client.get(reverse('api:file-list', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_file_upload_ro(self):
'''
An upload on a read only corpus should fail
'''
public = Corpus.objects.create(name='Unit Tests', public=True)
f = SimpleUploadedFile('test.txt', b'This is a text file')
self.client.force_login(self.user)
response = self.client.post(reverse('api:file-upload', kwargs={'pk': public.id}), data={'file': f})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_file_upload(self):
"""
Assert a file upload creates a database instance and saves the file
......
from arkindex.project.tests import RedisMockAPITestCase
from arkindex.dataimport.tasks import save_classification
from arkindex.documents.models import Page, Corpus
class TestTasks(RedisMockAPITestCase):
"""
Test data imports tasks
"""
def test_save_classification(self):
corpus = Corpus.objects.create(name='test class')
dog = Page.objects.create(corpus=corpus, name='A dog')
cat = Page.objects.create(corpus=corpus, name='A cat')
classification = {
dog.id: [
{
'label': 'dog',
'probability': 0.9,
}
],
cat.id: [
{
'label': 'cat',
'probability': 0.8,
}
]
}
save_classification(classification)
dog.refresh_from_db()
self.assertEqual(dog.classification, [{
'label': 'dog',
'probability': 0.9,
}])
cat.refresh_from_db()
self.assertEqual(cat.classification, [{
'label': 'cat',
'probability': 0.8,
}])
# Generated by Django 2.1 on 2018-08-21 16:06
import django.contrib.postgres.fields.jsonb
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('documents', '0022_new_corpus_ids'),
]
operations = [
migrations.AddField(
model_name='page',
name='classification',
field=django.contrib.postgres.fields.jsonb.JSONField(blank=True, null=True),
),
migrations.AlterField(
model_name='metadata',
name='revision',
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to='dataimport.Revision',
),
),
]
from django.db import models, transaction
from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.fields import JSONField
from enumfields import EnumField, Enum
from arkindex.project.models import IndexableModel
from arkindex.project.celery import app as celery_app
......@@ -283,6 +284,9 @@ class Page(Element):
"""
folio = models.CharField(max_length=250)
# Machine learning classes
classification = JSONField(null=True, blank=True)
# Parsed folio
page_type = EnumField(PageType, max_length=50, null=True, blank=True)
nb = models.PositiveIntegerField(null=True, blank=True)
......@@ -361,14 +365,18 @@ class Page(Element):
out.append(self.direction.value)
return ' '.join(out)
def ml_classify(self):
def classify(self):
'''
Use a machine learning worker to classify the page
using its image
Celery is triggered through an external signature
'''
# Classify a page through ML worker
signature = celery_app.signature('arkindex_ml.tasks.classify')
return signature.delay(self.zone.image.get_thumbnail_url(max_width=None))
task = signature.delay(self.zone.image.get_thumbnail_url(max_width=500))
# Wait for result
self.classification = task.get()
class Act(Element):
......
......@@ -42,6 +42,7 @@ class PageLightSerializer(serializers.ModelSerializer):
'direction',
'display_name',
'image',
'classification',
)
......
......@@ -42,6 +42,8 @@ class ExtendedRedisBackend(RedisBackend):
key = self.get_key_for_task(task_id)
existing = self.get(key)
meta = existing and self.decode(existing) or {}
if 'messages' not in meta:
meta['messages'] = []
meta['messages'].append(message)
self.set(key, self.encode(meta))
......
......@@ -369,3 +369,11 @@ try:
INSTALLED_APPS.append('debug_toolbar')
except ImportError:
pass
try:
import corsheaders # noqa
MIDDLEWARE.insert(0, 'corsheaders.middleware.CorsMiddleware')
INSTALLED_APPS.append('corsheaders')
CORS_ORIGIN_WHITELIST = ('localhost:5000', )
except ImportError:
pass
......@@ -50,6 +50,10 @@ class RedisMockMixin(object):
for m in self.mocked:
m.return_value = self.redis
# Patch the add_message
self.messages = patch('arkindex.project.celery.ExtendedRedisBackend.add_message')
self.messages.start()
def tearDown(self):
for p in self.patches:
p.stop()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment