From cae7adfb46f1da38e0ffe60628cc0ec75d4d6f79 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Fri, 14 Sep 2018 08:25:28 +0000
Subject: [PATCH] Fix ES limiting nested results to 3

---
 arkindex/documents/api/iiif.py               |   1 +
 arkindex/documents/api/search.py             |  91 ++++-------
 arkindex/documents/search.py                 |   5 +
 arkindex/documents/serializers/search.py     |   2 +
 arkindex/documents/tests/test_search.py      | 154 +++++++++++++++++++
 arkindex/documents/tests/test_search_post.py |  97 ------------
 arkindex/project/mixins.py                   |  67 +++++++-
 arkindex/project/settings.py                 |   2 +
 arkindex/templates/elastic/search_acts.json  |   6 +-
 9 files changed, 261 insertions(+), 164 deletions(-)
 create mode 100644 arkindex/documents/tests/test_search.py
 delete mode 100644 arkindex/documents/tests/test_search_post.py

diff --git a/arkindex/documents/api/iiif.py b/arkindex/documents/api/iiif.py
index f863891f68..dbe108a444 100644
--- a/arkindex/documents/api/iiif.py
+++ b/arkindex/documents/api/iiif.py
@@ -111,6 +111,7 @@ class TranscriptionSearchAnnotationList(RetrieveAPIView):
                     'query': self.request.query_params.get('q'),
                     'type': 'word',
                     'corpus_id': elt.corpus_id,
+                    'inner_hits_size': settings.ES_INNER_RESULTS_LIMIT,
                 },
             ),
             es_index=settings.ES_INDEX_TRANSCRIPTIONS,
diff --git a/arkindex/documents/api/search.py b/arkindex/documents/api/search.py
index 073da3a940..a57034eb8b 100644
--- a/arkindex/documents/api/search.py
+++ b/arkindex/documents/api/search.py
@@ -1,85 +1,46 @@
 from django.conf import settings
 from rest_framework.generics import ListAPIView
-from rest_framework.exceptions import ValidationError, PermissionDenied
-from arkindex.documents.models import Transcription, Act, Corpus
+from arkindex.documents.models import Transcription, Act
 from arkindex.documents.serializers.search import TranscriptionSearchResultSerializer, ActSearchResultSerializer
 from arkindex.documents.search import search_transcriptions_post, search_acts_post
-from arkindex.project.elastic import ESQuerySet
-from arkindex.project.tools import elasticsearch_escape
+from arkindex.project.mixins import SearchAPIMixin
 
 
-class SearchAPIMixin(object):
-    def get(self, request, *args, **kwargs):
-        q = request.query_params.get('q')
-        if not q or q.isspace():
-            raise ValidationError('A search query is required')
-        return super().get(request, *args, **kwargs)
+class SearchAPIView(SearchAPIMixin, ListAPIView):
+    """
+    A base class for ES search views
+    """
+    template_path = None
+    es_source = True
+    es_query = None
+    es_index = None
+    es_type = None
+    es_sort = None
 
 
-class TranscriptionSearch(SearchAPIMixin, ListAPIView):
+class TranscriptionSearch(SearchAPIView):
     """
     Search and list transcriptions, using pagination
     """
     serializer_class = TranscriptionSearchResultSerializer
+    template_path = 'elastic/search_transcriptions.json'
+    es_sort = {"score": {"order": "desc", "mode": "max"}}
+    es_index = settings.ES_INDEX_TRANSCRIPTIONS
+    es_type = Transcription.INDEX_TYPE
 
-    def get_queryset(self):
-        query = self.request.query_params.get('q')
-        if not query:
-            return
-        context = {
-            'query': elasticsearch_escape(query),
-            'type': self.request.query_params.get('type'),
-        }
-        if 'corpus' in self.request.query_params:
-            try:
-                context['corpus_id'] = Corpus.objects.readable(self.request.user) \
-                                                     .get(id=self.request.query_params['corpus']).id
-            except Corpus.DoesNotExist:
-                raise PermissionDenied()
-        else:
-            context['corpora_ids'] = Corpus.objects.readable(self.request.user).values_list('id', flat=True)
-
-        return ESQuerySet(
-            query=ESQuerySet.make_query(
-                'elastic/search_transcriptions.json',
-                ctx=context,
-            ),
-            sort={"score": {"order": "desc", "mode": "max"}},
-            es_index=settings.ES_INDEX_TRANSCRIPTIONS,
-            es_type=Transcription.INDEX_TYPE,
-            post_process=search_transcriptions_post,
-        )
+    def post_process(self, *args, **kwargs):
+        return search_transcriptions_post(*args)
 
 
-class ActSearch(SearchAPIMixin, ListAPIView):
+class ActSearch(SearchAPIView):
     """
     Search for acts containing a specific word
     """
     serializer_class = ActSearchResultSerializer
