From cae7adfb46f1da38e0ffe60628cc0ec75d4d6f79 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Fri, 14 Sep 2018 08:25:28 +0000 Subject: [PATCH] Fix ES limiting nested results to 3 --- arkindex/documents/api/iiif.py | 1 + arkindex/documents/api/search.py | 91 ++++------- arkindex/documents/search.py | 5 + arkindex/documents/serializers/search.py | 2 + arkindex/documents/tests/test_search.py | 154 +++++++++++++++++++ arkindex/documents/tests/test_search_post.py | 97 ------------ arkindex/project/mixins.py | 67 +++++++- arkindex/project/settings.py | 2 + arkindex/templates/elastic/search_acts.json | 6 +- 9 files changed, 261 insertions(+), 164 deletions(-) create mode 100644 arkindex/documents/tests/test_search.py delete mode 100644 arkindex/documents/tests/test_search_post.py diff --git a/arkindex/documents/api/iiif.py b/arkindex/documents/api/iiif.py index f863891f68..dbe108a444 100644 --- a/arkindex/documents/api/iiif.py +++ b/arkindex/documents/api/iiif.py @@ -111,6 +111,7 @@ class TranscriptionSearchAnnotationList(RetrieveAPIView): 'query': self.request.query_params.get('q'), 'type': 'word', 'corpus_id': elt.corpus_id, + 'inner_hits_size': settings.ES_INNER_RESULTS_LIMIT, }, ), es_index=settings.ES_INDEX_TRANSCRIPTIONS, diff --git a/arkindex/documents/api/search.py b/arkindex/documents/api/search.py index 073da3a940..a57034eb8b 100644 --- a/arkindex/documents/api/search.py +++ b/arkindex/documents/api/search.py @@ -1,85 +1,46 @@ from django.conf import settings from rest_framework.generics import ListAPIView -from rest_framework.exceptions import ValidationError, PermissionDenied -from arkindex.documents.models import Transcription, Act, Corpus +from arkindex.documents.models import Transcription, Act from arkindex.documents.serializers.search import TranscriptionSearchResultSerializer, ActSearchResultSerializer from arkindex.documents.search import search_transcriptions_post, search_acts_post -from arkindex.project.elastic import ESQuerySet -from arkindex.project.tools import elasticsearch_escape +from arkindex.project.mixins import SearchAPIMixin -class SearchAPIMixin(object): - def get(self, request, *args, **kwargs): - q = request.query_params.get('q') - if not q or q.isspace(): - raise ValidationError('A search query is required') - return super().get(request, *args, **kwargs) +class SearchAPIView(SearchAPIMixin, ListAPIView): + """ + A base class for ES search views + """ + template_path = None + es_source = True + es_query = None + es_index = None + es_type = None + es_sort = None -class TranscriptionSearch(SearchAPIMixin, ListAPIView): +class TranscriptionSearch(SearchAPIView): """ Search and list transcriptions, using pagination """ serializer_class = TranscriptionSearchResultSerializer + template_path = 'elastic/search_transcriptions.json' + es_sort = {"score": {"order": "desc", "mode": "max"}} + es_index = settings.ES_INDEX_TRANSCRIPTIONS + es_type = Transcription.INDEX_TYPE - def get_queryset(self): - query = self.request.query_params.get('q') - if not query: - return - context = { - 'query': elasticsearch_escape(query), - 'type': self.request.query_params.get('type'), - } - if 'corpus' in self.request.query_params: - try: - context['corpus_id'] = Corpus.objects.readable(self.request.user) \ - .get(id=self.request.query_params['corpus']).id - except Corpus.DoesNotExist: - raise PermissionDenied() - else: - context['corpora_ids'] = Corpus.objects.readable(self.request.user).values_list('id', flat=True) - - return ESQuerySet( - query=ESQuerySet.make_query( - 'elastic/search_transcriptions.json', - ctx=context, - ), - sort={"score": {"order": "desc", "mode": "max"}}, - es_index=settings.ES_INDEX_TRANSCRIPTIONS, - es_type=Transcription.INDEX_TYPE, - post_process=search_transcriptions_post, - ) + def post_process(self, *args, **kwargs): + return search_transcriptions_post(*args) -class ActSearch(SearchAPIMixin, ListAPIView): +class ActSearch(SearchAPIView): """ Search for acts containing a specific word """ serializer_class = ActSearchResultSerializer + template_path = 'elastic/search_acts.json' + es_source = False + es_index = settings.ES_INDEX_ACTS + es_type = Act.INDEX_TYPE - def get_queryset(self): - query = self.request.query_params.get('q') - if not query: - return - context = { - 'query': elasticsearch_escape(query), - 'type': self.request.query_params.get('type'), - } - if 'corpus' in self.request.query_params: - try: - context['corpus_id'] = Corpus.objects.readable(self.request.user) \ - .get(id=self.request.query_params['corpus']).id - except Corpus.DoesNotExist: - raise PermissionDenied() - else: - context['corpora_ids'] = Corpus.objects.readable(self.request.user).values_list('id', flat=True) - return ESQuerySet( - _source=False, - query=ESQuerySet.make_query( - 'elastic/search_acts.json', - ctx=context, - ), - es_index=settings.ES_INDEX_ACTS, - es_type=Act.INDEX_TYPE, - post_process=search_acts_post, - ) + def post_process(self, *args, **kwargs): + return search_acts_post(*args) diff --git a/arkindex/documents/search.py b/arkindex/documents/search.py index ec0653715a..af7c9f3c72 100644 --- a/arkindex/documents/search.py +++ b/arkindex/documents/search.py @@ -38,6 +38,10 @@ def search_acts_post(data): for result in results for hit in result['inner_hits']['transcriptions']['hits']['hits'] ] + tr_totals = { + uuid.UUID(result['_id']): result['inner_hits']['transcriptions']['hits']['total'] + for result in results + } transcriptions = { t.id: t @@ -61,6 +65,7 @@ def search_acts_post(data): for act in acts: act.transcriptions_results = [transcriptions[tid] for tid in acts_tr_ids[act.id]] + act.total_transcriptions = tr_totals[act.id] act.surfaces = [surf.zone for surf in all_surfaces.get(act.id, [])] act.parents = all_parents.get(act.id, []) diff --git a/arkindex/documents/serializers/search.py b/arkindex/documents/serializers/search.py index 2336b54772..846dc617ad 100644 --- a/arkindex/documents/serializers/search.py +++ b/arkindex/documents/serializers/search.py @@ -36,6 +36,7 @@ class ActSearchResultSerializer(serializers.ModelSerializer): Serialize an act """ transcriptions = TranscriptionSerializer(many=True, source='transcriptions_results') + total_transcriptions = serializers.IntegerField() surfaces = ZoneSerializer(many=True) parents = serializers.ListField( child=serializers.ListField( @@ -52,6 +53,7 @@ class ActSearchResultSerializer(serializers.ModelSerializer): 'name', 'number', 'transcriptions', + 'total_transcriptions', 'surfaces', 'parents', 'viewer_url', diff --git a/arkindex/documents/tests/test_search.py b/arkindex/documents/tests/test_search.py new file mode 100644 index 0000000000..156190a29a --- /dev/null +++ b/arkindex/documents/tests/test_search.py @@ -0,0 +1,154 @@ +from arkindex.project.tests import FixtureAPITestCase +from arkindex.documents.models import Transcription, Act, Element +from django.urls import reverse +from rest_framework import status +from unittest.mock import patch + + +class TestSearchAPI(FixtureAPITestCase): + + @classmethod + def setUpTestData(cls): + cls.es_mock = patch('arkindex.project.elastic.Elasticsearch').start() + + def setUp(self): + self.es_mock.reset_mock() + self.es_mock().reset_mock() + + def build_es_response(self, hits): + return { + "hits": { + "total": len(hits), + "max_score": None, + "hits": hits + }, + "_shards": { + "total": 5, + "failed": 0, + "skipped": 0, + "successful": 5 + }, + "took": 42, + "timed_out": False + } + + def make_transcription_hit(self, ts): + return { + "_id": str(ts.id.hex), + "_score": None, + "sort": [ + ts.score + ], + "_index": "transcriptions", + "_type": Transcription.INDEX_TYPE, + "_source": ts.build_search_index() + } + + def make_nested_transcription_hit(self, ts): + return { + "_source": { + "text": ts.text, + "id": str(ts.id), + "score": ts.score, + }, + "_score": ts.score, + "_nested": { + "field": "transcriptions", + "offset": 1337 + } + } + + def make_act_hit(self, act, ts, score=1.0): + return { + "_score": score, + "_type": Act.INDEX_TYPE, + "_id": str(act.id.hex), + "_index": "acts", + "inner_hits": { + "transcriptions": { + "hits": { + "total": len(ts), + "hits": list(map(self.make_nested_transcription_hit, ts)), + "max_score": 1337, + } + } + } + } + + def test_transcription_search(self): + expected = Transcription.objects.filter(text="PARIS") + + self.es_mock().count.return_value = {'count': len(expected)} + self.es_mock().search.return_value = self.build_es_response( + list(map(self.make_transcription_hit, expected)), + ) + + response = self.client.get(reverse('api:transcription-search'), {'q': "paris"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + results = response.json()["results"] + + self.assertCountEqual( + [r['id'] for r in results], + map(str, expected.values_list('id', flat=True)), + ) + + def test_act_search(self): + act = Act.objects.get(number="1") + ts = Transcription.objects.filter(text__in=["PARIS", "ROY"], zone__image__path='img1') + surf = Element.objects.get(name="Surface A").zone + + self.es_mock().count.return_value = {'count': len(ts)} + self.es_mock().search.return_value = self.build_es_response( + [self.make_act_hit(act, ts), ], + ) + + response = self.client.get(reverse('api:act-search'), {'q': "paris roy"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + results = response.json()["results"] + self.assertEqual(len(results), 1) + result = results[0] + + self.assertEqual(result['id'], str(act.id)) + self.assertCountEqual( + [t['id'] for t in result['transcriptions']], + map(str, ts.values_list('id', flat=True)), + ) + + self.assertEqual(len(result['surfaces']), 1) + self.assertEqual(result['surfaces'][0]['id'], str(surf.id)) + + def test_iiif_transcription_search(self): + # Filter to only get transcriptions from volume 1 + unfiltered = Transcription.objects.filter(text="PARIS") + expected = Transcription.objects.filter(text="PARIS", zone__image__path__in=['img1', 'img2', 'img3']) + + self.es_mock().count.return_value = {'count': len(unfiltered)} + self.es_mock().search.return_value = self.build_es_response( + list(map(self.make_transcription_hit, unfiltered)) + ) + + response = self.client.get(reverse( + 'api:ts-search-manifest', + kwargs={'pk': str(Element.objects.get(name='Volume 1').id)} + ), {'q': 'paris'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + + self.assertEqual(data['@context'], "http://iiif.io/api/search/0/context.json") + self.assertEqual(data['@type'], 'sc:AnnotationList') + self.assertEqual(data['startIndex'], 0) + self.assertEqual(data['within']['@type'], 'sc:Layer') + self.assertEqual(data['within']['total'], len(expected)) + + hits = data['hits'] + self.assertTrue(all(hit['@type'] == 'search:Hit' for hit in hits)) + self.assertTrue(all(hit['match'] == 'PARIS' for hit in hits)) + + resources = data['resources'] + self.assertTrue(all(res['@type'] == 'oa:Annotation' for res in resources)) + self.assertTrue(all(res['motivation'] == 'sc:painting' for res in resources)) + self.assertTrue(all(res['resource']['@type'] == 'cnt:ContentAsText' for res in resources)) + self.assertTrue(all(res['resource']['format'] == 'text/plain' for res in resources)) + self.assertTrue(all(res['resource']['chars'] == 'PARIS' for res in resources)) diff --git a/arkindex/documents/tests/test_search_post.py b/arkindex/documents/tests/test_search_post.py deleted file mode 100644 index 527c3cd0e0..0000000000 --- a/arkindex/documents/tests/test_search_post.py +++ /dev/null @@ -1,97 +0,0 @@ -from arkindex.project.tests import FixtureTestCase -from arkindex.documents.models import Transcription, Act, Element -from arkindex.documents.search import search_transcriptions_post, search_acts_post, search_transcriptions_filter_post - - -class TestSearchPostProcess(FixtureTestCase): - """Test ElasticSearch post-processing functions""" - - def build_es_response(self, hits): - return { - "hits": { - "total": len(hits), - "max_score": None, - "hits": hits - }, - "_shards": { - "total": 5, - "failed": 0, - "skipped": 0, - "successful": 5 - }, - "took": 42, - "timed_out": False - } - - def make_transcription_hit(self, ts): - return { - "_id": str(ts.id.hex), - "_score": None, - "sort": [ - ts.score - ], - "_index": "transcriptions", - "_type": Transcription.INDEX_TYPE, - "_source": ts.build_search_index() - } - - def make_nested_transcription_hit(self, ts): - return { - "_source": { - "text": ts.text, - "id": str(ts.id), - "score": ts.score, - }, - "_score": ts.score, - "_nested": { - "field": "transcriptions", - "offset": 1337 - } - } - - def make_act_hit(self, act, ts, score=1.0): - return { - "_score": score, - "_type": Act.INDEX_TYPE, - "_id": str(act.id.hex), - "_index": "acts", - "inner_hits": { - "transcriptions": { - "hits": { - "total": len(ts), - "hits": list(map(self.make_nested_transcription_hit, ts)), - "max_score": 1337, - } - } - } - } - - def test_search_transcriptions_post(self): - expected = Transcription.objects.filter(text="PARIS") - results = search_transcriptions_post(self.build_es_response( - list(map(self.make_transcription_hit, expected)) - )) - self.assertCountEqual(results, expected) - - def test_search_acts_post(self): - act = Act.objects.get(number="1") - ts = Transcription.objects.filter(text__in=["PARIS", "ROY"], zone__image__path='img1') - surf = Element.objects.get(name="Surface A").zone - results = search_acts_post(self.build_es_response( - [self.make_act_hit(act, ts), ] - )) - - self.assertCountEqual(results, [act]) - self.assertTrue(hasattr(results[0], 'transcriptions')) - self.assertTrue(hasattr(results[0], 'surfaces')) - self.assertCountEqual(results[0].transcriptions_results, ts) - self.assertCountEqual(results[0].surfaces, [surf]) - - def test_search_transcriptions_filter_post(self): - # Filter to only get transcriptions from volume 1 - unfiltered = Transcription.objects.filter(text="PARIS") - expected = Transcription.objects.filter(text="PARIS", zone__image__path__in=['img1', 'img2', 'img3']) - results = search_transcriptions_filter_post(self.build_es_response( - list(map(self.make_transcription_hit, unfiltered)) - ), Element.objects.get(name="Volume 1").id) - self.assertCountEqual(results, expected) diff --git a/arkindex/project/mixins.py b/arkindex/project/mixins.py index fa2fc6238c..50c46cb191 100644 --- a/arkindex/project/mixins.py +++ b/arkindex/project/mixins.py @@ -1,6 +1,9 @@ +from django.conf import settings from django.shortcuts import get_object_or_404 -from rest_framework.exceptions import PermissionDenied +from rest_framework.exceptions import PermissionDenied, ValidationError, APIException from arkindex.documents.models import Corpus, Right +from arkindex.project.elastic import ESQuerySet +from arkindex.project.tools import elasticsearch_escape class CorpusACLMixin(object): @@ -24,3 +27,65 @@ class CorpusACLMixin(object): assert isinstance(corpus, Corpus) return self.request.user.is_admin or not self.request.user.is_anonymous and \ Right.Admin in corpus.get_acl_rights(self.request.user) + + +class SearchAPIMixin(CorpusACLMixin): + + def get(self, request, *args, **kwargs): + q = request.query_params.get('q') + if not q or q.isspace(): + raise ValidationError('A search query is required') + return super().get(request, *args, **kwargs) + + def get_context(self): + context = { + 'query': elasticsearch_escape(self.request.query_params['q']), + 'type': self.request.query_params.get('type'), + 'inner_hits_size': settings.ES_INNER_RESULTS_LIMIT, + } + if 'corpus' in self.request.query_params: + try: + context['corpus_id'] = self.get_corpus(self.request.query_params['corpus']) + except Corpus.DoesNotExist: + raise PermissionDenied + else: + context['corpora_ids'] = Corpus.objects.readable(self.request.user).values_list('id', flat=True) + return context + + def get_template_path(self): + if self.template_path: + return self.template_path + raise APIException('A JSON template path is required to build an ElasticSearch query.') + + def get_query(self): + if self.es_query: + return self.es_query + + return ESQuerySet.make_query( + self.get_template_path(), + ctx=self.get_context(), + ) + + def get_sort(self): + return self.es_sort or {} + + def get_index(self): + if not self.es_index: + raise APIException('An ElasticSearch index name is required.') + return self.es_index + + def get_type(self): + return self.es_type or '_doc' + + def get_queryset(self): + return ESQuerySet( + _source=self.es_source, + query=self.get_query(), + sort=self.get_sort(), + es_index=self.get_index(), + es_type=self.get_type(), + post_process=self.post_process, + ) + + def post_process(self, *args): + return args diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index bb71201372..8fa49c8687 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -216,6 +216,8 @@ ELASTIC_SEARCH_HOSTS = [ ] # The Scroll API is required to go over 10K results ES_RESULTS_LIMIT = 10000 +# ES defaults to three items returned in a nested query if the inner_hits size is not defined +ES_INNER_RESULTS_LIMIT = 10 ES_INDEX_TRANSCRIPTIONS = 'transcriptions' ES_INDEX_ACTS = 'acts' diff --git a/arkindex/templates/elastic/search_acts.json b/arkindex/templates/elastic/search_acts.json index 16c7355de9..b56f12f437 100644 --- a/arkindex/templates/elastic/search_acts.json +++ b/arkindex/templates/elastic/search_acts.json @@ -1,7 +1,11 @@ { "nested": { "path": "transcriptions", - "inner_hits": {}, + "inner_hits": { + {% if inner_hits_size %} + "size": {{ inner_hits_size }} + {% endif %} + }, "score_mode": "sum", "query": { "function_score": { -- GitLab