# -*- coding: utf-8 -*- import sqlite3 from collections import namedtuple from arkindex_worker import logger SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements ( id VARCHAR(32) PRIMARY KEY, parent_id VARCHAR(32), type TEXT NOT NULL, polygon TEXT, initial BOOLEAN DEFAULT 0 NOT NULL, worker_version_id VARCHAR(32) )""" SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions ( id VARCHAR(32) PRIMARY KEY, element_id VARCHAR(32) NOT NULL, text TEXT NOT NULL, confidence REAL NOT NULL, worker_version_id VARCHAR(32) NOT NULL, FOREIGN KEY(element_id) REFERENCES elements(id) )""" CachedElement = namedtuple( "CachedElement", ["id", "type", "polygon", "worker_version_id", "parent_id", "initial"], defaults=[None, 0], ) CachedTranscription = namedtuple( "CachedTranscription", ["id", "element_id", "text", "confidence", "worker_version_id"], ) def convert_table_tuple(table): if table == "elements": return CachedElement else: raise NotImplementedError class LocalDB(object): def __init__(self, path): self.db = sqlite3.connect(path) self.db.row_factory = sqlite3.Row self.cursor = self.db.cursor() logger.info(f"Connection to local cache {path} established.") def create_tables(self): self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION) self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION) def insert(self, table, lines): if not lines: return columns = ", ".join(lines[0]._fields) placeholders = ", ".join("?" * len(lines[0])) values = [tuple(line) for line in lines] self.cursor.executemany( f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values ) self.db.commit() def fetch(self, table, where=[]): """ where parameter is a list containing 3-values tuples defining an SQL WHERE condition. e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")] stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL. This method only supports 'AND' SQL conditions. """ sql = f"SELECT * FROM {table}" if where: assert isinstance(where, list), "where should be a list" assert all( isinstance(condition, tuple) and len(condition) == 3 for condition in where ), "where conditions should be tuples of 3 values" sql += " WHERE " sql += " AND ".join( [f"{field} {operator} (?)" for field, operator, _ in where] ) self.cursor.execute(sql, [value for _, _, value in where]) tuple_type = convert_table_tuple(table) return [tuple_type(**dict(row)) for row in self.cursor.fetchall()]