+    template_path = 'elastic/search_acts.json'
+    es_source = False
+    es_index = settings.ES_INDEX_ACTS
+    es_type = Act.INDEX_TYPE
 
-    def get_queryset(self):
-        query = self.request.query_params.get('q')
-        if not query:
-            return
-        context = {
-            'query': elasticsearch_escape(query),
-            'type': self.request.query_params.get('type'),
-        }
-        if 'corpus' in self.request.query_params:
-            try:
-                context['corpus_id'] = Corpus.objects.readable(self.request.user) \
-                                                     .get(id=self.request.query_params['corpus']).id
-            except Corpus.DoesNotExist:
-                raise PermissionDenied()
-        else:
-            context['corpora_ids'] = Corpus.objects.readable(self.request.user).values_list('id', flat=True)
-        return ESQuerySet(
-            _source=False,
-            query=ESQuerySet.make_query(
-                'elastic/search_acts.json',
-                ctx=context,
-            ),
-            es_index=settings.ES_INDEX_ACTS,
-            es_type=Act.INDEX_TYPE,
-            post_process=search_acts_post,
-        )
+    def post_process(self, *args, **kwargs):
+        return search_acts_post(*args)
diff --git a/arkindex/documents/search.py b/arkindex/documents/search.py
index ec0653715a..af7c9f3c72 100644
--- a/arkindex/documents/search.py
+++ b/arkindex/documents/search.py
@@ -38,6 +38,10 @@ def search_acts_post(data):
         for result in results
         for hit in result['inner_hits']['transcriptions']['hits']['hits']
     ]
