From 18f725a4d82349ed8827fda99c1e07c5851975dd Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Wed, 24 Mar 2021 13:02:09 +0100
Subject: [PATCH] Prevent SQL injections

---
 arkindex_worker/cache.py  | 10 ++++++++--
 arkindex_worker/worker.py | 16 +++++++++-------
 tests/test_cache.py       |  9 ++++++++-
 3 files changed, 25 insertions(+), 10 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index 78b2131d..ed56e86c 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -42,6 +42,12 @@ class LocalDB(object):
         )
         self.db.commit()
 
-    def fetch(self, table, where_clause=""):
-        self.cursor.execute(f"SELECT * FROM {table} {where_clause}")
+    def fetch(self, table, where=[]):
+        sql = f"SELECT * FROM {table}"
+        if where:
+            sql += " WHERE "
+            sql += " AND ".join(
+                [f"{field} {operator} (?)" for field, operator, _ in where]
+            )
+        self.cursor.execute(sql, [value for _, _, value in where])
         return self.cursor.fetchall()
diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index 1068e684..776b4939 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -896,17 +896,19 @@ class ElementsWorker(BaseWorker):
                 "type",
                 "worker_version",
             }, "When using the local cache, you can only filter by 'name', 'type' and/or 'worker_version'"
-            parent_id_hex = convert_str_uuid_to_hex(element.id)
-            name_condition = f" AND name LIKE '%{name}%'" if name else ""
-            type_condition = f" AND type='{type}'" if type else ""
-            worker_version_condition = (
-                f" AND worker_version_id='{convert_str_uuid_to_hex(worker_version)}'"
+
+            conditions = [("parent_id", "=", convert_str_uuid_to_hex(element.id))]
+            conditions += [("name", "LIKE", f"%{name}%")] if name else []
+            conditions += [("type", "=", type)] if type else []
+            conditions += (
+                [("worker_version_id", "=", convert_str_uuid_to_hex(worker_version))]
                 if worker_version
-                else ""
+                else []
             )
+
             children = self.cache.fetch(
                 "elements",
-                where_clause=f"WHERE parent_id='{parent_id_hex}'{name_condition}{type_condition}{worker_version_condition}",
+                where=conditions,
             )
             children = [CachedElement(**dict(child)) for child in children]
         else:
diff --git a/tests/test_cache.py b/tests/test_cache.py
index 76304653..685cc1b3 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -144,6 +144,13 @@ def test_fetch_with_where():
     cache.create_tables()
     rows = cache.fetch(
         "elements",
-        where_clause=f"WHERE parent_id='{convert_str_uuid_to_hex('12341234-1234-1234-1234-123412341234')}' AND name LIKE '%0%'",
+        where=[
+            (
+                "parent_id",
+                "=",
+                convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
+            ),
+            ("name", "LIKE", "%0%"),
+        ],
     )
     assert [CachedElement(**dict(row)) for row in rows] == [ELEMENTS_TO_INSERT[0]]
-- 
GitLab