Skip to content
Snippets Groups Projects
Commit 2b0e1ed4 authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

ML tools

parent e03875b4
No related branches found
No related tags found
No related merge requests found
Showing
with 281 additions and 40 deletions
......@@ -27,6 +27,8 @@ backend-tests:
before_script:
- apk --update add build-base
# Custom line to install arkindex-common from Git using GitLab CI credentials
- "pip install -e git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/common#egg=arkindex-common"
- pip install -r tests-requirements.txt codecov
script:
......
......@@ -2,8 +2,19 @@ FROM registry.gitlab.com/arkindex/backend:base-0.8.8
ARG FRONTEND_BRANCH=master
ARG FRONTEND_ID=4675768
ARG COMMON_BRANCH=master
ARG COMMON_ID=9855787
ARG GITLAB_TOKEN="M98p3wihATgCx4Z5ChvK"
# Install arkindex-common from private repo
RUN \
apk add tar && \
mkdir /tmp/common && \
wget --header "PRIVATE-TOKEN: $GITLAB_TOKEN" https://gitlab.com/api/v4/projects/$COMMON_ID/repository/archive.tar.gz?sha=$COMMON_BRANCH -O /tmp/common/archive.tar.gz && \
tar --strip-components=1 -xvf /tmp/common/archive.tar.gz -C /tmp/common && \
cd /tmp/common && python setup.py install && \
rm -rf /tmp/common
# Install arkindex and its deps
# Uses a source archive instead of full local copy to speedup docker build
COPY dist/arkindex-*.tar.gz /tmp/arkindex.tar.gz
......
......@@ -19,6 +19,7 @@ Dev Setup
git clone git@gitlab.com:arkindex/backend.git
mkvirtualenv -p /usr/bin/python3 ark
cd backend
pip install -r tests-requirements.txt
pip install -e .[test]
```
......
0.8.8
0.8.9.dev
......@@ -20,9 +20,10 @@ from arkindex.dataimport.models import \
from arkindex.dataimport.serializers import (
DataImportLightSerializer, DataImportSerializer, DataImportFromFilesSerializer,
DataImportFailureSerializer, DataFileSerializer,
RepositorySerializer, ExternalRepositorySerializer, EventSerializer
RepositorySerializer, ExternalRepositorySerializer, EventSerializer, MLToolSerializer,
)
from arkindex.users.models import OAuthCredentials
from arkindex_common.ml_tool import MLTool
from datetime import datetime
import hashlib
import magic
......@@ -375,3 +376,17 @@ class ElementHistory(ListAPIView):
element_id=self.kwargs['pk'],
element__corpus__in=Corpus.objects.readable(self.request.user),
)
class MLToolList(ListAPIView):
"""
List available machine learning tools
"""
serializer_class = MLToolSerializer
pagination_class = None
def get_queryset(self):
return sorted(
MLTool.list(settings.ML_CLASSIFIERS_DIR),
key=lambda tool: tool.slug,
)
import fnmatch
import yaml
from enum import Enum
from django.conf import settings
from django.core.validators import URLValidator
from django.core.exceptions import ValidationError
from django.utils.functional import cached_property
from arkindex_common.ml_tool import MLTool, MLToolType
from arkindex.images.models import ImageServer
......@@ -44,7 +46,7 @@ class ConfigFile(object):
"""
FILE_NAME = '.arkindex.yml'
REQUIRED_ITEMS = ('version', 'branches')
REQUIRED_ITEMS = {'version', 'branches', 'classifier', 'recognizer'}
FORMAT_ENUMS = {
ImportType.Volumes: VolumesImportFormat,
ImportType.Transcriptions: TranscriptionsImportFormat,
......@@ -56,7 +58,10 @@ class ConfigFile(object):
def __init__(self, data=None):
if not data:
return
parsed = yaml.load(data)
if isinstance(data, dict):
parsed = data
else:
parsed = yaml.load(data)
self.validated_data = self.validate(parsed)
self.setattrs(self.validated_data)
......@@ -93,6 +98,19 @@ class ConfigFile(object):
raise ValidationError("Bad 'branches' format: should be a list of branch names")
validated_data['branches'] = parsed['branches']
# ML tools
try:
MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Classifier, parsed['classifier'])
except ValueError:
raise ValidationError('Classifier "{}" not found'.format(parsed['classifier']))
validated_data['classifier'] = parsed['classifier']
try:
MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Recognizer, parsed['recognizer'])
except ValueError:
raise ValidationError('Recognizer "{}" not found'.format(parsed['recognizer']))
validated_data['recognizer'] = parsed['recognizer']
# At least one import type required
if not any(it.value in parsed for it in ImportType):
raise ValidationError("No import types were specified")
......@@ -161,6 +179,8 @@ class ConfigFile(object):
self.version = validated_data['version']
self.branches = validated_data['branches']
self.imports = list(filter(lambda it: it.value in validated_data, ImportType))
self.classifier_slug = validated_data['classifier']
self.recognizer_slug = validated_data['recognizer']
# Default formats
self.volumes_format = VolumesImportFormat.IIIF
......@@ -223,3 +243,11 @@ class ConfigFile(object):
s, _ = ImageServer.objects.get_or_create(url=url, defaults={'name': name})
servers.append(s)
return servers
@cached_property
def classifier(self):
return MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Classifier, self.classifier_slug)
@cached_property
def recognizer(self):
return MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Recognizer, self.recognizer_slug)
......@@ -26,7 +26,6 @@ class Command(BaseCommand):
pk__in=Element.objects.get_descending(volume.id, type=ElementType.Page)
)
payload = {
'volume': str(volume.id),
'pages': [
(page.id, page.zone.image.get_thumbnail_url(max_width=1500))
for page in pages
......
......@@ -7,6 +7,7 @@ from arkindex.dataimport.models import (
)
from arkindex.documents.models import Corpus, Element, ElementType
from arkindex.documents.serializers.light import ElementLightSerializer
from arkindex_common.ml_tool import MLToolType
import gitlab.v4.objects
import celery.states
......@@ -333,3 +334,13 @@ class EventSerializer(serializers.ModelSerializer):
'date',
'revision',
)
class MLToolSerializer(serializers.Serializer):
"""
Serialize a machine learning tool for display
"""
name = serializers.CharField()
slug = serializers.SlugField()
type = EnumField(MLToolType)
version = serializers.CharField()
......@@ -58,8 +58,9 @@ def build_volume(self, files, dataimport, server_id=settings.LOCAL_IMAGESERVER_I
self.report_message("Imported {} pages into {}".format(len(pages), volume.name))
generate_thumbnail(volume.id)
return {
'volume': str(volume.id),
'pages': pages,
'classifier': settings.ML_DEFAULT_CLASSIFIER,
'recognizer': settings.ML_DEFAULT_RECOGNIZER,
}
......
......@@ -59,6 +59,8 @@ def parse_config(self, dataimport):
self.report_message('Configuration file version {}'.format(config.version))
self.report_message('Git branches to trigger imports on: {}'.format(', '.join(config.branches)))
self.report_message('Classifier: {}'.format(config.classifier.name))
self.report_message('Recognizer: {}'.format(config.recognizer.name))
self.report_message('Active import types: {}'.format(', '.join(import_type.name for import_type in config.imports)))
for import_type in config.imports:
......
version: 1
branches:
- master
classifier: dummy_classifier
recognizer: dummy_recognizer
volumes:
format: iiif
......
---
name: Unit test classifier
type: classifier
slug: dummy_classifier
version: 1.0.0
classes:
- class_1
- class_2
---
name: Unit test recognizer
type: recognizer
slug: dummy_recognizer
version: 0.4.2
tesseract: []
from django.core.exceptions import ValidationError
from django.test import override_settings
from arkindex.project.tests import FixtureTestCase
from arkindex.images.models import ImageServer
from arkindex.dataimport.config import \
ConfigFile, ImportType, VolumesImportFormat, SurfacesImportFormat
from arkindex_common.ml_tool import MLToolType
import os.path
ML_TOOLS = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'ml_tools',
)
@override_settings(ML_CLASSIFIERS_DIR=ML_TOOLS)
class TestConfigFile(FixtureTestCase):
"""
Tests for the .arkindex.yml configuration file parser
"""
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.sample_server = ImageServer.objects.create(name='sample', url='http://example.com/iiif/')
def setUp(self):
# A valid configuration; tests will edit/remove values to check for validations
self.base_data = {
"version": 1,
"branches": ["master", "dev"],
"classifier": "dummy_classifier",
"recognizer": "dummy_recognizer",
"volumes": {
"format": "txt",
"paths": ["volumes/*", "manifests/*"],
"image_servers": {
"sample": "http://example.com/iiif/",
},
"lazy_checks": True,
"autoconvert_https": True,
},
"surfaces": {
"format": "xml",
"paths": ["surfaces/*"],
}
}
def test_base(self):
"""
Test parsing a valid configuration file
"""
cfg = ConfigFile(data=self.base_data)
self.assertEqual(cfg.version, 1)
self.assertListEqual(cfg.branches, ['master', 'dev'])
self.assertCountEqual(cfg.imports, [ImportType.Volumes, ImportType.Surfaces])
self.assertEqual(cfg.volumes_format, VolumesImportFormat.TXT)
self.assertEqual(cfg.surfaces_format, SurfacesImportFormat.XML)
self.assertListEqual(cfg.volumes_paths, ['volumes/*', 'manifests/*'])
self.assertListEqual(cfg.surfaces_paths, ['surfaces/*'])
self.assertEqual(cfg.classifier.slug, 'dummy_classifier')
self.assertEqual(cfg.recognizer.slug, 'dummy_recognizer')
self.assertEqual(str(cfg.classifier.version), '1.0.0')
self.assertEqual(str(cfg.recognizer.version), '0.4.2')
self.assertEqual(cfg.classifier.name, 'Unit test classifier')
self.assertEqual(cfg.classifier.type, MLToolType.Classifier)
self.assertEqual(cfg.recognizer.name, 'Unit test recognizer')
self.assertEqual(cfg.recognizer.type, MLToolType.Recognizer)
self.assertTrue(cfg.volumes_lazy_checks)
self.assertTrue(cfg.volumes_autoconvert_https)
self.assertEqual(cfg.volumes_image_servers, [self.sample_server])
def test_version(self):
"""
Test the ConfigFile handles versioning
"""
self.base_data['version'] = 42 # Bad version number
with self.assertRaisesRegex(ValidationError, 'version'):
ConfigFile(data=self.base_data)
def test_branches(self):
"""
Test the ConfigFile handles branches
"""
self.base_data['branches'] = ['this', 'is', 'a', {'bad': 'format'}]
with self.assertRaisesRegex(ValidationError, 'branches'):
ConfigFile(data=self.base_data)
def test_ml_tools(self):
"""
Test the config parser checks for existence of ML tools
"""
self.base_data['classifier'] = 'nope'
with self.assertRaisesRegex(ValidationError, 'Classifier'):
ConfigFile(data=self.base_data)
self.base_data['classifier'] = 'dummy_classifier'
self.base_data['recognizer'] = 'nope'
with self.assertRaisesRegex(ValidationError, 'Recognizer'):
ConfigFile(data=self.base_data)
def test_no_imports(self):
"""
Test the config parser requires at least one import
"""
del self.base_data['volumes']
del self.base_data['surfaces']
with self.assertRaisesRegex(ValidationError, 'import types'):
ConfigFile(data=self.base_data)
def test_import_paths(self):
"""
Test the config parser checks for valid paths
"""
del self.base_data['volumes']['paths']
with self.assertRaisesRegex(ValidationError, 'paths'):
ConfigFile(data=self.base_data)
self.base_data['volumes']['paths'] = ['bad', {'paths': 'format'}]
with self.assertRaisesRegex(ValidationError, 'paths'):
ConfigFile(data=self.base_data)
def test_import_formats(self):
"""
Test omitting import formats sets to default values
"""
del self.base_data['volumes']['format']
del self.base_data['surfaces']['format']
cfg = ConfigFile(data=self.base_data)
self.assertEqual(cfg.volumes_format, VolumesImportFormat.IIIF)
self.assertEqual(cfg.surfaces_format, SurfacesImportFormat.XML)
from unittest.mock import patch
from django.test import override_settings
from arkindex.project.tests import RedisMockMixin, FixtureTestCase
from arkindex.documents.models import Element, ElementType, Page, MetaType
from arkindex.images.models import ImageStatus, ImageServer
......@@ -13,8 +14,13 @@ FIXTURES = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'manifest_samples',
)
ML_TOOLS = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'ml_tools',
)
@override_settings(ML_CLASSIFIERS_DIR=ML_TOOLS)
class TestManifestParser(RedisMockMixin, FixtureTestCase):
@classmethod
......@@ -222,6 +228,7 @@ class TestManifestParser(RedisMockMixin, FixtureTestCase):
gl_mock().projects.get().files.get.side_effect = gl_get_file
# Helper method to prepare commits
def copy_commit(message, src=[], dst=[]):
src = [os.path.join(FIXTURES, path) for path in src]
dst = [os.path.join(self.repo_dir, path) for path in dst]
......@@ -229,6 +236,7 @@ class TestManifestParser(RedisMockMixin, FixtureTestCase):
repo.index.add(dst)
return repo.index.commit(message)
# Helper method to run a Git import from a Git commit
def run_import(commit):
"""
Create a revision and run a synchronous import
......
......@@ -9,4 +9,5 @@ urlpatterns = [
url(r'^repos/?$', FrontendView.as_view(), name='repositories'),
url(r'^repos/new/?$', FrontendView.as_view(), name='repositories-create'),
url(r'^credentials/?$', FrontendView.as_view(), name='credentials'),
url(r'^mltools/?$', FrontendView.as_view(), name='mltools-list'),
]
......@@ -4,7 +4,6 @@ from django.contrib.postgres.fields import JSONField
from django.utils.functional import cached_property
from enumfields import EnumField, Enum
from arkindex.project.models import IndexableModel
from arkindex.project.celery import app as celery_app
from arkindex.project.fields import ArrayField
from arkindex.documents.managers import ElementManager, CorpusManager
import uuid
......@@ -398,32 +397,6 @@ class Page(Element):
]
}
def classify(self):
'''
Use a machine learning worker to classify the page
using its image
Celery is triggered through an external signature
'''
# Classify a page through ML worker
signature = celery_app.signature('arkindex_ml.tasks.classify')
task = signature.delay(self.zone.image.get_thumbnail_url(max_width=500))
# Wait for result
self.classification = task.get()
def ocr(self):
'''
Use a machine learning worker to extract text from the page
using its image
Celery is triggered through an external signature
'''
# Classify a page through ML worker
signature = celery_app.signature('arkindex_ml.tasks.extract_text')
task = signature.delay(self.zone.image.get_thumbnail_url(max_width=None))
# Wait for result
self.text = task.get()
def build_text(self):
self.text = '\n'.join(
t.text
......
......@@ -10,3 +10,4 @@ __all__ = ['celery_app']
# Register system checks
register(checks.api_urls_check)
register(checks.ml_default_tools_check)
......@@ -15,7 +15,7 @@ from arkindex.dataimport.api import (
DataImportsList, DataImportDetails, DataImportFailures, DataImportDemo,
DataFileList, DataFileRetrieve, DataFileUpload, DataImportFromFiles,
RepositoryList, RepositoryRetrieve, RepositoryStartImport,
GitRepositoryImportHook, AvailableRepositoriesList, ElementHistory,
GitRepositoryImportHook, AvailableRepositoriesList, ElementHistory, MLToolList,
)
from arkindex.users.api import (
ProvidersList, CredentialsList, CredentialsRetrieve, OAuthSignIn, OAuthRetry, OAuthCallback,
......@@ -88,19 +88,22 @@ api = [
# Import workflows
url(r'^imports/$', DataImportsList.as_view(), name='import-list'),
url(r'^imports/repos/$', RepositoryList.as_view(), name='repository-list'),
url(r'^imports/repos/(?P<pk>[\w\-]+)/$', RepositoryRetrieve.as_view(), name='repository-retrieve'),
url(r'^imports/repos/(?P<pk>[\w\-]+)/start/$', RepositoryStartImport.as_view(), name='repository-import'),
url(r'^imports/repos/search/(?P<pk>[\w\-]+)/$',
AvailableRepositoriesList.as_view(),
name='available-repositories'),
url(r'^imports/fromfiles/$', DataImportFromFiles.as_view(), name='import-from-files'),
url(r'^imports/mltools/$', MLToolList.as_view(), name='ml-tool-list'),
url(r'^imports/(?P<pk>[\w\-]+)/$', DataImportDetails.as_view(), name='import-details'),
url(r'^imports/(?P<pk>[\w\-]+)/failures/$', DataImportFailures.as_view(), name='import-failures'),
url(r'^imports/demo/(?P<pk>[\w\-]+)/$', DataImportDemo.as_view(), name='import-demo'),
url(r'^imports/files/(?P<pk>[\w\-]+)/$', DataFileList.as_view(), name='file-list'),
url(r'^imports/file/(?P<pk>[\w\-]+)/$', DataFileRetrieve.as_view(), name='file-retrieve'),
url(r'^imports/upload/(?P<pk>[\w\-]+)/$', DataFileUpload.as_view(), name='file-upload'),
# Git import workflows
url(r'^imports/repos/$', RepositoryList.as_view(), name='repository-list'),
url(r'^imports/repos/(?P<pk>[\w\-]+)/$', RepositoryRetrieve.as_view(), name='repository-retrieve'),
url(r'^imports/repos/(?P<pk>[\w\-]+)/start/$', RepositoryStartImport.as_view(), name='repository-import'),
url(r'^imports/repos/search/(?P<pk>[\w\-]+)/$',
AvailableRepositoriesList.as_view(),
name='available-repositories'),
url(r'^imports/hook/(?P<pk>[\w\-]+)/$', GitRepositoryImportHook.as_view(), name='import-hook'),
# Manage OAuth integrations
......
......@@ -15,3 +15,35 @@ def api_urls_check(*args, **kwargs):
for url in api
if not str(url.pattern).endswith('/$')
]
def ml_default_tools_check(*args, **kwargs):
"""
Check that the default ML tools defined in settings actually exist
"""
from django.conf import settings
from arkindex_common.ml_tool import MLTool, MLToolType
tools = (
(MLToolType.Classifier, 'ML_DEFAULT_CLASSIFIER'),
(MLToolType.Recognizer, 'ML_DEFAULT_RECOGNIZER'),
)
errors = []
for tool_type, tool_setting in tools:
try:
slug = getattr(settings, tool_setting)
except AttributeError:
errors.append(Error(
'Default {} has not been set'.format(tool_type.name),
hint='settings.{}'.format(tool_setting),
id='arkindex.E002',
))
continue
try:
MLTool.get(settings.ML_CLASSIFIERS_DIR, tool_type, slug)
except ValueError:
errors.append(Error(
'Default {} does not exist'.format(tool_type.name),
hint='settings.{} = "{}"'.format(tool_setting, slug),
id='arkindex.E003',
))
return errors
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment