From 4b27f5068a07a139bc5d7cbf082aa7eca87f7e9e Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Thu, 25 Mar 2021 12:03:20 +0100
Subject: [PATCH] Add logic to merge parents caches into the current task one

---
 arkindex_worker/cache.py  | 17 ++++++++++++++++-
 arkindex_worker/worker.py | 21 +++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index c67d1d9e..81ec1765 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -42,7 +42,8 @@ def convert_table_tuple(table):
 
 class LocalDB(object):
     def __init__(self, path):
-        self.db = sqlite3.connect(path)
+        self.path = path
+        self.db = sqlite3.connect(self.path)
         self.db.row_factory = sqlite3.Row
         self.cursor = self.db.cursor()
         logger.info(f"Connection to local cache {path} established.")
@@ -51,6 +52,20 @@ class LocalDB(object):
         self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
         self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION)
 
+    def merge_parents_caches(self, parents_cache_paths):
+        for idx, parent_cache in enumerate(parents_cache_paths):
+            statements = [
+                "PRAGMA page_size=80000;",
+                "PRAGMA synchronous=OFF;",
+                f"ATTACH DATABASE '{parent_cache}' AS source{idx};",
+                f"REPLACE INTO elements SELECT * FROM source{idx}.elements;",
+                f"REPLACE INTO transcriptions SELECT * FROM source{idx}.transcriptions;",
+            ]
+
+            for statement in statements:
+                self.cursor.execute(statement)
+            self.db.commit()
+
     def insert(self, table, lines):
         if not lines:
             return
diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index 031b9da3..1364def6 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -129,6 +129,27 @@ class BaseWorker(object):
         # Load all required secrets
         self.secrets = {name: self.load_secret(name) for name in required_secrets}
 
+        # Merging parents caches (if there are any) in the current task local cache
+        if self.cache and os.environ.get("TASK_ID"):
+            task = self.api_client.request(
+                "RetrieveTaskFromAgent", id=os.environ.get("TASK_ID")
+            )
+
+            parents_cache_paths = []
+            for parent in task["parents"]:
+                parent_cache_path = f"/data/{parent}/db.sqlite"
+                if os.path.isfile(parent_cache_path):
+                    parents_cache_paths.append(parent_cache_path)
+
+            if len(parents_cache_paths) == 1:
+                with open(self.cache.path, "rb+") as cache_file, open(
+                    parents_cache_paths[0], "rb"
+                ) as parent_cache_file:
+                    cache_file.truncate(0)
+                    cache_file.write(parent_cache_file.read())
+            elif len(parents_cache_paths) > 1:
+                self.cache.merge_parent_caches(parents_cache_paths)
+
     def load_secret(self, name):
         """Load all secrets described in the worker configuration"""
         secret = None
-- 
GitLab