From 0090147a6bfe3c426b3d63a91a2033c0e928c634 Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Wed, 24 Mar 2021 18:06:22 +0100
Subject: [PATCH] Retrieve transcriptions from local cache in
 list_transcriptions

---
 arkindex_worker/cache.py  |  2 ++
 arkindex_worker/worker.py | 27 ++++++++++++++++++++++++---
 2 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index c67d1d9e..1e804717 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -36,6 +36,8 @@ CachedTranscription = namedtuple(
 def convert_table_tuple(table):
     if table == "elements":
         return CachedElement
+    elif table == "transcriptions":
+        return CachedTranscription
     else:
         raise NotImplementedError
 
diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index bddb65e0..03553220 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -842,9 +842,30 @@ class ElementsWorker(BaseWorker):
             ), "worker_version should be of type str"
             query_params["worker_version"] = worker_version
 
-        transcriptions = self.api_client.paginate(
-            "ListTranscriptions", id=element.id, **query_params
-        )
+        if self.cache and recursive is None:
+            # Checking that we only received query_params handled by the cache
+            assert set(query_params.keys()) <= {
+                "worker_version",
+            }, "When using the local cache, you can only filter by 'worker_version'"
+
+            conditions = [("element_id", "=", convert_str_uuid_to_hex(element.id))]
+            if worker_version:
+                conditions.append(
+                    ("worker_version_id", "=", convert_str_uuid_to_hex(worker_version))
+                )
+
+            transcriptions = self.cache.fetch(
+                "transcriptions",
+                where=conditions,
+            )
+        else:
+            if self.cache:
+                logger.warning(
+                    "'recursive' filter was set, results will be retrieved from the API since the local cache doesn't handle this filter."
+                )
+            transcriptions = self.api_client.paginate(
+                "ListTranscriptions", id=element.id, **query_params
+            )
 
         return transcriptions
 
-- 
GitLab