diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 007ee2554191da8e8da9cd31263f1ba8c10c3f84..5f312ec8229d7fc2964544cccfe437728eb66330 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -43,6 +43,12 @@ class LocalDB(object): ) self.db.commit() - def fetch(self, table, where_clause=""): - self.cursor.execute(f"SELECT * FROM {table} {where_clause}") + def fetch(self, table, where=[]): + sql = f"SELECT * FROM {table}" + if where: + sql += " WHERE " + sql += " AND ".join( + [f"{field} {operator} (?)" for field, operator, _ in where] + ) + self.cursor.execute(sql, [value for _, _, value in where]) return self.cursor.fetchall() diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 33b5c636c06a12c08ae0f893746ee8fd8141ea90..23b3db34cafa5de4fd9029bb498d1a398afb48f8 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -897,17 +897,19 @@ class ElementsWorker(BaseWorker): "type", "worker_version", }, "When using the local cache, you can only filter by 'name', 'type' and/or 'worker_version'" - parent_id_hex = convert_str_uuid_to_hex(element.id) - name_condition = f" AND name LIKE '%{name}%'" if name else "" - type_condition = f" AND type='{type}'" if type else "" - worker_version_condition = ( - f" AND worker_version_id='{convert_str_uuid_to_hex(worker_version)}'" + + conditions = [("parent_id", "=", convert_str_uuid_to_hex(element.id))] + conditions += [("name", "LIKE", f"%{name}%")] if name else [] + conditions += [("type", "=", type)] if type else [] + conditions += ( + [("worker_version_id", "=", convert_str_uuid_to_hex(worker_version))] if worker_version - else "" + else [] ) + children = self.cache.fetch( "elements", - where_clause=f"WHERE parent_id='{parent_id_hex}'{name_condition}{type_condition}{worker_version_condition}", + where=conditions, ) children = [CachedElement(**dict(child)) for child in children] else: diff --git a/tests/test_cache.py b/tests/test_cache.py index 49767480fd88b2880fa01ff5326de17079d95563..5ec7c2b5da0ee402292e60b7519035ccf79232b3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -146,6 +146,13 @@ def test_fetch_with_where(): cache.create_tables() rows = cache.fetch( "elements", - where_clause=f"WHERE parent_id='{convert_str_uuid_to_hex('12341234-1234-1234-1234-123412341234')}' AND name LIKE '%0%'", + where=[ + ( + "parent_id", + "=", + convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + ), + ("name", "LIKE", "%0%"), + ], ) assert [CachedElement(**dict(row)) for row in rows] == [ELEMENTS_TO_INSERT[0]]