Newer
Older
# -*- coding: utf-8 -*-
import sqlite3
from collections import namedtuple
from arkindex_worker import logger
SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
id VARCHAR(32) PRIMARY KEY,
parent_id VARCHAR(32),
type TEXT NOT NULL,
polygon TEXT,
initial BOOLEAN DEFAULT 0 NOT NULL,
worker_version_id VARCHAR(32)
)"""
SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions (
id VARCHAR(32) PRIMARY KEY,
element_id VARCHAR(32) NOT NULL,
text TEXT NOT NULL,
confidence REAL NOT NULL,
worker_version_id VARCHAR(32) NOT NULL,
FOREIGN KEY(element_id) REFERENCES elements(id)
)"""
CachedElement = namedtuple(
"CachedElement",
["id", "type", "polygon", "worker_version_id", "parent_id", "initial"],
CachedTranscription = namedtuple(
"CachedTranscription",
["id", "element_id", "text", "confidence", "worker_version_id"],
)
def convert_table_tuple(table):
if table == "elements":
return CachedElement
else:
raise NotImplementedError
class LocalDB(object):
def __init__(self, path):
self.db = sqlite3.connect(path)
self.db.row_factory = sqlite3.Row
self.cursor = self.db.cursor()
logger.info(f"Connection to local cache {path} established.")
def create_tables(self):
self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION)
def insert(self, table, lines):
if not lines:
return
columns = ", ".join(lines[0]._fields)
placeholders = ", ".join("?" * len(lines[0]))
values = [tuple(line) for line in lines]
self.cursor.executemany(
f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
)
self.db.commit()
def fetch(self, table, where=[]):
"""
where parameter is a list containing 3-values tuples defining an SQL WHERE condition.
e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")]
stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL.
This method only supports 'AND' SQL conditions.
"""
sql = f"SELECT * FROM {table}"
if where:
assert isinstance(where, list), "where should be a list"
assert all(
isinstance(condition, tuple) and len(condition) == 3
for condition in where
), "where conditions should be tuples of 3 values"
sql += " WHERE "
sql += " AND ".join(
[f"{field} {operator} (?)" for field, operator, _ in where]
)
self.cursor.execute(sql, [value for _, _, value in where])
tuple_type = convert_table_tuple(table)
return [tuple_type(**dict(row)) for row in self.cursor.fetchall()]