Skip to content
Snippets Groups Projects
Commit f0547d54 authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'commands' into 'master'

Remove duplicate argument validations in commands

See merge request !204
parents 7e129760 3c23dfc7
No related branches found
No related tags found
1 merge request!204Remove duplicate argument validations in commands
#!/usr/bin/env python3 #!/usr/bin/env python3
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand
from arkindex_common.ml_tool import MLToolType from arkindex_common.ml_tool import MLToolType
from arkindex.dataimport.models import DataImport, DataImportMode from arkindex.project.argparse import DataImportArgument
from arkindex.dataimport.models import DataImportMode
from arkindex.dataimport.tasks import extract_pdf_images, populate_volume, setup_ml_analysis, check_images from arkindex.dataimport.tasks import extract_pdf_images, populate_volume, setup_ml_analysis, check_images
from arkindex.dataimport.git import GitFlow from arkindex.dataimport.git import GitFlow
from django.conf import settings from django.conf import settings
...@@ -23,22 +24,17 @@ class Command(BaseCommand): ...@@ -23,22 +24,17 @@ class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'data_import', 'data_import',
type=str, type=DataImportArgument(),
help='ID of the DataImport to run' help='ID of the DataImport to run'
) )
def handle(self, *args, **options): def handle(self, *args, data_import=None, **options):
# Use default ML tools here # Use default ML tools here
ml_tools = ( ml_tools = (
(MLToolType.Classifier, settings.ML_DEFAULT_CLASSIFIER), (MLToolType.Classifier, settings.ML_DEFAULT_CLASSIFIER),
(MLToolType.Recognizer, settings.ML_DEFAULT_RECOGNIZER), (MLToolType.Recognizer, settings.ML_DEFAULT_RECOGNIZER),
) )
try:
data_import = DataImport.objects.get(pk=options['data_import'])
except DataImport.DoesNotExist:
raise CommandError('Missing DataImport')
# Use shared directory when running in docker # Use shared directory when running in docker
# Fallback to a temp directory while developing # Fallback to a temp directory while developing
task_dir = os.environ.get('PONOS_DATA', tempfile.mkdtemp(suffix='-ponos')) task_dir = os.environ.get('PONOS_DATA', tempfile.mkdtemp(suffix='-ponos'))
......
#!/usr/bin/env python3
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from arkindex.dataimport.models import DataImport, DataImportMode, Repository from arkindex.project.argparse import RepositoryArgument
from arkindex.dataimport.models import DataImport, DataImportMode
class Command(BaseCommand): class Command(BaseCommand):
...@@ -8,8 +8,9 @@ class Command(BaseCommand): ...@@ -8,8 +8,9 @@ class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'repository', 'repo',
help='ID of the repository to check on', type=RepositoryArgument(),
help='ID or part of the URL of the repository to check on',
) )
parser.add_argument( parser.add_argument(
'--hash', '--hash',
...@@ -17,14 +18,9 @@ class Command(BaseCommand): ...@@ -17,14 +18,9 @@ class Command(BaseCommand):
default=None, default=None,
) )
def handle(self, *args, **options): def handle(self, *args, repo=None, **options):
try:
repo = Repository.objects.get(id=options['repository'])
except Repository.DoesNotExist:
raise CommandError('Repository {} not found'.format(options['repository']))
if repo.provider_class is None: if repo.provider_class is None:
raise ValueError("No repository provider found for {}".format(repo.url)) raise CommandError("No repository provider found for {}".format(repo.url))
if 'hash' in options and options['hash'] is not None: if 'hash' in options and options['hash'] is not None:
rev, created = repo.provider.get_or_create_revision(repo, options['hash']) rev, created = repo.provider.get_or_create_revision(repo, options['hash'])
...@@ -32,7 +28,7 @@ class Command(BaseCommand): ...@@ -32,7 +28,7 @@ class Command(BaseCommand):
rev, created = repo.provider.get_or_create_latest_revision(repo) rev, created = repo.provider.get_or_create_latest_revision(repo)
if created: if created:
print('Created revision {} "{}" on repository {}'.format(rev.hash, rev.message, repo.url)) self.stdout.write('Created revision {} "{}" on repository {}'.format(rev.hash, rev.message, repo.url))
di = DataImport.objects.create( di = DataImport.objects.create(
creator=repo.credentials.user, creator=repo.credentials.user,
...@@ -42,5 +38,5 @@ class Command(BaseCommand): ...@@ -42,5 +38,5 @@ class Command(BaseCommand):
) )
di.start() di.start()
print('Successfully built DataImport {}'.format(di)) self.stdout.write(self.style.SUCCESS('Successfully built DataImport {}'.format(di)))
print('To test the import manually, run: ./manage.py import {}'.format(di.id)) self.stdout.write('To test the import manually, run: ./manage.py import {}'.format(di.id))
#!/usr/bin/env python3 #!/usr/bin/env python3
from django.core.management.base import CommandError from django.core.management.base import CommandError
from django.core.exceptions import ValidationError
from django.conf import settings from django.conf import settings
from ponos.management.base import PonosCommand from ponos.management.base import PonosCommand
from arkindex.documents.models import Corpus, Element, ElementType from arkindex.project.argparse import CorpusArgument, ElementArgument
from arkindex.documents.models import Element, ElementType
class Command(PonosCommand): class Command(PonosCommand):
...@@ -13,20 +13,21 @@ class Command(PonosCommand): ...@@ -13,20 +13,21 @@ class Command(PonosCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
super().add_arguments(parser) super().add_arguments(parser)
# TODO: Mutually exclusive group parser.add_argument(
'--all',
help='Create thumbnails for every volume in every corpus',
action='store_true',
default=False,
)
parser.add_argument( parser.add_argument(
'--corpus', '--corpus',
help='ID or part of the name of the corpus to fetch volumes from', help='ID or part of the name of the corpus to fetch volumes from',
type=CorpusArgument(),
) )
parser.add_argument( parser.add_argument(
'--element', '--element',
help='ID or part of the name of a single element to build a thumbnail for', help='ID or part of the name of a single element to build a thumbnail for',
) type=ElementArgument(),
parser.add_argument(
'--all',
help='Create thumbnails for every volume in every corpus',
action='store_true',
default=False,
) )
parser.add_argument( parser.add_argument(
'--force', '--force',
...@@ -35,34 +36,23 @@ class Command(PonosCommand): ...@@ -35,34 +36,23 @@ class Command(PonosCommand):
default=False, default=False,
) )
def validate_args(self, **options): def validate_args(self, corpus=None, element=None, all=False, force=False, **options):
if options['all']: if all:
if options['corpus'] or options['element']: if corpus or element:
raise CommandError('--all cannot be used together with --corpus or --element') raise CommandError('--all cannot be used together with --corpus or --element')
return {'elements': Element.objects.filter(type=ElementType.Volume), 'force': options['force']} return {'elements': Element.objects.filter(type=ElementType.Volume), 'force': force}
try: if not corpus:
corpus = Corpus.objects.get(pk=options['corpus']) raise CommandError('--corpus is required when not using --all')
except (Corpus.DoesNotExist, ValidationError):
try:
corpus = Corpus.objects.get(name__icontains=options['corpus'])
except Corpus.DoesNotExist:
raise CommandError('Corpus "{}" does not exist'.format(options['corpus']))
if options['element']:
try:
elt = corpus.elements.get(pk=options['element'])
except (Element.DoesNotExist, ValidationError):
try:
elt = corpus.elements.get(name__icontains=options['element'])
except Element.DoesNotExist:
raise CommandError('Element "{}" does not exist'.format(options['element']))
elts = [elt, ] if element:
if not corpus.elements.filter(pk=element.pk).exists():
raise CommandError('Element {} is not in corpus {}'.format(element, corpus))
elts = [element, ]
else: else:
elts = Element.objects.filter(corpus=corpus, type=ElementType.Volume) elts = corpus.elements.filter(type=ElementType.Volume)
return {'elements': elts, 'force': options['force']} return {'elements': elts, 'force': force}
def run(self, elements=[], force=False): def run(self, elements=[], force=False):
for element in elements: for element in elements:
......
#!/usr/bin/env python3 #!/usr/bin/env python3
from django.core.management.base import CommandError
from django.core.exceptions import ValidationError
from django.conf import settings from django.conf import settings
from ponos.management.base import PonosCommand from ponos.management.base import PonosCommand
from arkindex.project.argparse import ElementArgument
from arkindex.documents.indexer import Indexer from arkindex.documents.indexer import Indexer
from arkindex.documents.models import Element, ElementType, Act, Transcription, Page from arkindex.documents.models import Element, ElementType, Act, Transcription, Page
import logging import logging
...@@ -87,6 +86,7 @@ class Command(PonosCommand): ...@@ -87,6 +86,7 @@ class Command(PonosCommand):
parser.add_argument( parser.add_argument(
'--volume', '--volume',
help='Restrict reindexing to a specific volume by ID or part of the name', help='Restrict reindexing to a specific volume by ID or part of the name',
type=ElementArgument(type=ElementType.Volume),
) )
parser.add_argument( parser.add_argument(
'--drop', '--drop',
...@@ -100,26 +100,13 @@ class Command(PonosCommand): ...@@ -100,26 +100,13 @@ class Command(PonosCommand):
for k in self.index_methods.keys(): for k in self.index_methods.keys():
options[k] = True options[k] = True
volume = None
if options['volume']:
try:
volume = Element.objects.get(type=ElementType.Volume, pk=options['volume'])
except (Element.DoesNotExist, ValidationError):
try:
volume = Element.objects.get(
type=ElementType.Volume,
name__icontains=options['volume'],
)
except Element.DoesNotExist:
raise CommandError('Volume "{}" not found'.format(options['volume']))
return { return {
'methods': [ 'methods': [
key key
for key in self.index_methods.keys() for key in self.index_methods.keys()
if options.get(key) if options.get(key)
], ],
'volume': volume, 'volume': options['volume'],
'drop': options.get('drop', False), 'drop': options.get('drop', False),
} }
......
from django.core.management import call_command from django.core.management import call_command
from django.core.management.base import CommandError
from unittest.mock import patch, call from unittest.mock import patch, call
from arkindex.project.tests import FixtureTestCase from arkindex.project.tests import FixtureTestCase
from arkindex.documents.models import Element, ElementType from arkindex.documents.models import Corpus, ElementType
class TestGenerateThumbnailsCommand(FixtureTestCase): class TestGenerateThumbnailsCommand(FixtureTestCase):
...@@ -12,9 +13,11 @@ class TestGenerateThumbnailsCommand(FixtureTestCase): ...@@ -12,9 +13,11 @@ class TestGenerateThumbnailsCommand(FixtureTestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
super().setUpTestData() super().setUpTestData()
cls.reg = Element.objects.get(type=ElementType.Register, name="Register 1") cls.reg = cls.corpus.elements.get(type=ElementType.Register, name="Register 1")
cls.vol1 = Element.objects.get(type=ElementType.Volume, name="Volume 1") cls.vol1 = cls.corpus.elements.get(type=ElementType.Volume, name="Volume 1")
cls.vol2 = Element.objects.get(type=ElementType.Volume, name="Volume 2") cls.vol2 = cls.corpus.elements.get(type=ElementType.Volume, name="Volume 2")
corpus2 = Corpus.objects.create(name='Other corpus')
cls.vol3 = corpus2.elements.create(type=ElementType.Volume, name='Volume 3')
cls.thumb_patch = patch('arkindex.images.models.Thumbnail.create') cls.thumb_patch = patch('arkindex.images.models.Thumbnail.create')
def setUp(self): def setUp(self):
...@@ -27,11 +30,11 @@ class TestGenerateThumbnailsCommand(FixtureTestCase): ...@@ -27,11 +30,11 @@ class TestGenerateThumbnailsCommand(FixtureTestCase):
def test_start_corpus(self): def test_start_corpus(self):
""" """
Test generate_thumbnails starts a Celery task for each volume of a corpus Test generate_thumbnails runs generation for each volume of a corpus
""" """
call_command( call_command(
'generate_thumbnails', 'generate_thumbnails',
corpus=str(self.corpus.id), corpus=self.corpus,
) )
self.assertCountEqual(self.thumb_mock.call_args_list, [ self.assertCountEqual(self.thumb_mock.call_args_list, [
call(self.vol1), call(self.vol1),
...@@ -40,54 +43,71 @@ class TestGenerateThumbnailsCommand(FixtureTestCase): ...@@ -40,54 +43,71 @@ class TestGenerateThumbnailsCommand(FixtureTestCase):
def test_start_element(self): def test_start_element(self):
""" """
Test generate_thumbnails starts a Celery task for an element Test generate_thumbnails runs generation for an element
""" """
call_command( call_command(
'generate_thumbnails', 'generate_thumbnails',
corpus=str(self.corpus.id), corpus=self.corpus,
element=str(self.reg.id), element=self.reg,
) )
self.assertCountEqual(self.thumb_mock.call_args_list, [ self.assertCountEqual(self.thumb_mock.call_args_list, [
call(self.reg), call(self.reg),
]) ])
def test_corpus_name(self): @patch('arkindex.documents.models.Element.generate_thumbnail')
def test_force(self, gen_mock):
""" """
Test generate_thumbnails accepts a part of the corpus name instead of an ID Test generate_thumbnails passes the --force argument to tasks
""" """
call_command( call_command(
'generate_thumbnails', 'generate_thumbnails',
corpus='tests', corpus=self.corpus,
element=self.reg,
force=True,
) )
self.assertCountEqual(self.thumb_mock.call_args_list, [ self.assertCountEqual(gen_mock.call_args_list, [
call(self.vol1), call(force=True),
call(self.vol2),
]) ])
def test_element_name(self): def test_all(self):
""" """
Test generate_thumbnails accepts a part of an element name instead of an ID Test generate_thumbnails picks all volumes when using --all
""" """
call_command( call_command(
'generate_thumbnails', 'generate_thumbnails',
corpus=str(self.corpus.id), all=True,
element='register 1',
) )
self.assertCountEqual(self.thumb_mock.call_args_list, [ self.assertCountEqual(self.thumb_mock.call_args_list, [
call(self.reg), call(self.vol1),
call(self.vol2),
call(self.vol3),
]) ])
@patch('arkindex.documents.models.Element.generate_thumbnail') def test_all_xor_corpus(self):
def test_force(self, gen_mock):
""" """
Test generate_thumbnails passes the --force argument to tasks Test generate_thumbnails does not allow --all and --corpus simultaneously
""" """
call_command( with self.assertRaisesRegex(CommandError, r'--all.+--corpus'):
'generate_thumbnails', call_command(
corpus=str(self.corpus.id), 'generate_thumbnails',
element='register 1', all=True,
force=True, corpus=self.corpus,
) )
self.assertCountEqual(gen_mock.call_args_list, [
call(force=True), def test_corpus_required(self):
]) """
Test generate_thumbnails requires either --all or --corpus
"""
with self.assertRaisesRegex(CommandError, '--corpus'):
call_command('generate_thumbnails')
def test_element_in_corpus(self):
"""
Test generate_thumbnails requires --element to be inside --corpus
"""
with self.assertRaisesRegex(CommandError, 'not in corpus'):
call_command(
'generate_thumbnails',
corpus=self.corpus,
element=self.vol3,
)
...@@ -181,25 +181,13 @@ class TestReindexCommand(FixtureTestCase): ...@@ -181,25 +181,13 @@ class TestReindexCommand(FixtureTestCase):
self.assertEqual(self.indexer_mock().setup.call_count, 0) self.assertEqual(self.indexer_mock().setup.call_count, 0)
self._assert_all() self._assert_all()
def test_volume_id(self): def test_volume(self):
""" """
Test the reindex command can restrict indexing to a specific volume by ID Test the reindex command can restrict indexing to a specific volume
""" """
call_command( call_command(
'reindex', 'reindex',
volume=str(self.vol.id), volume=self.vol,
)
self.assertEqual(self.indexer_mock().drop_index.call_count, 0)
self.assertEqual(self.indexer_mock().setup.call_count, 0)
self._assert_volume()
def test_volume_name(self):
"""
Test the reindex command can restrict indexing to a specific volume by name
"""
call_command(
'reindex',
volume="volume 1",
) )
self.assertEqual(self.indexer_mock().drop_index.call_count, 0) self.assertEqual(self.indexer_mock().drop_index.call_count, 0)
self.assertEqual(self.indexer_mock().setup.call_count, 0) self.assertEqual(self.indexer_mock().setup.call_count, 0)
......
...@@ -2,6 +2,7 @@ from django.db.models import Model ...@@ -2,6 +2,7 @@ from django.db.models import Model
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.management.base import CommandError from django.core.management.base import CommandError
from arkindex.documents.models import Corpus, Element, DataSource from arkindex.documents.models import Corpus, Element, DataSource
from arkindex.dataimport.models import DataImport, Repository
class ModelArgument(object): class ModelArgument(object):
...@@ -20,17 +21,22 @@ class ModelArgument(object): ...@@ -20,17 +21,22 @@ class ModelArgument(object):
except (ValidationError, Element.DoesNotExist): except (ValidationError, Element.DoesNotExist):
pass pass
text_filter = {'{}__icontains'.format(self.text_search_field): arg}
if self.many:
return qs.filter(**text_filter)
try: try:
return qs.get(**text_filter) return self.text_search(qs, arg)
except self.model.DoesNotExist: except self.model.DoesNotExist:
raise CommandError('{} "{}" does not exist'.format(self.model.__name__, arg)) raise CommandError('{} "{}" does not exist'.format(self.model.__name__, arg))
except self.model.MultipleObjectsReturned: except self.model.MultipleObjectsReturned:
raise CommandError('"{}" matches multiple {} instances'.format(arg, self.model.__name__)) raise CommandError('"{}" matches multiple {} instances'.format(arg, self.model.__name__))
def text_search(self, qs, arg):
if not self.text_search_field:
raise self.model.DoesNotExist
text_filter = {'{}__icontains'.format(self.text_search_field): arg}
if self.many:
return qs.filter(**text_filter)
else:
return qs.get(**text_filter)
class CorpusArgument(ModelArgument): class CorpusArgument(ModelArgument):
model = Corpus model = Corpus
...@@ -43,3 +49,13 @@ class ElementArgument(ModelArgument): ...@@ -43,3 +49,13 @@ class ElementArgument(ModelArgument):
class DataSourceArgument(ModelArgument): class DataSourceArgument(ModelArgument):
model = DataSource model = DataSource
text_search_field = 'slug' text_search_field = 'slug'
class DataImportArgument(ModelArgument):
model = DataImport
text_search_field = None
class RepositoryArgument(ModelArgument):
model = Repository
text_search_field = 'url'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment