From 4b27f5068a07a139bc5d7cbf082aa7eca87f7e9e Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Thu, 25 Mar 2021 12:03:20 +0100 Subject: [PATCH] Add logic to merge parents caches into the current task one --- arkindex_worker/cache.py | 17 ++++++++++++++++- arkindex_worker/worker.py | 21 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index c67d1d9e..81ec1765 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -42,7 +42,8 @@ def convert_table_tuple(table): class LocalDB(object): def __init__(self, path): - self.db = sqlite3.connect(path) + self.path = path + self.db = sqlite3.connect(self.path) self.db.row_factory = sqlite3.Row self.cursor = self.db.cursor() logger.info(f"Connection to local cache {path} established.") @@ -51,6 +52,20 @@ class LocalDB(object): self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION) self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION) + def merge_parents_caches(self, parents_cache_paths): + for idx, parent_cache in enumerate(parents_cache_paths): + statements = [ + "PRAGMA page_size=80000;", + "PRAGMA synchronous=OFF;", + f"ATTACH DATABASE '{parent_cache}' AS source{idx};", + f"REPLACE INTO elements SELECT * FROM source{idx}.elements;", + f"REPLACE INTO transcriptions SELECT * FROM source{idx}.transcriptions;", + ] + + for statement in statements: + self.cursor.execute(statement) + self.db.commit() + def insert(self, table, lines): if not lines: return diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 031b9da3..1364def6 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -129,6 +129,27 @@ class BaseWorker(object): # Load all required secrets self.secrets = {name: self.load_secret(name) for name in required_secrets} + # Merging parents caches (if there are any) in the current task local cache + if self.cache and os.environ.get("TASK_ID"): + task = self.api_client.request( + "RetrieveTaskFromAgent", id=os.environ.get("TASK_ID") + ) + + parents_cache_paths = [] + for parent in task["parents"]: + parent_cache_path = f"/data/{parent}/db.sqlite" + if os.path.isfile(parent_cache_path): + parents_cache_paths.append(parent_cache_path) + + if len(parents_cache_paths) == 1: + with open(self.cache.path, "rb+") as cache_file, open( + parents_cache_paths[0], "rb" + ) as parent_cache_file: + cache_file.truncate(0) + cache_file.write(parent_cache_file.read()) + elif len(parents_cache_paths) > 1: + self.cache.merge_parent_caches(parents_cache_paths) + def load_secret(self, name): """Load all secrets described in the worker configuration""" secret = None -- GitLab