diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 095f703fcd52e74b8acf2e95fe5dafcd916c06ef..ea7adf1a1e366df8fea0aa8ff77f031daf077765 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- import json import logging +import os +import sqlite3 from peewee import ( BooleanField, @@ -33,21 +35,6 @@ class JSONField(Field): return json.loads(value) -def merge_parents_caches(self, parents_cache_paths): - for idx, parent_cache in enumerate(parents_cache_paths): - statements = [ - "PRAGMA page_size=80000;", - "PRAGMA synchronous=OFF;", - f"ATTACH DATABASE '{parent_cache}' AS source{idx};", - f"REPLACE INTO elements SELECT * FROM source{idx}.elements;", - f"REPLACE INTO transcriptions SELECT * FROM source{idx}.transcriptions;", - ] - - for statement in statements: - self.cursor.execute(statement) - self.db.commit() - - class CachedElement(Model): id = UUIDField(primary_key=True) parent_id = UUIDField(null=True) @@ -91,3 +78,44 @@ def create_tables(): Creates the tables in the cache DB only if they do not already exist. """ db.create_tables([CachedElement, CachedTranscription]) + + +def merge_parents_caches(parent_ids, current_database, data_dir="/data"): + """ + Merge all the potential parent task's databases into the existing local one + """ + assert isinstance(parent_ids, list) + assert os.path.isdir(data_dir) + assert os.path.exists(current_database) + + # TODO: handle chunk + + # Find all the paths for these databases + paths = filter( + lambda p: os.path.isfile(p), + [os.path.join(data_dir, parent, "db.sqlite") for parent in parent_ids], + ) + + if not paths: + logger.info("No parents cache to use") + return + + # Open a connection on current database + connection = sqlite3.connect(current_database) + + with connection.cursor() as cursor: + for idx, path in enumerate(paths): + # Merge each table into the local database + statements = [ + "PRAGMA page_size=80000;", + "PRAGMA synchronous=OFF;", + f"ATTACH DATABASE '{path}' AS source_{idx};", + f"REPLACE INTO elements SELECT * FROM source_{idx}.elements;", + f"REPLACE INTO transcriptions SELECT * FROM source_{idx}.transcriptions;", + ] + + for statement in statements: + cursor.execute(statement) + connection.commit() + + # TODO: maybe reopen peewee connection ? diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 81c43f67a4d0733b6ccd311fe93cec6d48376e63..97c3a72d743dc3a6705257a4946f6f35965649fa 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -28,13 +28,12 @@ from arkindex_worker.cache import ( CachedTranscription, create_tables, init_cache_db, + merge_parents_cache, ) from arkindex_worker.models import Element from arkindex_worker.reporting import Reporter MANUAL_SLUG = "manual" -DATA_DIR = "/data" -CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}" def _is_500_error(exc): @@ -74,12 +73,14 @@ class BaseWorker(object): self.use_cache = use_cache if self.use_cache is True: - if os.environ.get("TASK_ID") and os.path.isdir(CACHE_DIR): - cache_path = os.path.join(CACHE_DIR, "db.sqlite") + if os.environ.get("TASK_ID"): + cache_dir = f"/data/{os.environ.get('TASK_ID')}" + assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" + self.cache_path = os.path.join(cache_dir, "db.sqlite") else: - cache_path = os.path.join(os.getcwd(), "db.sqlite") + self.cache_path = os.path.join(os.getcwd(), "db.sqlite") - init_cache_db(cache_path) + init_cache_db(self.cache_path) create_tables() else: logger.debug("Cache is disabled") @@ -157,23 +158,7 @@ class BaseWorker(object): task = self.api_client.request( "RetrieveTaskFromAgent", id=os.environ.get("TASK_ID") ) - - parents_cache_paths = [] - for parent in task["parents"]: - parent_cache_path = f"{DATA_DIR}/{parent}/db.sqlite" - if os.path.isfile(parent_cache_path): - parents_cache_paths.append(parent_cache_path) - - # Only one parent cache, we can just copy it into our current task local cache - if len(parents_cache_paths) == 1: - with open(self.cache.path, "rb+") as cache_file, open( - parents_cache_paths[0], "rb" - ) as parent_cache_file: - cache_file.truncate(0) - cache_file.write(parent_cache_file.read()) - # Many parents caches, we have to merge all of them in our current task local cache - elif len(parents_cache_paths) > 1: - self.cache.merge_parents_caches(parents_cache_paths) + merge_parents_cache(task["parents"], self.cache_path) def load_secret(self, name): """Load all secrets described in the worker configuration"""