From 0c9d58431f83f2fba1f26b8802f499743bd0b8b4 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Thu, 21 Nov 2019 13:09:13 +0000 Subject: [PATCH] Reindex via a Channels worker --- arkindex/dataimport/models.py | 2 +- arkindex/documents/acts.py | 4 +- arkindex/documents/api/admin.py | 16 ++ arkindex/documents/api/ml.py | 18 +- arkindex/documents/consumers.py | 50 +++++ arkindex/documents/managers.py | 14 ++ arkindex/documents/pagexml.py | 15 -- arkindex/documents/serializers/admin.py | 42 +++++ arkindex/documents/tei.py | 14 +- .../documents/tests/test_acts_importer.py | 15 +- arkindex/documents/tests/test_admin_api.py | 74 ++++++++ arkindex/documents/tests/test_consumers.py | 171 ++++++++++++++++++ arkindex/documents/tests/test_pagexml.py | 28 ++- .../tests/test_transcription_create.py | 102 ++++++++--- arkindex/images/importer.py | 32 +--- arkindex/project/api_v1.py | 4 + arkindex/project/routing.py | 7 +- arkindex/project/settings.py | 3 +- arkindex/project/triggers.py | 51 ++++++ openapi/patch.yml | 2 + tests-requirements.txt | 1 + 21 files changed, 563 insertions(+), 102 deletions(-) create mode 100644 arkindex/documents/api/admin.py create mode 100644 arkindex/documents/consumers.py create mode 100644 arkindex/documents/serializers/admin.py create mode 100644 arkindex/documents/tests/test_admin_api.py create mode 100644 arkindex/documents/tests/test_consumers.py create mode 100644 arkindex/project/triggers.py diff --git a/arkindex/dataimport/models.py b/arkindex/dataimport/models.py index c3e47fea78..0a53eef9ac 100644 --- a/arkindex/dataimport/models.py +++ b/arkindex/dataimport/models.py @@ -77,7 +77,7 @@ class DataImport(IndexableModel): 'DB_PASSWORD': settings.DATABASES['default']['PASSWORD'], 'DB_NAME': settings.DATABASES['default']['NAME'], 'LOCAL_IMAGESERVER_ID': settings.LOCAL_IMAGESERVER_ID, - 'ES_HOST': settings.ELASTIC_SEARCH_HOSTS[0], + 'REDIS_HOST': settings.REDIS_HOST, # Some empty folder to bypass the system check 'ML_CLASSIFIERS_DIR': '/data/current', }, diff --git a/arkindex/documents/acts.py b/arkindex/documents/acts.py index 363b0382b2..144d9a1667 100644 --- a/arkindex/documents/acts.py +++ b/arkindex/documents/acts.py @@ -4,8 +4,8 @@ from django.db.models import CharField, Value, Prefetch from django.db.models.functions import Concat from arkindex_common.enums import MetaType from arkindex.project.polygon import Polygon +from arkindex.project.triggers import reindex_start from arkindex.images.models import Image -from arkindex.documents.indexer import Indexer from arkindex.documents.models import Element, ElementType, Corpus, MetaData import csv import logging @@ -154,4 +154,4 @@ class ActsImporter(object): assert failed < count, 'No acts were imported' logger.info('Updating search index') - Indexer().run_index(Element.objects.get_descending(self.folder.id).filter(type=self.act_type)) + reindex_start(element=self.folder, elements=True) diff --git a/arkindex/documents/api/admin.py b/arkindex/documents/api/admin.py new file mode 100644 index 0000000000..eb48ce5050 --- /dev/null +++ b/arkindex/documents/api/admin.py @@ -0,0 +1,16 @@ +from rest_framework.generics import CreateAPIView +from arkindex.project.permissions import IsAdminUser +from arkindex.documents.serializers.admin import ReindexConfigSerializer + + +class ReindexStart(CreateAPIView): + """ + Run an ElasticSearch indexation from the API. + """ + permission_classes = (IsAdminUser, ) + serializer_class = ReindexConfigSerializer + openapi_overrides = { + 'operationId': 'Reindex', + 'description': 'Manually reindex elements, transcriptions and entities for search APIs', + 'tags': ['management'] + } diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index 7ec27dfc6b..eae17521bf 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -17,14 +17,14 @@ from arkindex.documents.serializers.ml import ( TranscriptionsSerializer, TranscriptionCreateSerializer, DataSourceStatsSerializer, ClassificationsSelectionSerializer ) -from arkindex.documents.indexer import Indexer from arkindex.documents.pagexml import PageXmlParser from arkindex.images.models import Zone -from arkindex.images.importer import build_transcriptions, save_transcriptions, index_transcriptions +from arkindex.images.importer import build_transcriptions, save_transcriptions +from arkindex.project.mixins import SelectionMixin from arkindex.project.parsers import XMLParser from arkindex.project.permissions import IsVerified, IsAdminUser from arkindex.project.polygon import Polygon -from arkindex.project.mixins import SelectionMixin +from arkindex.project.triggers import reindex_start import os.path import logging @@ -63,9 +63,7 @@ class TranscriptionCreate(CreateAPIView): ts.save() # Index in ES - Indexer().run_index( - Transcription.objects.filter(id=ts.id) - ) + reindex_start(element=element, transcriptions=True, elements=True) return ts def create(self, request, *args, **kwargs): @@ -150,7 +148,7 @@ class TranscriptionBulk(CreateAPIView, UpdateAPIView): source=source, )) - index_transcriptions(transcriptions) + reindex_start(element=parent, transcriptions=True, elements=True) return Response(None, status=status.HTTP_201_CREATED, headers=headers) @@ -375,11 +373,7 @@ class PageXmlTranscriptionsImport(CreateModelMixin, APIView): len(entities_id) )) - if entities_id: - # Reindex entities to add new entities in ElasticSearch - Indexer().run_index( - Entity.objects.filter(id__in=entities_id) - ) + reindex_start(element=element, transcriptions=True, elements=True, entities=True) return Response( status=status.HTTP_201_CREATED, diff --git a/arkindex/documents/consumers.py b/arkindex/documents/consumers.py new file mode 100644 index 0000000000..e2d77a02b8 --- /dev/null +++ b/arkindex/documents/consumers.py @@ -0,0 +1,50 @@ +from channels.consumer import SyncConsumer +from django.db.models import Q +from arkindex.documents.indexer import Indexer +from arkindex.documents.models import Element, Transcription, Entity + + +class ReindexConsumer(SyncConsumer): + + def reindex_start(self, message): + corpus_id, element_id, transcriptions, elements, entities, drop = map( + message.get, + ('corpus', 'element', 'transcriptions', 'elements', 'entities', 'drop'), + (None, None, True, True, True, False), # Default values + ) + indexer = Indexer() + + if drop: + if transcriptions: + indexer.drop_index(Transcription.es_document) + if elements: + indexer.drop_index(Element.es_document) + if entities: + indexer.drop_index(Entity.es_document) + indexer.setup() + + if element_id or corpus_id: + if element_id: + # Pick this element, and all its children + elements_queryset = list(Element.objects.get_descending(element_id)) + elements_queryset.append(Element.objects.get(id=element_id)) + else: + # Pick all elements in the corpus + elements_queryset = Element.objects.filter(corpus_id=corpus_id) + + transcriptions_queryset = Transcription.objects.filter(element__in=elements_queryset) + entities_queryset = Entity.objects.filter( + Q(metadatas__element__in=elements_queryset) + | Q(transcriptions__element__in=elements_queryset) + ) + else: + transcriptions_queryset = Transcription.objects.all() + elements_queryset = Element.objects.all() + entities_queryset = Entity.objects.all() + + if transcriptions: + indexer.run_index(transcriptions_queryset, bulk_size=400) + if elements: + indexer.run_index(elements_queryset, bulk_size=100) + if entities: + indexer.run_index(entities_queryset, bulk_size=400) diff --git a/arkindex/documents/managers.py b/arkindex/documents/managers.py index 25eb0fbaf8..efd832dc71 100644 --- a/arkindex/documents/managers.py +++ b/arkindex/documents/managers.py @@ -170,3 +170,17 @@ class CorpusManager(models.Manager): # Authenticated users can write only on corpora with ACL return qs.filter(corpus_right__user=user, corpus_right__can_write=True).distinct() + + def admin(self, user): + # An anonymous user cannot manage anything + if user.is_anonymous: + return super().none() + + qs = super().get_queryset().order_by('name') + + # Admins and internal users have access to every corpus + if user.is_admin or user.is_internal: + return qs.all() + + # Authenticated users can manage only corpora with ACL + return qs.filter(corpus_right__user=user, corpus_right__can_admin=True).distinct() diff --git a/arkindex/documents/pagexml.py b/arkindex/documents/pagexml.py index e95663bc98..e8e8fc76e8 100644 --- a/arkindex/documents/pagexml.py +++ b/arkindex/documents/pagexml.py @@ -5,7 +5,6 @@ from arkindex.project.polygon import Polygon from arkindex_common.enums import TranscriptionType from arkindex.documents.models import \ DataSource, Element, Entity, EntityRole, EntityLink, TranscriptionEntity -from arkindex.documents.indexer import Indexer import functools import string import Levenshtein @@ -68,19 +67,6 @@ class PageXmlParser(object): score=1 if region.confidence is None else region.confidence, ) - def index(self, element): - """ - Create or Update ElasticSearch index for transcriptions and related elements - """ - assert isinstance(element, Element), 'Element should be an Arkindex element' - es_indexer = Indexer() - es_indexer.run_index( - Element.objects.filter(id=element.id), - ) - es_indexer.run_index( - element.transcriptions.all(), - ) - def save(self, element): assert isinstance(element, Element), 'Element should be an Arkindex element' if self.pagexml_page.page.text_regions is None or not len(self.pagexml_page.page.text_regions): @@ -115,7 +101,6 @@ class PageXmlParser(object): region_ts_count, line_ts_count, )) - self.index(element) return transcriptions def merge(self, blocks): diff --git a/arkindex/documents/serializers/admin.py b/arkindex/documents/serializers/admin.py new file mode 100644 index 0000000000..0377d233a1 --- /dev/null +++ b/arkindex/documents/serializers/admin.py @@ -0,0 +1,42 @@ +from rest_framework import serializers +from arkindex.project.triggers import reindex_start +from arkindex.documents.models import Corpus, Element + + +class ReindexConfigSerializer(serializers.Serializer): + corpus = serializers.PrimaryKeyRelatedField(queryset=Corpus.objects.all(), required=False) + element = serializers.PrimaryKeyRelatedField(queryset=Element.objects.all(), required=False) + transcriptions = serializers.BooleanField(default=True) + elements = serializers.BooleanField(default=True) + entities = serializers.BooleanField(default=True) + drop = serializers.BooleanField(default=False) + + default_error_messages = { + 'reindex_nothing': 'At least one index type is required.', + 'filtered_drop': '`drop` can only be used when reindexing everything.', + 'element_in_corpus': 'The selected element is not in the selected corpus.', + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.context.get('request'): + return + user = self.context.get('request').user + self.fields['corpus'].queryset = Corpus.objects.admin(user) + self.fields['element'].queryset = Element.objects.filter(corpus__in=Corpus.objects.admin(user)) + + def validate(self, data): + corpus, element, transcriptions, elements, entities, drop = map( + data.get, + ('corpus', 'element', 'transcriptions', 'elements', 'entities', 'drop') + ) + if not (transcriptions or elements or entities): + self.fail('reindex_nothing') + if drop and (corpus or element): + self.fail('filtered_drop') + if element and corpus and element.corpus_id != corpus.id: + self.fail('element_in_corpus') + return data + + def save(self): + reindex_start(**self.validated_data) diff --git a/arkindex/documents/tei.py b/arkindex/documents/tei.py index 078892f754..0da96c338e 100644 --- a/arkindex/documents/tei.py +++ b/arkindex/documents/tei.py @@ -2,8 +2,8 @@ from itertools import groupby from arkindex_common.tei import TeiElement, TeiParser as BaseTeiParser from arkindex_common.enums import MetaType from arkindex.project.tools import find_closest -from arkindex.documents.indexer import Indexer -from arkindex.documents.models import Element, Entity, MetaData, DataSource, MLToolType +from arkindex.project.triggers import reindex_start +from arkindex.documents.models import Element, Entity, DataSource, MLToolType import logging @@ -147,11 +147,7 @@ class TeiParser(BaseTeiParser): """ Create or Update ElasticSearch index for indexed elements """ - md_ids = (md.id for md in db_metadatas) - date_metadatas = MetaData.objects.filter(id__in=md_ids, type=MetaType.Date) - - elements = Element.objects \ - .filter(metadatas__in=date_metadatas, type__folder=False, type__hidden=False) \ - .prefetch_related('metadatas', 'transcriptions') + md_ids = (md.id for md in db_metadatas if md.type == MetaType.Date) + elements = Element.objects.filter(metadatas__id__in=md_ids, type__folder=False, type__hidden=False) if elements.exists(): - Indexer().run_index(elements) + reindex_start(corpus=self.db_corpus, elements=True) diff --git a/arkindex/documents/tests/test_acts_importer.py b/arkindex/documents/tests/test_acts_importer.py index 2216131d03..36bd4cb290 100644 --- a/arkindex/documents/tests/test_acts_importer.py +++ b/arkindex/documents/tests/test_acts_importer.py @@ -1,5 +1,6 @@ from unittest.mock import patch from pathlib import Path +from asyncmock import AsyncMock from arkindex_common.enums import MetaType from arkindex.project.tests import FixtureTestCase from arkindex.project.polygon import Polygon @@ -22,7 +23,9 @@ class TestActsImporter(FixtureTestCase): cls.surface_type = cls.corpus.types.get(slug='surface') cls.volume = cls.corpus.elements.get(name='Volume 2', type=cls.folder_type) - def test_import(self): + @patch('arkindex.project.triggers.get_channel_layer') + def test_import(self, get_layer_mock): + get_layer_mock().send = AsyncMock() self.assertFalse( Element.objects .get_descending(self.volume.id) @@ -83,6 +86,16 @@ class TestActsImporter(FixtureTestCase): self.assertEqual(surface.zone.image, self.img5) self.assertEqual(surface.zone.polygon, Polygon.from_coords(0, 500, 1000, 500)) # Second half + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'element': str(self.volume.id), + 'corpus': None, + 'transcriptions': False, + 'elements': True, + 'entities': False, + 'drop': False, + }) + def test_epic_fail(self): importer = ActsImporter( ACT_SAMPLES / 'volume 2.csv', diff --git a/arkindex/documents/tests/test_admin_api.py b/arkindex/documents/tests/test_admin_api.py new file mode 100644 index 0000000000..d1cf40a44b --- /dev/null +++ b/arkindex/documents/tests/test_admin_api.py @@ -0,0 +1,74 @@ +from unittest.mock import patch +from asyncmock import AsyncMock +from django.urls import reverse +from rest_framework import status +from arkindex.project.tests import FixtureTestCase +from arkindex.documents.models import Corpus + + +class TestAdminAPI(FixtureTestCase): + + def test_reindex_requires_login(self): + response = self.client.post(reverse('api:reindex-start'), {}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_reindex_requires_admin(self): + self.client.force_login(self.user) + response = self.client.post(reverse('api:reindex-start'), {}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_reindex_requires_index_types(self): + self.client.force_login(self.superuser) + response = self.client.post(reverse('api:reindex-start'), { + 'transcriptions': False, + 'elements': False, + 'entities': False, + }) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'non_field_errors': ['At least one index type is required.'], + }) + + def test_reindex_drop_only_all(self): + self.client.force_login(self.superuser) + response = self.client.post(reverse('api:reindex-start'), { + 'corpus': str(self.corpus.id), + 'drop': True, + }) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'non_field_errors': ['`drop` can only be used when reindexing everything.'], + }) + + def test_reindex_element_in_corpus(self): + corpus2 = Corpus.objects.create(name='some corpus') + self.client.force_login(self.superuser) + response = self.client.post(reverse('api:reindex-start'), { + 'corpus': str(corpus2.id), + 'element': str(self.corpus.elements.first().id), + }) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'non_field_errors': ['The selected element is not in the selected corpus.'], + }) + + @patch('arkindex.project.triggers.get_channel_layer') + def test_reindex(self, get_layer_mock): + get_layer_mock.return_value.send = AsyncMock() + + self.client.force_login(self.superuser) + response = self.client.post(reverse('api:reindex-start'), {}) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + get_layer_mock.return_value.send.assert_called_once_with( + 'reindex', + { + 'element': None, + 'corpus': None, + 'transcriptions': True, + 'elements': True, + 'entities': True, + 'drop': False, + 'type': 'reindex.start', + }, + ) diff --git a/arkindex/documents/tests/test_consumers.py b/arkindex/documents/tests/test_consumers.py new file mode 100644 index 0000000000..13e82f0293 --- /dev/null +++ b/arkindex/documents/tests/test_consumers.py @@ -0,0 +1,171 @@ +from unittest.mock import patch +from django.db.models import Q +from arkindex_common.enums import TranscriptionType, MetaType, EntityType +from arkindex.project.tests import FixtureTestCase +from arkindex.documents.consumers import ReindexConsumer +from arkindex.documents.models import Corpus, Element, Transcription, Entity, DataSource + + +@patch('arkindex.documents.consumers.Indexer') +class TestConsumers(FixtureTestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + source = DataSource.objects.get(slug='test') + + cls.folder = cls.corpus.elements.get(name='Volume 1') + cls.folder.metadatas.create( + type=MetaType.Entity, + name='something', + value='Some entity', + entity=cls.corpus.entities.create( + type=EntityType.Person, + name='Some entity', + source=source, + ) + ) + + corpus2 = Corpus.objects.create(name='Another corpus') + element2 = corpus2.elements.create( + type=corpus2.types.create(display_name='Element'), + name='An element', + ) + ts = element2.transcriptions.create( + score=0.8, + text='something', + type=TranscriptionType.Word, + source=source, + ) + ts.transcription_entities.create( + entity=corpus2.entities.create( + type=EntityType.Misc, + name='Some other entity', + source=source, + ), + offset=0, + length=1, + ) + + def assertQuerysetEqual(self, queryset1, queryset2, **options): + """ + Make Django's assertQuerysetEqual slightly nicer to use + """ + options.setdefault('ordered', False) + return super().assertQuerysetEqual(queryset1, map(repr, queryset2), **options) + + def _assert_all_elements(self, call_args): + (queryset, ), kwargs = call_args + self.assertQuerysetEqual(queryset, Element.objects.all()) + self.assertDictEqual(kwargs, {'bulk_size': 100}) + + def _assert_all_entities(self, call_args): + (queryset, ), kwargs = call_args + self.assertQuerysetEqual(queryset, Entity.objects.all()) + self.assertDictEqual(kwargs, {'bulk_size': 400}) + + def _assert_all_transcriptions(self, call_args): + (queryset, ), kwargs = call_args + self.assertQuerysetEqual(queryset, Transcription.objects.all()) + self.assertDictEqual(kwargs, {'bulk_size': 400}) + + def _assert_all(self, mock): + self.assertEqual(mock().run_index.call_count, 3) + elements_call, entities_call, ts_call = sorted(mock().run_index.call_args_list, key=repr) + self._assert_all_elements(elements_call) + self._assert_all_entities(entities_call) + self._assert_all_transcriptions(ts_call) + + def test_reindex_all(self, mock): + ReindexConsumer({}).reindex_start({}) + self.assertEqual(mock().drop_index.call_count, 0) + self._assert_all(mock) + + def test_reindex_drop(self, mock): + ReindexConsumer({}).reindex_start({ + 'drop': True, + }) + self.assertEqual(mock().drop_index.call_count, 3) + mock().drop_index.assert_any_call(Element.es_document) + mock().drop_index.assert_any_call(Entity.es_document) + mock().drop_index.assert_any_call(Transcription.es_document) + self._assert_all(mock) + + def test_reindex_only_transcriptions(self, mock): + ReindexConsumer({}).reindex_start({ + 'transcriptions': True, + 'entities': False, + 'elements': False, + }) + self.assertEqual(mock().drop_index.call_count, 0) + self.assertEqual(mock().run_index.call_count, 1) + self._assert_all_transcriptions(mock().run_index.call_args) + + def test_reindex_only_elements(self, mock): + ReindexConsumer({}).reindex_start({ + 'transcriptions': False, + 'entities': False, + 'elements': True, + }) + self.assertEqual(mock().drop_index.call_count, 0) + self.assertEqual(mock().run_index.call_count, 1) + self._assert_all_elements(mock().run_index.call_args) + + def test_reindex_only_entities(self, mock): + ReindexConsumer({}).reindex_start({ + 'transcriptions': False, + 'entities': True, + 'elements': False, + }) + self.assertEqual(mock().drop_index.call_count, 0) + self.assertEqual(mock().run_index.call_count, 1) + self._assert_all_entities(mock().run_index.call_args) + + def test_reindex_corpus(self, mock): + ReindexConsumer({}).reindex_start({ + 'corpus': str(self.corpus.id), + }) + self.assertEqual(mock().drop_index.call_count, 0) + self.assertEqual(mock().run_index.call_count, 3) + elements_call, entities_call, ts_call = sorted(mock().run_index.call_args_list, key=repr) + + (queryset, ), kwargs = elements_call + self.assertQuerysetEqual(queryset, self.corpus.elements.all()) + self.assertDictEqual(kwargs, {'bulk_size': 100}) + + (queryset, ), kwargs = entities_call + self.assertQuerysetEqual(queryset, self.corpus.entities.all()) + self.assertDictEqual(kwargs, {'bulk_size': 400}) + + (queryset, ), kwargs = ts_call + self.assertQuerysetEqual(queryset, Transcription.objects.filter(element__corpus_id=self.corpus.id)) + self.assertDictEqual(kwargs, {'bulk_size': 400}) + + def test_reindex_element(self, mock): + ReindexConsumer({}).reindex_start({ + 'element': str(self.folder.id), + }) + self.assertEqual(mock().drop_index.call_count, 0) + self.assertEqual(mock().run_index.call_count, 3) + entities_call, ts_call, elements_call = sorted(mock().run_index.call_args_list, key=repr) + + elements_list = list(Element.objects.get_descending(self.folder.id)) + elements_list.append(self.folder) + + (queryset, ), kwargs = elements_call + self.assertQuerysetEqual(queryset, elements_list) + self.assertDictEqual(kwargs, {'bulk_size': 100}) + + (queryset, ), kwargs = entities_call + self.assertQuerysetEqual( + queryset, + Entity.objects.filter( + Q(metadatas__element__in=elements_list) + | Q(transcriptions__element__in=elements_list) + ), + ) + self.assertDictEqual(kwargs, {'bulk_size': 400}) + + (queryset, ), kwargs = ts_call + self.assertQuerysetEqual(queryset, Transcription.objects.filter(element__in=elements_list)) + self.assertDictEqual(kwargs, {'bulk_size': 400}) diff --git a/arkindex/documents/tests/test_pagexml.py b/arkindex/documents/tests/test_pagexml.py index 4213a017bc..902a60cdfa 100644 --- a/arkindex/documents/tests/test_pagexml.py +++ b/arkindex/documents/tests/test_pagexml.py @@ -1,4 +1,6 @@ from pathlib import Path +from unittest.mock import patch +from asyncmock import AsyncMock from django.urls import reverse from rest_framework import status from arkindex.project.tests import FixtureAPITestCase @@ -37,7 +39,9 @@ class TestPageXml(FixtureAPITestCase): ) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) - def test_pagexml_import(self): + @patch('arkindex.project.triggers.get_channel_layer') + def test_pagexml_import(self, get_layer_mock): + get_layer_mock.return_value.send = AsyncMock() self.assertFalse(self.page.transcriptions.exists()) self.client.force_login(self.user) with (FIXTURES / 'transcript.xml').open() as f: @@ -60,6 +64,15 @@ class TestPageXml(FixtureAPITestCase): ]) # All transcriptions have a score of 100% self.assertFalse(self.page.transcriptions.exclude(score=1.0).exists()) + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'element': str(self.page.id), + 'corpus': None, + 'transcriptions': True, + 'entities': True, + 'elements': True, + 'drop': False, + }) def test_pagexml_import_requires_zone(self): volume = self.corpus.elements.get(name='Volume 1') @@ -74,7 +87,9 @@ class TestPageXml(FixtureAPITestCase): ) self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) - def test_pagexml_import_any_element(self): + @patch('arkindex.project.triggers.get_channel_layer') + def test_pagexml_import_any_element(self, get_layer_mock): + get_layer_mock.return_value.send = AsyncMock() surface = self.corpus.elements.get(name='Surface A') self.client.force_login(self.user) with (FIXTURES / 'transcript.xml').open() as f: @@ -97,6 +112,15 @@ class TestPageXml(FixtureAPITestCase): ]) # All transcriptions have a score of 100% self.assertFalse(surface.transcriptions.exclude(score=1.0).exists()) + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'element': str(surface.id), + 'corpus': None, + 'transcriptions': True, + 'elements': True, + 'entities': True, + 'drop': False, + }) def test_pagexml_create_blocks(self): parser = PageXmlParser(FIXTURES / 'create_blocks.xml') diff --git a/arkindex/documents/tests/test_transcription_create.py b/arkindex/documents/tests/test_transcription_create.py index 999351ce98..050acf8170 100644 --- a/arkindex/documents/tests/test_transcription_create.py +++ b/arkindex/documents/tests/test_transcription_create.py @@ -1,5 +1,6 @@ from django.urls import reverse from unittest.mock import patch +from asyncmock import AsyncMock from rest_framework import status from arkindex.project.tests import FixtureAPITestCase from arkindex.project.polygon import Polygon @@ -24,11 +25,13 @@ class TestTranscriptionCreate(FixtureAPITestCase): response = self.client.post(reverse('api:transcription-create'), format='json') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - @patch('arkindex.documents.api.ml.Indexer') - def test_create_transcription(self, indexer): + @patch('arkindex.project.triggers.get_channel_layer') + def test_create_transcription(self, get_layer_mock): """ Checks the view creates transcriptions, zones, links, paths and runs ES indexing """ + get_layer_mock.return_value.send = AsyncMock() + self.client.force_login(self.user) response = self.client.post(reverse('api:transcription-create'), format='json', data={ "type": "word", @@ -39,18 +42,29 @@ class TestTranscriptionCreate(FixtureAPITestCase): "score": 0.83, }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) + new_ts = Transcription.objects.get(text="NEKUDOTAYIM", type=TranscriptionType.Word) self.assertEqual(new_ts.zone.polygon, Polygon.from_coords(0, 0, 100, 100)) self.assertEqual(new_ts.score, 0.83) self.assertEqual(new_ts.source, self.src) self.assertTrue(self.page.transcriptions.filter(pk=new_ts.id).exists()) - self.assertTrue(indexer.return_value.run_index.called) - @patch('arkindex.documents.api.ml.Indexer') - def test_unique_zone(self, indexer): + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'element': str(self.page.id), + 'corpus': None, + 'transcriptions': True, + 'elements': True, + 'entities': False, + 'drop': False, + }) + + @patch('arkindex.project.triggers.get_channel_layer') + def test_unique_zone(self, get_layer_mock): """ Checks the view reuses zones when available """ + get_layer_mock.return_value.send = AsyncMock() self.client.force_login(self.user) ts = Transcription.objects.get(zone__image__path='img1', text="PARIS") response = self.client.post(reverse('api:transcription-create'), format='json', data={ @@ -63,12 +77,22 @@ class TestTranscriptionCreate(FixtureAPITestCase): }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(Transcription.objects.get(text="GLOUBIBOULGA").zone.id, ts.zone.id) + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'element': str(self.page.id), + 'corpus': None, + 'transcriptions': True, + 'elements': True, + 'entities': False, + 'drop': False, + }) - @patch('arkindex.documents.api.ml.Indexer') - def test_update_transcription(self, indexer): + @patch('arkindex.project.triggers.get_channel_layer') + def test_update_transcription(self, get_layer_mock): """ Checks the view updates transcriptions when they already exist """ + get_layer_mock.return_value.send = AsyncMock() self.client.force_login(self.user) ts = Transcription.objects.get(zone__image__path='img1', text="PARIS") self.assertNotEqual(ts.score, 0.99) @@ -85,6 +109,16 @@ class TestTranscriptionCreate(FixtureAPITestCase): ts.refresh_from_db() self.assertEqual(ts.score, 0.99) + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'element': str(self.page.id), + 'corpus': None, + 'transcriptions': True, + 'elements': True, + 'entities': False, + 'drop': False, + }) + def test_invalid_data(self): """ Checks the view validates data properly @@ -123,12 +157,13 @@ class TestTranscriptionCreate(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @patch('arkindex.project.serializer_fields.MLTool.get') - @patch('arkindex.images.importer.Indexer') - def test_create_bulk_transcription(self, indexer, ml_get_mock): + @patch('arkindex.project.triggers.get_channel_layer') + def test_create_bulk_transcription(self, get_layer_mock, ml_get_mock): """ Checks the view creates transcriptions, zones, links, paths and runs ES indexing Using bulk_transcriptions """ + get_layer_mock.return_value.send = AsyncMock() ml_get_mock.return_value.type = self.src.type ml_get_mock.return_value.slug = self.src.slug ml_get_mock.return_value.name = self.src.name @@ -176,18 +211,15 @@ class TestTranscriptionCreate(FixtureAPITestCase): self.assertEqual(page_ts.source, self.src) self.assertEqual(page_ts.zone, self.page.zone) - # Indexer called - self.assertEqual(indexer.return_value.run_index.call_count, 2) - indexer.reset_mock() - - # Call once more, ensuring creating the exact same transcriptions does not make duplicates - response = self.client.post(reverse('api:transcription-bulk'), format='json', data=request_data) - self.assertEqual(self.page.transcriptions.count(), 3) - self.assertCountEqual( - self.page.transcriptions.values_list('id', flat=True), - [word_ts.id, line_ts.id, page_ts.id], - ) - self.assertEqual(indexer.return_value.run_index.call_count, 0) + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'corpus': None, + 'element': str(self.page.id), + 'transcriptions': True, + 'elements': True, + 'entities': False, + 'drop': False, + }) @patch('arkindex.project.serializer_fields.MLTool.get') def test_bulk_transcription_no_zone(self, ml_get_mock): @@ -216,8 +248,9 @@ class TestTranscriptionCreate(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @patch('arkindex.project.serializer_fields.MLTool.get') - @patch('arkindex.images.importer.Indexer') - def test_update_bulk_transcription(self, indexer, ml_get_mock): + @patch('arkindex.project.triggers.get_channel_layer') + def test_update_bulk_transcription(self, get_layer_mock, ml_get_mock): + get_layer_mock.return_value.send = AsyncMock() ml_get_mock.return_value.type = self.src.type ml_get_mock.return_value.slug = self.src.slug ml_get_mock.return_value.name = self.src.name @@ -258,11 +291,18 @@ class TestTranscriptionCreate(FixtureAPITestCase): self.assertEqual(page_ts.source, self.src) self.assertEqual(page_ts.zone, self.page.zone) - # Indexer called - self.assertEqual(indexer.return_value.run_index.call_count, 2) + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'corpus': None, + 'element': str(self.page.id), + 'transcriptions': True, + 'elements': True, + 'entities': False, + 'drop': False, + }) # Update again - indexer.reset_mock() + get_layer_mock().send.reset_mock() data['transcriptions'] = [ { "type": "word", @@ -280,7 +320,15 @@ class TestTranscriptionCreate(FixtureAPITestCase): response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(indexer.return_value.run_index.call_count, 2) + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'corpus': None, + 'element': str(self.page.id), + 'transcriptions': True, + 'elements': True, + 'entities': False, + 'drop': False, + }) # Previous transcriptions should be replaced by two Word and one Page transcription self.assertEqual(self.page.transcriptions.count(), 3) diff --git a/arkindex/images/importer.py b/arkindex/images/importer.py index 46313dc619..3b7332f19e 100644 --- a/arkindex/images/importer.py +++ b/arkindex/images/importer.py @@ -2,8 +2,7 @@ from collections import namedtuple from django.db import connection from arkindex_common.enums import TranscriptionType from arkindex.project.polygon import Polygon -from arkindex.documents.indexer import Indexer -from arkindex.documents.models import Element, Transcription +from arkindex.documents.models import Element from arkindex.images.models import Image, Zone import csv import io @@ -162,32 +161,3 @@ def save_transcriptions(*tr_polygons, delimiter='\t', quotechar='"'): ) return transcriptions - - -def index_transcriptions(items): - ''' - Index in ElasticSearch new transcriptions built above - TODO: Delete if both the PonosCommand deletion and Channels reindex are merged! - ''' - transcriptions = [] - for item in items: - if isinstance(item, tuple): - transcriptions.append(Transcription( - **dict(zip( - ('id', 'element_id', 'source_id', 'zone_id', 'type', 'text', 'score'), - item, - )) - )) - elif isinstance(item, Transcription): - transcriptions.append(item) - else: - raise ValueError('Items must be a list of tuples or Transcription instances') - - # Index transcriptions directly (IIIF search) - indexer = Indexer() - indexer.run_index(transcriptions) - - # Index transcriptions in elements - elements = Element.objects.filter(id__in=[t.element_id for t in transcriptions]) - if elements.exists(): - indexer.run_index(elements) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 4af5f90b57..ff0c471401 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -18,6 +18,7 @@ from arkindex.documents.api.entities import ( ElementLinks ) from arkindex.documents.api.iiif import FolderManifest, ElementAnnotationList, TranscriptionSearchAnnotationList +from arkindex.documents.api.admin import ReindexStart from arkindex.dataimport.api import ( DataImportsList, DataImportDetails, DataImportRetry, DataFileList, DataFileRetrieve, DataFileUpload, DataImportFromFiles, @@ -149,4 +150,7 @@ api = [ path('user/token/', UserEmailVerification.as_view(), name='user-token'), path('user/password-reset/', PasswordReset.as_view(), name='password-reset'), path('user/password-reset/confirm/', PasswordResetConfirm.as_view(), name='password-reset-confirm'), + + # Management tools + path('reindex/', ReindexStart.as_view(), name='reindex-start'), ] diff --git a/arkindex/project/routing.py b/arkindex/project/routing.py index 60f19878e0..8fb6189ba6 100644 --- a/arkindex/project/routing.py +++ b/arkindex/project/routing.py @@ -1,6 +1,8 @@ from channels.auth import AuthMiddlewareStack -from channels.routing import ProtocolTypeRouter, URLRouter +from channels.routing import ProtocolTypeRouter, URLRouter, ChannelNameRouter from channels.security.websocket import AllowedHostsOriginValidator +from arkindex.project.triggers import REINDEX_CHANNEL +from arkindex.documents.consumers import ReindexConsumer application = ProtocolTypeRouter({ 'websocket': AllowedHostsOriginValidator( @@ -9,4 +11,7 @@ application = ProtocolTypeRouter({ ]), ), ), + 'channel': ChannelNameRouter({ + REINDEX_CHANNEL: ReindexConsumer, + }) }) diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index d860f56880..eab4794635 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -280,12 +280,13 @@ CACHES = { } # Django Channels layer using Redis +REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') CHANNEL_LAYERS = { "default": { "BACKEND": "channels_redis.core.RedisChannelLayer", "CONFIG": { "hosts": [ - (os.environ.get('REDIS_HOST', 'localhost'), 6379) + (REDIS_HOST, 6379) ], }, }, diff --git a/arkindex/project/triggers.py b/arkindex/project/triggers.py new file mode 100644 index 0000000000..472c915da0 --- /dev/null +++ b/arkindex/project/triggers.py @@ -0,0 +1,51 @@ +""" +Helper methods to trigger tasks in asynchronous workers +""" +from typing import Union +from uuid import UUID +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer +from arkindex.documents.models import Element, Corpus + +REINDEX_CHANNEL = 'reindex' + +ACTION_REINDEX_START = 'reindex.start' + + +def reindex_start(*, + element: Union[Element, UUID, str] = None, + corpus: Union[Corpus, UUID, str] = None, + transcriptions: bool = False, + elements: bool = False, + entities: bool = False, + drop: bool = False) -> None: + """ + Reindex elements into ElasticSearch. + + If `element` and `corpus` are left unspecified, all elements will be picked. + If `element` is specified, its children will be included. + + `drop` will REMOVE all indexes entirely and recreate them before indexing. + This is only allowed when reindexing all elements. + """ + element_id = None + corpus_id = None + if isinstance(element, Element): + element_id = str(element.id) + elif element: + element_id = str(element) + + if isinstance(corpus, Corpus): + corpus_id = str(corpus.id) + elif corpus: + corpus_id = str(corpus) + + async_to_sync(get_channel_layer().send)(REINDEX_CHANNEL, { + 'type': ACTION_REINDEX_START, + 'element': element_id, + 'corpus': corpus_id, + 'transcriptions': transcriptions, + 'elements': elements, + 'entities': entities, + 'drop': drop, + }) diff --git a/openapi/patch.yml b/openapi/patch.yml index 9d6b6ae748..3dad786b97 100644 --- a/openapi/patch.yml +++ b/openapi/patch.yml @@ -39,6 +39,8 @@ tags: description: Machine Learning tools and results - name: entities - name: users + - name: management + description: Admin-only tools paths: /api/v1/classification/bulk/: post: diff --git a/tests-requirements.txt b/tests-requirements.txt index 6c1358c164..8381422153 100644 --- a/tests-requirements.txt +++ b/tests-requirements.txt @@ -4,3 +4,4 @@ django-nose coverage uritemplate==3 responses +asyncmock==0.4.1 -- GitLab