Skip to content
Snippets Groups Projects
Commit dc7145cb authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Better pseudo code

parent e2d14dea
No related branches found
No related tags found
1 merge request!76Merge parents caches into the current task one
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import logging import logging
import os
import sqlite3
from peewee import ( from peewee import (
BooleanField, BooleanField,
...@@ -33,21 +35,6 @@ class JSONField(Field): ...@@ -33,21 +35,6 @@ class JSONField(Field):
return json.loads(value) return json.loads(value)
def merge_parents_caches(self, parents_cache_paths):
for idx, parent_cache in enumerate(parents_cache_paths):
statements = [
"PRAGMA page_size=80000;",
"PRAGMA synchronous=OFF;",
f"ATTACH DATABASE '{parent_cache}' AS source{idx};",
f"REPLACE INTO elements SELECT * FROM source{idx}.elements;",
f"REPLACE INTO transcriptions SELECT * FROM source{idx}.transcriptions;",
]
for statement in statements:
self.cursor.execute(statement)
self.db.commit()
class CachedElement(Model): class CachedElement(Model):
id = UUIDField(primary_key=True) id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True) parent_id = UUIDField(null=True)
...@@ -91,3 +78,44 @@ def create_tables(): ...@@ -91,3 +78,44 @@ def create_tables():
Creates the tables in the cache DB only if they do not already exist. Creates the tables in the cache DB only if they do not already exist.
""" """
db.create_tables([CachedElement, CachedTranscription]) db.create_tables([CachedElement, CachedTranscription])
def merge_parents_caches(parent_ids, current_database, data_dir="/data"):
"""
Merge all the potential parent task's databases into the existing local one
"""
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
assert os.path.exists(current_database)
# TODO: handle chunk
# Find all the paths for these databases
paths = filter(
lambda p: os.path.isfile(p),
[os.path.join(data_dir, parent, "db.sqlite") for parent in parent_ids],
)
if not paths:
logger.info("No parents cache to use")
return
# Open a connection on current database
connection = sqlite3.connect(current_database)
with connection.cursor() as cursor:
for idx, path in enumerate(paths):
# Merge each table into the local database
statements = [
"PRAGMA page_size=80000;",
"PRAGMA synchronous=OFF;",
f"ATTACH DATABASE '{path}' AS source_{idx};",
f"REPLACE INTO elements SELECT * FROM source_{idx}.elements;",
f"REPLACE INTO transcriptions SELECT * FROM source_{idx}.transcriptions;",
]
for statement in statements:
cursor.execute(statement)
connection.commit()
# TODO: maybe reopen peewee connection ?
...@@ -28,13 +28,12 @@ from arkindex_worker.cache import ( ...@@ -28,13 +28,12 @@ from arkindex_worker.cache import (
CachedTranscription, CachedTranscription,
create_tables, create_tables,
init_cache_db, init_cache_db,
merge_parents_cache,
) )
from arkindex_worker.models import Element from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter from arkindex_worker.reporting import Reporter
MANUAL_SLUG = "manual" MANUAL_SLUG = "manual"
DATA_DIR = "/data"
CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}"
def _is_500_error(exc): def _is_500_error(exc):
...@@ -74,12 +73,14 @@ class BaseWorker(object): ...@@ -74,12 +73,14 @@ class BaseWorker(object):
self.use_cache = use_cache self.use_cache = use_cache
if self.use_cache is True: if self.use_cache is True:
if os.environ.get("TASK_ID") and os.path.isdir(CACHE_DIR): if os.environ.get("TASK_ID"):
cache_path = os.path.join(CACHE_DIR, "db.sqlite") cache_dir = f"/data/{os.environ.get('TASK_ID')}"
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
else: else:
cache_path = os.path.join(os.getcwd(), "db.sqlite") self.cache_path = os.path.join(os.getcwd(), "db.sqlite")
init_cache_db(cache_path) init_cache_db(self.cache_path)
create_tables() create_tables()
else: else:
logger.debug("Cache is disabled") logger.debug("Cache is disabled")
...@@ -157,23 +158,7 @@ class BaseWorker(object): ...@@ -157,23 +158,7 @@ class BaseWorker(object):
task = self.api_client.request( task = self.api_client.request(
"RetrieveTaskFromAgent", id=os.environ.get("TASK_ID") "RetrieveTaskFromAgent", id=os.environ.get("TASK_ID")
) )
merge_parents_cache(task["parents"], self.cache_path)
parents_cache_paths = []
for parent in task["parents"]:
parent_cache_path = f"{DATA_DIR}/{parent}/db.sqlite"
if os.path.isfile(parent_cache_path):
parents_cache_paths.append(parent_cache_path)
# Only one parent cache, we can just copy it into our current task local cache
if len(parents_cache_paths) == 1:
with open(self.cache.path, "rb+") as cache_file, open(
parents_cache_paths[0], "rb"
) as parent_cache_file:
cache_file.truncate(0)
cache_file.write(parent_cache_file.read())
# Many parents caches, we have to merge all of them in our current task local cache
elif len(parents_cache_paths) > 1:
self.cache.merge_parents_caches(parents_cache_paths)
def load_secret(self, name): def load_secret(self, name):
"""Load all secrets described in the worker configuration""" """Load all secrets described in the worker configuration"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment