From 2b0e1ed45186cdc90e4b68a53acc1c396e637722 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Thu, 13 Dec 2018 12:57:45 +0000 Subject: [PATCH] ML tools --- .gitlab-ci.yml | 2 + Dockerfile | 11 ++ README.md | 1 + VERSION | 2 +- arkindex/dataimport/api.py | 17 ++- arkindex/dataimport/config.py | 32 +++- .../management/commands/analyze_ml.py | 1 - arkindex/dataimport/serializers.py | 11 ++ arkindex/dataimport/tasks/base.py | 3 +- arkindex/dataimport/tasks/git.py | 2 + .../tests/manifest_samples/.arkindex.yml | 2 + .../tests/ml_tools/classifier/config.yml | 8 + .../tests/ml_tools/recognizer/config.yml | 6 + arkindex/dataimport/tests/test_config.py | 137 ++++++++++++++++++ arkindex/dataimport/tests/test_iiif.py | 8 + arkindex/dataimport/urls.py | 1 + arkindex/documents/models.py | 27 ---- arkindex/project/__init__.py | 1 + arkindex/project/api_v1.py | 17 ++- arkindex/project/checks.py | 32 ++++ arkindex/project/settings.py | 5 + arkindex/project/tests/test_checks.py | 31 +++- arkindex/templates/base.html | 1 + requirements.txt | 1 + setup.py | 10 +- 25 files changed, 327 insertions(+), 42 deletions(-) create mode 100644 arkindex/dataimport/tests/ml_tools/classifier/config.yml create mode 100644 arkindex/dataimport/tests/ml_tools/recognizer/config.yml create mode 100644 arkindex/dataimport/tests/test_config.py mode change 100644 => 100755 setup.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 487448abc4..2602086906 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -27,6 +27,8 @@ backend-tests: before_script: - apk --update add build-base + # Custom line to install arkindex-common from Git using GitLab CI credentials + - "pip install -e git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/arkindex/common#egg=arkindex-common" - pip install -r tests-requirements.txt codecov script: diff --git a/Dockerfile b/Dockerfile index 560f48302a..642e5edc33 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,8 +2,19 @@ FROM registry.gitlab.com/arkindex/backend:base-0.8.8 ARG FRONTEND_BRANCH=master ARG FRONTEND_ID=4675768 +ARG COMMON_BRANCH=master +ARG COMMON_ID=9855787 ARG GITLAB_TOKEN="M98p3wihATgCx4Z5ChvK" +# Install arkindex-common from private repo +RUN \ + apk add tar && \ + mkdir /tmp/common && \ + wget --header "PRIVATE-TOKEN: $GITLAB_TOKEN" https://gitlab.com/api/v4/projects/$COMMON_ID/repository/archive.tar.gz?sha=$COMMON_BRANCH -O /tmp/common/archive.tar.gz && \ + tar --strip-components=1 -xvf /tmp/common/archive.tar.gz -C /tmp/common && \ + cd /tmp/common && python setup.py install && \ + rm -rf /tmp/common + # Install arkindex and its deps # Uses a source archive instead of full local copy to speedup docker build COPY dist/arkindex-*.tar.gz /tmp/arkindex.tar.gz diff --git a/README.md b/README.md index 46bd896363..8fe5a2bdf4 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Dev Setup git clone git@gitlab.com:arkindex/backend.git mkvirtualenv -p /usr/bin/python3 ark cd backend +pip install -r tests-requirements.txt pip install -e .[test] ``` diff --git a/VERSION b/VERSION index 6201b5f77f..1de5fe2742 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.8.8 +0.8.9.dev diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index d7ff7e47a9..fd8a220062 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -20,9 +20,10 @@ from arkindex.dataimport.models import \ from arkindex.dataimport.serializers import ( DataImportLightSerializer, DataImportSerializer, DataImportFromFilesSerializer, DataImportFailureSerializer, DataFileSerializer, - RepositorySerializer, ExternalRepositorySerializer, EventSerializer + RepositorySerializer, ExternalRepositorySerializer, EventSerializer, MLToolSerializer, ) from arkindex.users.models import OAuthCredentials +from arkindex_common.ml_tool import MLTool from datetime import datetime import hashlib import magic @@ -375,3 +376,17 @@ class ElementHistory(ListAPIView): element_id=self.kwargs['pk'], element__corpus__in=Corpus.objects.readable(self.request.user), ) + + +class MLToolList(ListAPIView): + """ + List available machine learning tools + """ + serializer_class = MLToolSerializer + pagination_class = None + + def get_queryset(self): + return sorted( + MLTool.list(settings.ML_CLASSIFIERS_DIR), + key=lambda tool: tool.slug, + ) diff --git a/arkindex/dataimport/config.py b/arkindex/dataimport/config.py index abf6e43766..f95cd99f43 100644 --- a/arkindex/dataimport/config.py +++ b/arkindex/dataimport/config.py @@ -1,9 +1,11 @@ import fnmatch import yaml from enum import Enum +from django.conf import settings from django.core.validators import URLValidator from django.core.exceptions import ValidationError from django.utils.functional import cached_property +from arkindex_common.ml_tool import MLTool, MLToolType from arkindex.images.models import ImageServer @@ -44,7 +46,7 @@ class ConfigFile(object): """ FILE_NAME = '.arkindex.yml' - REQUIRED_ITEMS = ('version', 'branches') + REQUIRED_ITEMS = {'version', 'branches', 'classifier', 'recognizer'} FORMAT_ENUMS = { ImportType.Volumes: VolumesImportFormat, ImportType.Transcriptions: TranscriptionsImportFormat, @@ -56,7 +58,10 @@ class ConfigFile(object): def __init__(self, data=None): if not data: return - parsed = yaml.load(data) + if isinstance(data, dict): + parsed = data + else: + parsed = yaml.load(data) self.validated_data = self.validate(parsed) self.setattrs(self.validated_data) @@ -93,6 +98,19 @@ class ConfigFile(object): raise ValidationError("Bad 'branches' format: should be a list of branch names") validated_data['branches'] = parsed['branches'] + # ML tools + try: + MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Classifier, parsed['classifier']) + except ValueError: + raise ValidationError('Classifier "{}" not found'.format(parsed['classifier'])) + validated_data['classifier'] = parsed['classifier'] + + try: + MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Recognizer, parsed['recognizer']) + except ValueError: + raise ValidationError('Recognizer "{}" not found'.format(parsed['recognizer'])) + validated_data['recognizer'] = parsed['recognizer'] + # At least one import type required if not any(it.value in parsed for it in ImportType): raise ValidationError("No import types were specified") @@ -161,6 +179,8 @@ class ConfigFile(object): self.version = validated_data['version'] self.branches = validated_data['branches'] self.imports = list(filter(lambda it: it.value in validated_data, ImportType)) + self.classifier_slug = validated_data['classifier'] + self.recognizer_slug = validated_data['recognizer'] # Default formats self.volumes_format = VolumesImportFormat.IIIF @@ -223,3 +243,11 @@ class ConfigFile(object): s, _ = ImageServer.objects.get_or_create(url=url, defaults={'name': name}) servers.append(s) return servers + + @cached_property + def classifier(self): + return MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Classifier, self.classifier_slug) + + @cached_property + def recognizer(self): + return MLTool.get(settings.ML_CLASSIFIERS_DIR, MLToolType.Recognizer, self.recognizer_slug) diff --git a/arkindex/dataimport/management/commands/analyze_ml.py b/arkindex/dataimport/management/commands/analyze_ml.py index baaf786397..6babe54d4f 100644 --- a/arkindex/dataimport/management/commands/analyze_ml.py +++ b/arkindex/dataimport/management/commands/analyze_ml.py @@ -26,7 +26,6 @@ class Command(BaseCommand): pk__in=Element.objects.get_descending(volume.id, type=ElementType.Page) ) payload = { - 'volume': str(volume.id), 'pages': [ (page.id, page.zone.image.get_thumbnail_url(max_width=1500)) for page in pages diff --git a/arkindex/dataimport/serializers.py b/arkindex/dataimport/serializers.py index edd4c4be4d..9f0c5b9bd8 100644 --- a/arkindex/dataimport/serializers.py +++ b/arkindex/dataimport/serializers.py @@ -7,6 +7,7 @@ from arkindex.dataimport.models import ( ) from arkindex.documents.models import Corpus, Element, ElementType from arkindex.documents.serializers.light import ElementLightSerializer +from arkindex_common.ml_tool import MLToolType import gitlab.v4.objects import celery.states @@ -333,3 +334,13 @@ class EventSerializer(serializers.ModelSerializer): 'date', 'revision', ) + + +class MLToolSerializer(serializers.Serializer): + """ + Serialize a machine learning tool for display + """ + name = serializers.CharField() + slug = serializers.SlugField() + type = EnumField(MLToolType) + version = serializers.CharField() diff --git a/arkindex/dataimport/tasks/base.py b/arkindex/dataimport/tasks/base.py index 045b0b7002..cafb6d1c9a 100644 --- a/arkindex/dataimport/tasks/base.py +++ b/arkindex/dataimport/tasks/base.py @@ -58,8 +58,9 @@ def build_volume(self, files, dataimport, server_id=settings.LOCAL_IMAGESERVER_I self.report_message("Imported {} pages into {}".format(len(pages), volume.name)) generate_thumbnail(volume.id) return { - 'volume': str(volume.id), 'pages': pages, + 'classifier': settings.ML_DEFAULT_CLASSIFIER, + 'recognizer': settings.ML_DEFAULT_RECOGNIZER, } diff --git a/arkindex/dataimport/tasks/git.py b/arkindex/dataimport/tasks/git.py index 2b24a2ad42..eab2e061b5 100644 --- a/arkindex/dataimport/tasks/git.py +++ b/arkindex/dataimport/tasks/git.py @@ -59,6 +59,8 @@ def parse_config(self, dataimport): self.report_message('Configuration file version {}'.format(config.version)) self.report_message('Git branches to trigger imports on: {}'.format(', '.join(config.branches))) + self.report_message('Classifier: {}'.format(config.classifier.name)) + self.report_message('Recognizer: {}'.format(config.recognizer.name)) self.report_message('Active import types: {}'.format(', '.join(import_type.name for import_type in config.imports))) for import_type in config.imports: diff --git a/arkindex/dataimport/tests/manifest_samples/.arkindex.yml b/arkindex/dataimport/tests/manifest_samples/.arkindex.yml index 377895537e..77902be919 100644 --- a/arkindex/dataimport/tests/manifest_samples/.arkindex.yml +++ b/arkindex/dataimport/tests/manifest_samples/.arkindex.yml @@ -1,6 +1,8 @@ version: 1 branches: - master +classifier: dummy_classifier +recognizer: dummy_recognizer volumes: format: iiif diff --git a/arkindex/dataimport/tests/ml_tools/classifier/config.yml b/arkindex/dataimport/tests/ml_tools/classifier/config.yml new file mode 100644 index 0000000000..e1e9ddc0f3 --- /dev/null +++ b/arkindex/dataimport/tests/ml_tools/classifier/config.yml @@ -0,0 +1,8 @@ +--- +name: Unit test classifier +type: classifier +slug: dummy_classifier +version: 1.0.0 +classes: + - class_1 + - class_2 diff --git a/arkindex/dataimport/tests/ml_tools/recognizer/config.yml b/arkindex/dataimport/tests/ml_tools/recognizer/config.yml new file mode 100644 index 0000000000..83cb84d587 --- /dev/null +++ b/arkindex/dataimport/tests/ml_tools/recognizer/config.yml @@ -0,0 +1,6 @@ +--- +name: Unit test recognizer +type: recognizer +slug: dummy_recognizer +version: 0.4.2 +tesseract: [] diff --git a/arkindex/dataimport/tests/test_config.py b/arkindex/dataimport/tests/test_config.py new file mode 100644 index 0000000000..9dedcb0d63 --- /dev/null +++ b/arkindex/dataimport/tests/test_config.py @@ -0,0 +1,137 @@ +from django.core.exceptions import ValidationError +from django.test import override_settings +from arkindex.project.tests import FixtureTestCase +from arkindex.images.models import ImageServer +from arkindex.dataimport.config import \ + ConfigFile, ImportType, VolumesImportFormat, SurfacesImportFormat +from arkindex_common.ml_tool import MLToolType +import os.path + +ML_TOOLS = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'ml_tools', +) + + +@override_settings(ML_CLASSIFIERS_DIR=ML_TOOLS) +class TestConfigFile(FixtureTestCase): + """ + Tests for the .arkindex.yml configuration file parser + """ + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.sample_server = ImageServer.objects.create(name='sample', url='http://example.com/iiif/') + + def setUp(self): + # A valid configuration; tests will edit/remove values to check for validations + self.base_data = { + "version": 1, + "branches": ["master", "dev"], + "classifier": "dummy_classifier", + "recognizer": "dummy_recognizer", + "volumes": { + "format": "txt", + "paths": ["volumes/*", "manifests/*"], + "image_servers": { + "sample": "http://example.com/iiif/", + }, + "lazy_checks": True, + "autoconvert_https": True, + }, + "surfaces": { + "format": "xml", + "paths": ["surfaces/*"], + } + } + + def test_base(self): + """ + Test parsing a valid configuration file + """ + cfg = ConfigFile(data=self.base_data) + self.assertEqual(cfg.version, 1) + self.assertListEqual(cfg.branches, ['master', 'dev']) + self.assertCountEqual(cfg.imports, [ImportType.Volumes, ImportType.Surfaces]) + + self.assertEqual(cfg.volumes_format, VolumesImportFormat.TXT) + self.assertEqual(cfg.surfaces_format, SurfacesImportFormat.XML) + self.assertListEqual(cfg.volumes_paths, ['volumes/*', 'manifests/*']) + self.assertListEqual(cfg.surfaces_paths, ['surfaces/*']) + + self.assertEqual(cfg.classifier.slug, 'dummy_classifier') + self.assertEqual(cfg.recognizer.slug, 'dummy_recognizer') + + self.assertEqual(str(cfg.classifier.version), '1.0.0') + self.assertEqual(str(cfg.recognizer.version), '0.4.2') + + self.assertEqual(cfg.classifier.name, 'Unit test classifier') + self.assertEqual(cfg.classifier.type, MLToolType.Classifier) + + self.assertEqual(cfg.recognizer.name, 'Unit test recognizer') + self.assertEqual(cfg.recognizer.type, MLToolType.Recognizer) + + self.assertTrue(cfg.volumes_lazy_checks) + self.assertTrue(cfg.volumes_autoconvert_https) + self.assertEqual(cfg.volumes_image_servers, [self.sample_server]) + + def test_version(self): + """ + Test the ConfigFile handles versioning + """ + self.base_data['version'] = 42 # Bad version number + with self.assertRaisesRegex(ValidationError, 'version'): + ConfigFile(data=self.base_data) + + def test_branches(self): + """ + Test the ConfigFile handles branches + """ + self.base_data['branches'] = ['this', 'is', 'a', {'bad': 'format'}] + with self.assertRaisesRegex(ValidationError, 'branches'): + ConfigFile(data=self.base_data) + + def test_ml_tools(self): + """ + Test the config parser checks for existence of ML tools + """ + self.base_data['classifier'] = 'nope' + with self.assertRaisesRegex(ValidationError, 'Classifier'): + ConfigFile(data=self.base_data) + + self.base_data['classifier'] = 'dummy_classifier' + self.base_data['recognizer'] = 'nope' + with self.assertRaisesRegex(ValidationError, 'Recognizer'): + ConfigFile(data=self.base_data) + + def test_no_imports(self): + """ + Test the config parser requires at least one import + """ + del self.base_data['volumes'] + del self.base_data['surfaces'] + with self.assertRaisesRegex(ValidationError, 'import types'): + ConfigFile(data=self.base_data) + + def test_import_paths(self): + """ + Test the config parser checks for valid paths + """ + del self.base_data['volumes']['paths'] + with self.assertRaisesRegex(ValidationError, 'paths'): + ConfigFile(data=self.base_data) + + self.base_data['volumes']['paths'] = ['bad', {'paths': 'format'}] + with self.assertRaisesRegex(ValidationError, 'paths'): + ConfigFile(data=self.base_data) + + def test_import_formats(self): + """ + Test omitting import formats sets to default values + """ + del self.base_data['volumes']['format'] + del self.base_data['surfaces']['format'] + cfg = ConfigFile(data=self.base_data) + self.assertEqual(cfg.volumes_format, VolumesImportFormat.IIIF) + self.assertEqual(cfg.surfaces_format, SurfacesImportFormat.XML) diff --git a/arkindex/dataimport/tests/test_iiif.py b/arkindex/dataimport/tests/test_iiif.py index dea8b9a44b..ea7399675f 100644 --- a/arkindex/dataimport/tests/test_iiif.py +++ b/arkindex/dataimport/tests/test_iiif.py @@ -1,4 +1,5 @@ from unittest.mock import patch +from django.test import override_settings from arkindex.project.tests import RedisMockMixin, FixtureTestCase from arkindex.documents.models import Element, ElementType, Page, MetaType from arkindex.images.models import ImageStatus, ImageServer @@ -13,8 +14,13 @@ FIXTURES = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'manifest_samples', ) +ML_TOOLS = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'ml_tools', +) +@override_settings(ML_CLASSIFIERS_DIR=ML_TOOLS) class TestManifestParser(RedisMockMixin, FixtureTestCase): @classmethod @@ -222,6 +228,7 @@ class TestManifestParser(RedisMockMixin, FixtureTestCase): gl_mock().projects.get().files.get.side_effect = gl_get_file + # Helper method to prepare commits def copy_commit(message, src=[], dst=[]): src = [os.path.join(FIXTURES, path) for path in src] dst = [os.path.join(self.repo_dir, path) for path in dst] @@ -229,6 +236,7 @@ class TestManifestParser(RedisMockMixin, FixtureTestCase): repo.index.add(dst) return repo.index.commit(message) + # Helper method to run a Git import from a Git commit def run_import(commit): """ Create a revision and run a synchronous import diff --git a/arkindex/dataimport/urls.py b/arkindex/dataimport/urls.py index ed4abd173a..fbd89318b3 100644 --- a/arkindex/dataimport/urls.py +++ b/arkindex/dataimport/urls.py @@ -9,4 +9,5 @@ urlpatterns = [ url(r'^repos/?$', FrontendView.as_view(), name='repositories'), url(r'^repos/new/?$', FrontendView.as_view(), name='repositories-create'), url(r'^credentials/?$', FrontendView.as_view(), name='credentials'), + url(r'^mltools/?$', FrontendView.as_view(), name='mltools-list'), ] diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 93c58684fb..bd4adfc0d7 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -4,7 +4,6 @@ from django.contrib.postgres.fields import JSONField from django.utils.functional import cached_property from enumfields import EnumField, Enum from arkindex.project.models import IndexableModel -from arkindex.project.celery import app as celery_app from arkindex.project.fields import ArrayField from arkindex.documents.managers import ElementManager, CorpusManager import uuid @@ -398,32 +397,6 @@ class Page(Element): ] } - def classify(self): - ''' - Use a machine learning worker to classify the page - using its image - Celery is triggered through an external signature - ''' - # Classify a page through ML worker - signature = celery_app.signature('arkindex_ml.tasks.classify') - task = signature.delay(self.zone.image.get_thumbnail_url(max_width=500)) - - # Wait for result - self.classification = task.get() - - def ocr(self): - ''' - Use a machine learning worker to extract text from the page - using its image - Celery is triggered through an external signature - ''' - # Classify a page through ML worker - signature = celery_app.signature('arkindex_ml.tasks.extract_text') - task = signature.delay(self.zone.image.get_thumbnail_url(max_width=None)) - - # Wait for result - self.text = task.get() - def build_text(self): self.text = '\n'.join( t.text diff --git a/arkindex/project/__init__.py b/arkindex/project/__init__.py index e3f694af2e..213122c24e 100644 --- a/arkindex/project/__init__.py +++ b/arkindex/project/__init__.py @@ -10,3 +10,4 @@ __all__ = ['celery_app'] # Register system checks register(checks.api_urls_check) +register(checks.ml_default_tools_check) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 75126096e6..a22dd503ce 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -15,7 +15,7 @@ from arkindex.dataimport.api import ( DataImportsList, DataImportDetails, DataImportFailures, DataImportDemo, DataFileList, DataFileRetrieve, DataFileUpload, DataImportFromFiles, RepositoryList, RepositoryRetrieve, RepositoryStartImport, - GitRepositoryImportHook, AvailableRepositoriesList, ElementHistory, + GitRepositoryImportHook, AvailableRepositoriesList, ElementHistory, MLToolList, ) from arkindex.users.api import ( ProvidersList, CredentialsList, CredentialsRetrieve, OAuthSignIn, OAuthRetry, OAuthCallback, @@ -88,19 +88,22 @@ api = [ # Import workflows url(r'^imports/$', DataImportsList.as_view(), name='import-list'), - url(r'^imports/repos/$', RepositoryList.as_view(), name='repository-list'), - url(r'^imports/repos/(?P<pk>[\w\-]+)/$', RepositoryRetrieve.as_view(), name='repository-retrieve'), - url(r'^imports/repos/(?P<pk>[\w\-]+)/start/$', RepositoryStartImport.as_view(), name='repository-import'), - url(r'^imports/repos/search/(?P<pk>[\w\-]+)/$', - AvailableRepositoriesList.as_view(), - name='available-repositories'), url(r'^imports/fromfiles/$', DataImportFromFiles.as_view(), name='import-from-files'), + url(r'^imports/mltools/$', MLToolList.as_view(), name='ml-tool-list'), url(r'^imports/(?P<pk>[\w\-]+)/$', DataImportDetails.as_view(), name='import-details'), url(r'^imports/(?P<pk>[\w\-]+)/failures/$', DataImportFailures.as_view(), name='import-failures'), url(r'^imports/demo/(?P<pk>[\w\-]+)/$', DataImportDemo.as_view(), name='import-demo'), url(r'^imports/files/(?P<pk>[\w\-]+)/$', DataFileList.as_view(), name='file-list'), url(r'^imports/file/(?P<pk>[\w\-]+)/$', DataFileRetrieve.as_view(), name='file-retrieve'), url(r'^imports/upload/(?P<pk>[\w\-]+)/$', DataFileUpload.as_view(), name='file-upload'), + + # Git import workflows + url(r'^imports/repos/$', RepositoryList.as_view(), name='repository-list'), + url(r'^imports/repos/(?P<pk>[\w\-]+)/$', RepositoryRetrieve.as_view(), name='repository-retrieve'), + url(r'^imports/repos/(?P<pk>[\w\-]+)/start/$', RepositoryStartImport.as_view(), name='repository-import'), + url(r'^imports/repos/search/(?P<pk>[\w\-]+)/$', + AvailableRepositoriesList.as_view(), + name='available-repositories'), url(r'^imports/hook/(?P<pk>[\w\-]+)/$', GitRepositoryImportHook.as_view(), name='import-hook'), # Manage OAuth integrations diff --git a/arkindex/project/checks.py b/arkindex/project/checks.py index a6b65e1372..39dd39ae41 100644 --- a/arkindex/project/checks.py +++ b/arkindex/project/checks.py @@ -15,3 +15,35 @@ def api_urls_check(*args, **kwargs): for url in api if not str(url.pattern).endswith('/$') ] + + +def ml_default_tools_check(*args, **kwargs): + """ + Check that the default ML tools defined in settings actually exist + """ + from django.conf import settings + from arkindex_common.ml_tool import MLTool, MLToolType + tools = ( + (MLToolType.Classifier, 'ML_DEFAULT_CLASSIFIER'), + (MLToolType.Recognizer, 'ML_DEFAULT_RECOGNIZER'), + ) + errors = [] + for tool_type, tool_setting in tools: + try: + slug = getattr(settings, tool_setting) + except AttributeError: + errors.append(Error( + 'Default {}Â has not been set'.format(tool_type.name), + hint='settings.{}'.format(tool_setting), + id='arkindex.E002', + )) + continue + try: + MLTool.get(settings.ML_CLASSIFIERS_DIR, tool_type, slug) + except ValueError: + errors.append(Error( + 'Default {} does not exist'.format(tool_type.name), + hint='settings.{} = "{}"'.format(tool_setting, slug), + id='arkindex.E003', + )) + return errors diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index bac50cba76..7a2d2e5bfc 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -335,6 +335,11 @@ CELERY_WORKING_DIR = os.environ.get('CELERY_WORKING_DIR', os.path.join(BASE_DIR, CELERY_WORKER_CONCURRENCY = int(os.environ.get('CELERY_WORKER_CONCURRENCY', 1)) CELERY_WORKER_PREFETCH_MULTIPLIER = 1 +# ML worker +ML_CLASSIFIERS_DIR = os.environ.get('ML_CLASSIFIERS_DIR', os.path.join(BASE_DIR, '../../ml-classifiers')) +ML_DEFAULT_CLASSIFIER = 'tobacco' +ML_DEFAULT_RECOGNIZER = 'tesseract' + # Email EMAIL_SUBJECT_PREFIX = '[Arkindex {}] '.format(ARKINDEX_ENV) if os.environ.get('EMAIL_HOST'): diff --git a/arkindex/project/tests/test_checks.py b/arkindex/project/tests/test_checks.py index a09c45f0c5..5e4f0cec43 100644 --- a/arkindex/project/tests/test_checks.py +++ b/arkindex/project/tests/test_checks.py @@ -1,4 +1,5 @@ -from unittest import TestCase +from unittest.mock import patch +from django.test import TestCase from django.conf.urls import url from django.core.checks import Error @@ -29,3 +30,31 @@ class ChecksTestCase(TestCase): ) ] ) + + @patch('arkindex_common.ml_tool.MLTool.get') + def test_ml_default_tools_check(self, ml_get_mock): + """ + Test the default ML tools existence checks + """ + from arkindex.project.checks import ml_default_tools_check + + self.assertListEqual(ml_default_tools_check(), []) + + ml_get_mock.side_effect = ValueError + + with self.settings(ML_DEFAULT_CLASSIFIER='fail1', ML_DEFAULT_RECOGNIZER='fail2'): + self.assertListEqual( + ml_default_tools_check(), + [ + Error( + 'Default Classifier does not exist', + hint='settings.ML_DEFAULT_CLASSIFIER = "fail1"', + id='arkindex.E003', + ), + Error( + 'Default Recognizer does not exist', + hint='settings.ML_DEFAULT_RECOGNIZER = "fail2"', + id='arkindex.E003', + ), + ] + ) diff --git a/arkindex/templates/base.html b/arkindex/templates/base.html index d7844636ff..8585063641 100644 --- a/arkindex/templates/base.html +++ b/arkindex/templates/base.html @@ -64,6 +64,7 @@ <div class="navbar-dropdown"> <a href="{% url 'credentials' %}" class="navbar-item">OAuth</a> <a href="{% url 'corpus-list' %}" class="navbar-item">Corpora</a> + <a href="{% url 'mltools-list' %}" class="navbar-item">ML tools</a> {% if user.is_admin %} <a href="{% url 'admin:index' %}" class="navbar-item">Admin</a> {% endif %} diff --git a/requirements.txt b/requirements.txt index ec73c8bdf8..529d594832 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ # -r ./base/requirements.txt +arkindex-common==0.1.0 celery==4.2.0 celery_once==2.0.0 certifi==2017.7.27.1 diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 2b45d2f053..cba5daea3b --- a/setup.py +++ b/setup.py @@ -3,10 +3,18 @@ import os.path from setuptools import setup, find_packages +def _parse_requirement(line): + if '#egg=' not in line: + return line + # When a requirement is from Git, remove the Git part and keep the egg name. + # This is needed as setup.py does not want any Git requirements in install_requires + return line.rpartition('#egg=')[2] + + def requirements(path): assert os.path.exists(path), 'Missing requirements {}'.format(path) with open(path) as f: - return f.read().splitlines() + return list(map(_parse_requirement, f.read().splitlines())) with open('VERSION') as f: -- GitLab