diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index ed56e86c5974fc7b1d4bd79522a794c6cc6f0d7b..4c45d1349e7d429fb6435126cfc70b97b834d4ce 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -21,6 +21,13 @@ CachedElement = namedtuple( ) +def convert_table_tuple(table): + if table == "elements": + return CachedElement + else: + raise NotImplementedError + + class LocalDB(object): def __init__(self, path): self.db = sqlite3.connect(path) @@ -43,11 +50,27 @@ class LocalDB(object): self.db.commit() def fetch(self, table, where=[]): + """ + where parameter is a list containing 3-values tuples defining an SQL WHERE condition. + + e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")] + stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL. + + This method only supports 'AND' SQL conditions. + """ + sql = f"SELECT * FROM {table}" if where: + assert isinstance(where, list), "where should be a list" + assert all( + isinstance(condition, tuple) and len(condition) == 3 + for condition in where + ), "where conditions should be tuples of 3 values" + sql += " WHERE " sql += " AND ".join( [f"{field} {operator} (?)" for field, operator, _ in where] ) self.cursor.execute(sql, [value for _, _, value in where]) - return self.cursor.fetchall() + tuple_type = convert_table_tuple(table) + return [tuple_type(**dict(row)) for row in self.cursor.fetchall()] diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index dcb5feb7edd7befd2dbb5264eb43478903a9e34b..bfe402a96c50e0cd8c500f5a274d76c09a2ab452 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -897,18 +897,17 @@ class ElementsWorker(BaseWorker): }, "When using the local cache, you can only filter by 'type' and/or 'worker_version'" conditions = [("parent_id", "=", convert_str_uuid_to_hex(element.id))] - conditions += [("type", "=", type)] if type else [] - conditions += ( - [("worker_version_id", "=", convert_str_uuid_to_hex(worker_version))] - if worker_version - else [] - ) + if type: + conditions.append(("type", "=", type)) + if worker_version: + conditions.append( + ("worker_version_id", "=", convert_str_uuid_to_hex(worker_version)) + ) children = self.cache.fetch( "elements", where=conditions, ) - children = [CachedElement(**dict(child)) for child in children] else: children = self.api_client.paginate( "ListElementChildren", id=element.id, **query_params diff --git a/tests/test_cache.py b/tests/test_cache.py index 526f159946c586514998006a45ffcf0ba9fdcbb9..a978e3632ab165d32e49b86ba02230f7b3d2e7e3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -134,15 +134,15 @@ def test_fetch_all(): db_path = f"{FIXTURES}/lines.sqlite" cache = LocalDB(db_path) cache.create_tables() - rows = cache.fetch("elements") - assert [CachedElement(**dict(row)) for row in rows] == ELEMENTS_TO_INSERT + children = cache.fetch("elements") + assert children == ELEMENTS_TO_INSERT def test_fetch_with_where(): db_path = f"{FIXTURES}/lines.sqlite" cache = LocalDB(db_path) cache.create_tables() - rows = cache.fetch( + children = cache.fetch( "elements", where=[ ( @@ -153,4 +153,4 @@ def test_fetch_with_where(): ("id", "LIKE", "%1111%"), ], ) - assert [CachedElement(**dict(row)) for row in rows] == [ELEMENTS_TO_INSERT[0]] + assert children == [ELEMENTS_TO_INSERT[0]]