+    tr_totals = {
+        uuid.UUID(result['_id']): result['inner_hits']['transcriptions']['hits']['total']
+        for result in results
+    }
 
     transcriptions = {
         t.id: t
@@ -61,6 +65,7 @@ def search_acts_post(data):
 
     for act in acts:
         act.transcriptions_results = [transcriptions[tid] for tid in acts_tr_ids[act.id]]
+        act.total_transcriptions = tr_totals[act.id]
         act.surfaces = [surf.zone for surf in all_surfaces.get(act.id, [])]
         act.parents = all_parents.get(act.id, [])
 
diff --git a/arkindex/documents/serializers/search.py b/arkindex/documents/serializers/search.py
index 2336b54772..846dc617ad 100644
--- a/arkindex/documents/serializers/search.py
+++ b/arkindex/documents/serializers/search.py
@@ -36,6 +36,7 @@ class ActSearchResultSerializer(serializers.ModelSerializer):
     Serialize an act
     """
     transcriptions = TranscriptionSerializer(many=True, source='transcriptions_results')
+    total_transcriptions = serializers.IntegerField()
     surfaces = ZoneSerializer(many=True)
     parents = serializers.ListField(
         child=serializers.ListField(
@@ -52,6 +53,7 @@ class ActSearchResultSerializer(serializers.ModelSerializer):
             'name',
             'number',
             'transcriptions',
+            'total_transcriptions',
             'surfaces',
             'parents',
             'viewer_url',
diff --git a/arkindex/documents/tests/test_search.py b/arkindex/documents/tests/test_search.py
new file mode 100644
index 0000000000..156190a29a
--- /dev/null
+++ b/arkindex/documents/tests/test_search.py
@@ -0,0 +1,154 @@
+from arkindex.project.tests import FixtureAPITestCase
+from arkindex.documents.models import Transcription, Act, Element
+from django.urls import reverse
+from rest_framework import status
+from unittest.mock import patch
+
+
+class TestSearchAPI(FixtureAPITestCase):
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.es_mock = patch('arkindex.project.elastic.Elasticsearch').start()
+
+    def setUp(self):
+        self.es_mock.reset_mock()
+        self.es_mock().reset_mock()
+
+    def build_es_response(self, hits):
+        return {
+            "hits": {
+                "total": len(hits),
+                "max_score": None,
+                "hits": hits
+            },
+            "_shards": {
+                "total": 5,
+                "failed": 0,
+                "skipped": 0,
+                "successful": 5
+            },
+            "took": 42,
+            "timed_out": False
+        }
+
+    def make_transcription_hit(self, ts):
+        return {
+            "_id": str(ts.id.hex),
+            "_score": None,
+            "sort": [
+                ts.score
+            ],
+            "_index": "transcriptions",
+            "_type": Transcription.INDEX_TYPE,
+            "_source": ts.build_search_index()
+        }
+
+    def make_nested_transcription_hit(self, ts):
+        return {
+            "_source": {
+                "text": ts.text,
+                "id": str(ts.id),
+                "score": ts.score,
+            },
+            "_score": ts.score,
+            "_nested": {
+                "field": "transcriptions",
+                "offset": 1337
+            }
+        }
+
+    def make_act_hit(self, act, ts, score=1.0):
+        return {
+            "_score": score,
+            "_type": Act.INDEX_TYPE,
+            "_id": str(act.id.hex),
+            "_index": "acts",
+            "inner_hits": {
+                "transcriptions": {
+                    "hits": {
+                        "total": len(ts),
+                        "hits": list(map(self.make_nested_transcription_hit, ts)),
+                        "max_score": 1337,
+                    }
+                }
+            }
+        }
+
+    def test_transcription_search(self):
+        expected = Transcription.objects.filter(text="PARIS")
+
+        self.es_mock().count.return_value = {'count': len(expected)}
+        self.es_mock().search.return_value = self.build_es_response(
+            list(map(self.make_transcription_hit, expected)),
+        )
+
+        response = self.client.get(reverse('api:transcription-search'), {'q': "paris"})
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        results = response.json()["results"]
+
+        self.assertCountEqual(
+            [r['id'] for r in results],
+            map(str, expected.values_list('id', flat=True)),
+        )
+
+    def test_act_search(self):
+        act = Act.objects.get(number="1")
+        ts = Transcription.objects.filter(text__in=["PARIS", "ROY"], zone__image__path='img1')
+        surf = Element.objects.get(name="Surface A").zone
+
+        self.es_mock().count.return_value = {'count': len(ts)}
+        self.es_mock().search.return_value = self.build_es_response(
+            [self.make_act_hit(act, ts), ],
+        )
+
+        response = self.client.get(reverse('api:act-search'), {'q': "paris roy"})
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        results = response.json()["results"]
+        self.assertEqual(len(results), 1)
+        result = results[0]
+
+        self.assertEqual(result['id'], str(act.id))
+        self.assertCountEqual(
+            [t['id'] for t in result['transcriptions']],
+            map(str, ts.values_list('id', flat=True)),
+        )
+
+        self.assertEqual(len(result['surfaces']), 1)
+        self.assertEqual(result['surfaces'][0]['id'], str(surf.id))
+
+    def test_iiif_transcription_search(self):
+        # Filter to only get transcriptions from volume 1
+        unfiltered = Transcription.objects.filter(text="PARIS")
+        expected = Transcription.objects.filter(text="PARIS", zone__image__path__in=['img1', 'img2', 'img3'])
+
+        self.es_mock().count.return_value = {'count': len(unfiltered)}
+        self.es_mock().search.return_value = self.build_es_response(
+            list(map(self.make_transcription_hit, unfiltered))
+        )
+
+        response = self.client.get(reverse(
+            'api:ts-search-manifest',
+            kwargs={'pk': str(Element.objects.get(name='Volume 1').id)}
+        ), {'q': 'paris'})
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        data = response.json()
+
+        self.assertEqual(data['@context'], "http://iiif.io/api/search/0/context.json")
+        self.assertEqual(data['@type'], 'sc:AnnotationList')
+        self.assertEqual(data['startIndex'], 0)
+        self.assertEqual(data['within']['@type'], 'sc:Layer')
+        self.assertEqual(data['within']['total'], len(expected))
+
+        hits = data['hits']
+        self.assertTrue(all(hit['@type'] == 'search:Hit' for hit in hits))
+        self.assertTrue(all(hit['match'] == 'PARIS' for hit in hits))
+
+        resources = data['resources']
+        self.assertTrue(all(res['@type'] == 'oa:Annotation' for res in resources))
+        self.assertTrue(all(res['motivation'] == 'sc:painting' for res in resources))
+        self.assertTrue(all(res['resource']['@type'] == 'cnt:ContentAsText' for res in resources))
+        self.assertTrue(all(res['resource']['format'] == 'text/plain' for res in resources))
+        self.assertTrue(all(res['resource']['chars'] == 'PARIS' for res in resources))
diff --git a/arkindex/documents/tests/test_search_post.py b/arkindex/documents/tests/test_search_post.py
deleted file mode 100644
index 527c3cd0e0..0000000000
--- a/arkindex/documents/tests/test_search_post.py
+++ /dev/null
@@ -1,97 +0,0 @@
-from arkindex.project.tests import FixtureTestCase
-from arkindex.documents.models import Transcription, Act, Element
-from arkindex.documents.search import search_transcriptions_post, search_acts_post, search_transcriptions_filter_post
-
-
-class TestSearchPostProcess(FixtureTestCase):
-    """Test ElasticSearch post-processing functions"""
-
-    def build_es_response(self, hits):
-        return {
-            "hits": {
-                "total": len(hits),
-                "max_score": None,
-                "hits": hits
-            },
-            "_shards": {
-                "total": 5,
-                "failed": 0,
-                "skipped": 0,
-                "successful": 5
-            },
-            "took": 42,
-            "timed_out": False
-        }
-
-    def make_transcription_hit(self, ts):
-        return {
-            "_id": str(ts.id.hex),
-            "_score": None,
-            "sort": [
-                ts.score
-            ],
-            "_index": "transcriptions",
-            "_type": Transcription.INDEX_TYPE,
-            "_source": ts.build_search_index()
-        }
-
-    def make_nested_transcription_hit(self, ts):
-        return {
-            "_source": {
-                "text": ts.text,
-                "id": str(ts.id),
-                "score": ts.score,
-            },
-            "_score": ts.score,
-            "_nested": {
-                "field": "transcriptions",
-                "offset": 1337
-            }
-        }
-
-    def make_act_hit(self, act, ts, score=1.0):
-        return {
-            "_score": score,
-            "_type": Act.INDEX_TYPE,
-            "_id": str(act.id.hex),
-            "_index": "acts",
-            "inner_hits": {
-                "transcriptions": {
-                    "hits": {
-                        "total": len(ts),
-                        "hits": list(map(self.make_nested_transcription_hit, ts)),
-                        "max_score": 1337,
-                    }
-                }
-            }
-        }
-
-    def test_search_transcriptions_post(self):
-        expected = Transcription.objects.filter(text="PARIS")
-        results = search_transcriptions_post(self.build_es_response(
-            list(map(self.make_transcription_hit, expected))
-        ))
-        self.assertCountEqual(results, expected)
-
-    def test_search_acts_post(self):
-        act = Act.objects.get(number="1")
-        ts = Transcription.objects.filter(text__in=["PARIS", "ROY"], zone__image__path='img1')
-        surf = Element.objects.get(name="Surface A").zone
-        results = search_acts_post(self.build_es_response(
-            [self.make_act_hit(act, ts), ]
-        ))
-
-        self.assertCountEqual(results, [act])
-        self.assertTrue(hasattr(results[0], 'transcriptions'))
-        self.assertTrue(hasattr(results[0], 'surfaces'))
-        self.assertCountEqual(results[0].transcriptions_results, ts)
-        self.assertCountEqual(results[0].surfaces, [surf])
-
-    def test_search_transcriptions_filter_post(self):
-        # Filter to only get transcriptions from volume 1
-        unfiltered = Transcription.objects.filter(text="PARIS")
-        expected = Transcription.objects.filter(text="PARIS", zone__image__path__in=['img1', 'img2', 'img3'])
-        results = search_transcriptions_filter_post(self.build_es_response(
-            list(map(self.make_transcription_hit, unfiltered))
-        ), Element.objects.get(name="Volume 1").id)
-        self.assertCountEqual(results, expected)
diff --git a/arkindex/project/mixins.py b/arkindex/project/mixins.py
index fa2fc6238c..50c46cb191 100644
--- a/arkindex/project/mixins.py
+++ b/arkindex/project/mixins.py
@@ -1,6 +1,9 @@
+from django.conf import settings
 from django.shortcuts import get_object_or_404
-from rest_framework.exceptions import PermissionDenied
+from rest_framework.exceptions import PermissionDenied, ValidationError, APIException
 from arkindex.documents.models import Corpus, Right
+from arkindex.project.elastic import ESQuerySet
+from arkindex.project.tools import elasticsearch_escape
 
 
 class CorpusACLMixin(object):
@@ -24,3 +27,65 @@ class CorpusACLMixin(object):
         assert isinstance(corpus, Corpus)
         return self.request.user.is_admin or not self.request.user.is_anonymous and \
             Right.Admin in corpus.get_acl_rights(self.request.user)
+
+
+class SearchAPIMixin(CorpusACLMixin):
+
+    def get(self, request, *args, **kwargs):
+        q = request.query_params.get('q')
+        if not q or q.isspace():
+            raise ValidationError('A search query is required')
+        return super().get(request, *args, **kwargs)
+
+    def get_context(self):
+        context = {
+            'query': elasticsearch_escape(self.request.query_params['q']),
+            'type': self.request.query_params.get('type'),
+            'inner_hits_size': settings.ES_INNER_RESULTS_LIMIT,
+        }
+        if 'corpus' in self.request.query_params:
+            try:
+                context['corpus_id'] = self.get_corpus(self.request.query_params['corpus'])
+            except Corpus.DoesNotExist:
+                raise PermissionDenied
+        else:
+            context['corpora_ids'] = Corpus.objects.readable(self.request.user).values_list('id', flat=True)
+        return context
+
+    def get_template_path(self):
+        if self.template_path:
+            return self.template_path
+        raise APIException('A JSON template path is required to build an ElasticSearch query.')
+
+    def get_query(self):
+        if self.es_query:
+            return self.es_query
+
+        return ESQuerySet.make_query(
+            self.get_template_path(),
+            ctx=self.get_context(),
+        )
+
+    def get_sort(self):
+        return self.es_sort or {}
+
+    def get_index(self):
+        if not self.es_index:
+            raise APIException('An ElasticSearch index name is required.')
+        return self.es_index
+
+    def get_type(self):
+        return self.es_type or '_doc'
+
+    def get_queryset(self):
+        return ESQuerySet(
+            _source=self.es_source,
+            query=self.get_query(),
+            sort=self.get_sort(),
+            es_index=self.get_index(),
+            es_type=self.get_type(),
+            post_process=self.post_process,
+        )
+
+    def post_process(self, *args):
+        return args
diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py
index bb71201372..8fa49c8687 100644
--- a/arkindex/project/settings.py
+++ b/arkindex/project/settings.py
@@ -216,6 +216,8 @@ ELASTIC_SEARCH_HOSTS = [
 ]
 # The Scroll API is required to go over 10K results
 ES_RESULTS_LIMIT = 10000
+# ES defaults to three items returned in a nested query if the inner_hits size is not defined
+ES_INNER_RESULTS_LIMIT = 10
 ES_INDEX_TRANSCRIPTIONS = 'transcriptions'
 ES_INDEX_ACTS = 'acts'
 
diff --git a/arkindex/templates/elastic/search_acts.json b/arkindex/templates/elastic/search_acts.json
index 16c7355de9..b56f12f437 100644
--- a/arkindex/templates/elastic/search_acts.json
+++ b/arkindex/templates/elastic/search_acts.json
@@ -1,7 +1,11 @@
 {
     "nested": {
         "path": "transcriptions",
-        "inner_hits": {},
+        "inner_hits": {
+            {% if inner_hits_size %}
+            "size": {{ inner_hits_size }}
+            {% endif %}
+        },
         "score_mode": "sum",
         "query": {
             "function_score": {
-- 
GitLab