From b7abc0ca055829299cb8b5f0bf5dc8eb37e7d5ca Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Fri, 23 Apr 2021 15:02:00 +0200
Subject: [PATCH] Optimize recursive

---
 arkindex_worker/worker/transcription.py       | 35 +++++++-------
 .../test_transcriptions.py                    | 47 ++++++++++++++++---
 2 files changed, 59 insertions(+), 23 deletions(-)

diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py
index a1488a32..3c1c9926 100644
--- a/arkindex_worker/worker/transcription.py
+++ b/arkindex_worker/worker/transcription.py
@@ -249,27 +249,30 @@ class TranscriptionMixin(object):
             query_params["worker_version"] = worker_version
 
         if self.use_cache:
-            elements_query = CachedElement.select().where(
-                CachedElement.id == element.id
-            )
-            type_attr = CachedElement.type
-
-            if recursive:
-                base_case = elements_query.cte("base", recursive=True)
+            if not recursive:
+                if element_type and element_type != element.type:
+                    return CachedTranscription.select().where(False)
+                transcriptions = CachedTranscription.select().where(
+                    CachedTranscription.element_id == element.id
+                )
+            else:
+                base_case = (
+                    CachedElement.select()
+                    .where(CachedElement.id == element.id)
+                    .cte("base", recursive=True)
+                )
                 recursive = CachedElement.select().join(
                     base_case, on=(CachedElement.parent_id == base_case.c.id)
                 )
                 cte = base_case.union_all(recursive)
-                elements_query = cte.select_from(cte.c.id, cte.c.type)
-                type_attr = cte.c.type
-
-            if element_type:
-                elements_query = elements_query.where(type_attr == element_type)
+                transcriptions = (
+                    CachedTranscription.select()
+                    .join(cte, on=(CachedTranscription.element_id == cte.c.id))
+                    .with_cte(cte)
+                )
 
-            elements_ids = [elem.id for elem in elements_query]
-            transcriptions = CachedTranscription.select().where(
-                CachedTranscription.element_id.in_(elements_ids)
-            )
+                if element_type:
+                    transcriptions = transcriptions.where(cte.c.type == element_type)
 
             if worker_version:
                 transcriptions = transcriptions.where(
diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py
index 3381bac4..586b2d2d 100644
--- a/tests/test_elements_worker/test_transcriptions.py
+++ b/tests/test_elements_worker/test_transcriptions.py
@@ -1268,14 +1268,38 @@ def test_list_transcriptions(responses, mock_elements_worker):
         # Filter on element should give first transcription
         (
             {
-                "element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
+                "element": CachedElement(
+                    id="11111111-1111-1111-1111-111111111111", type="page"
+                ),
+            },
+            ("11111111-1111-1111-1111-111111111111",),
+        ),
+        # Filter on element and element_type should give first transcription
+        (
+            {
+                "element": CachedElement(
+                    id="11111111-1111-1111-1111-111111111111", type="page"
+                ),
+                "element_type": "page",
+            },
+            ("11111111-1111-1111-1111-111111111111",),
+        ),
+        # Filter on element and worker_version should give first transcription
+        (
+            {
+                "element": CachedElement(
+                    id="11111111-1111-1111-1111-111111111111", type="page"
+                ),
+                "worker_version": "56785678-5678-5678-5678-567856785678",
             },
             ("11111111-1111-1111-1111-111111111111",),
         ),
         # Filter recursively on element should give all transcriptions inserted
         (
             {
-                "element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
+                "element": CachedElement(
+                    id="11111111-1111-1111-1111-111111111111", type="page"
+                ),
                 "recursive": True,
             },
             (
@@ -1286,19 +1310,28 @@ def test_list_transcriptions(responses, mock_elements_worker):
                 "55555555-5555-5555-5555-555555555555",
             ),
         ),
-        # Filter recursively on element and worker_version should give first transcription
+        # Filter recursively on element and worker_version should give four transcriptions
         (
             {
-                "element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
-                "worker_version": "56785678-5678-5678-5678-567856785678",
+                "element": CachedElement(
+                    id="11111111-1111-1111-1111-111111111111", type="page"
+                ),
+                "worker_version": "90129012-9012-9012-9012-901290129012",
                 "recursive": True,
             },
-            ("11111111-1111-1111-1111-111111111111",),
+            (
+                "22222222-2222-2222-2222-222222222222",
+                "33333333-3333-3333-3333-333333333333",
+                "44444444-4444-4444-4444-444444444444",
+                "55555555-5555-5555-5555-555555555555",
+            ),
         ),
         # Filter recursively on element and element_type should give three transcriptions
         (
             {
-                "element": CachedElement(id="11111111-1111-1111-1111-111111111111"),
+                "element": CachedElement(
+                    id="11111111-1111-1111-1111-111111111111", type="page"
+                ),
                 "element_type": "something_else",
                 "recursive": True,
             },
-- 
GitLab