Skip to content
Snippets Groups Projects
cache.py 2.83 KiB
Newer Older
# -*- coding: utf-8 -*-
import sqlite3
from collections import namedtuple

from arkindex_worker import logger

SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
    id VARCHAR(32) PRIMARY KEY,
    parent_id VARCHAR(32),
    type TEXT NOT NULL,
    polygon TEXT,
    initial BOOLEAN DEFAULT 0 NOT NULL,
    worker_version_id VARCHAR(32)
)"""
SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions (
    id VARCHAR(32) PRIMARY KEY,
    element_id VARCHAR(32) NOT NULL,
    text TEXT NOT NULL,
    confidence REAL NOT NULL,
    worker_version_id VARCHAR(32) NOT NULL,
    FOREIGN KEY(element_id) REFERENCES elements(id)
)"""


CachedElement = namedtuple(
    "CachedElement",
    ["id", "type", "polygon", "worker_version_id", "parent_id", "initial"],
    defaults=[None, 0],
)
CachedTranscription = namedtuple(
    "CachedTranscription",
    ["id", "element_id", "text", "confidence", "worker_version_id"],
)
def convert_table_tuple(table):
    if table == "elements":
        return CachedElement
    else:
        raise NotImplementedError


class LocalDB(object):
    def __init__(self, path):
        self.db = sqlite3.connect(path)
        self.db.row_factory = sqlite3.Row
        self.cursor = self.db.cursor()
        logger.info(f"Connection to local cache {path} established.")

    def create_tables(self):
        self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
        self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION)

    def insert(self, table, lines):
        if not lines:
            return
        columns = ", ".join(lines[0]._fields)
        placeholders = ", ".join("?" * len(lines[0]))
        values = [tuple(line) for line in lines]
        self.cursor.executemany(
            f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
        )
        self.db.commit()

    def fetch(self, table, where=[]):
        """
        where parameter is a list containing 3-values tuples defining an SQL WHERE condition.

        e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")]
             stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL.

        This method only supports 'AND' SQL conditions.
        """

        sql = f"SELECT * FROM {table}"
        if where:
            assert isinstance(where, list), "where should be a list"
            assert all(
                isinstance(condition, tuple) and len(condition) == 3
                for condition in where
            ), "where conditions should be tuples of 3 values"

            sql += " WHERE "
            sql += " AND ".join(
                [f"{field} {operator} (?)" for field, operator, _ in where]
            )
        self.cursor.execute(sql, [value for _, _, value in where])
        tuple_type = convert_table_tuple(table)
        return [tuple_type(**dict(row)) for row in self.cursor.fetchall()]