From e6f99ec557af9ec6dacd5f23bf5672f8f712b4be Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Wed, 24 Mar 2021 16:31:05 +0000
Subject: [PATCH] Retrieve children elements from SQLite cache in
 list_element_children

---
 arkindex_worker/cache.py                    | 33 +++++++
 arkindex_worker/worker.py                   | 26 +++++-
 tests/test_cache.py                         | 26 ++++++
 tests/test_elements_worker/test_elements.py | 96 +++++++++++++++++++++
 4 files changed, 178 insertions(+), 3 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index 9bb6c0fa..4c45d134 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -21,6 +21,13 @@ CachedElement = namedtuple(
 )
 
 
+def convert_table_tuple(table):
+    if table == "elements":
+        return CachedElement
+    else:
+        raise NotImplementedError
+
+
 class LocalDB(object):
     def __init__(self, path):
         self.db = sqlite3.connect(path)
@@ -41,3 +48,29 @@ class LocalDB(object):
             f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
         )
         self.db.commit()
+
+    def fetch(self, table, where=[]):
+        """
+        where parameter is a list containing 3-values tuples defining an SQL WHERE condition.
+
+        e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")]
+             stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL.
+
+        This method only supports 'AND' SQL conditions.
+        """
+
+        sql = f"SELECT * FROM {table}"
+        if where:
+            assert isinstance(where, list), "where should be a list"
+            assert all(
+                isinstance(condition, tuple) and len(condition) == 3
+                for condition in where
+            ), "where conditions should be tuples of 3 values"
+
+            sql += " WHERE "
+            sql += " AND ".join(
+                [f"{field} {operator} (?)" for field, operator, _ in where]
+            )
+        self.cursor.execute(sql, [value for _, _, value in where])
+        tuple_type = convert_table_tuple(table)
+        return [tuple_type(**dict(row)) for row in self.cursor.fetchall()]
diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index 00a3f550..19ae2a7f 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -845,9 +845,29 @@ class ElementsWorker(BaseWorker):
             ), "worker_version should be of type str"
             query_params["worker_version"] = worker_version
 
-        children = self.api_client.paginate(
-            "ListElementChildren", id=element.id, **query_params
-        )
+        if self.cache:
+            # Checking that we only received query_params handled by the cache
+            assert set(query_params.keys()) <= {
+                "type",
+                "worker_version",
+            }, "When using the local cache, you can only filter by 'type' and/or 'worker_version'"
+
+            conditions = [("parent_id", "=", convert_str_uuid_to_hex(element.id))]
+            if type:
+                conditions.append(("type", "=", type))
+            if worker_version:
+                conditions.append(
+                    ("worker_version_id", "=", convert_str_uuid_to_hex(worker_version))
+                )
+
+            children = self.cache.fetch(
+                "elements",
+                where=conditions,
+            )
+        else:
+            children = self.api_client.paginate(
+                "ListElementChildren", id=element.id, **query_params
+            )
 
         return children
 
diff --git a/tests/test_cache.py b/tests/test_cache.py
index fb3c28a1..a978e363 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -128,3 +128,29 @@ def test_insert():
     )
 
     assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT
+
+
+def test_fetch_all():
+    db_path = f"{FIXTURES}/lines.sqlite"
+    cache = LocalDB(db_path)
+    cache.create_tables()
+    children = cache.fetch("elements")
+    assert children == ELEMENTS_TO_INSERT
+
+
+def test_fetch_with_where():
+    db_path = f"{FIXTURES}/lines.sqlite"
+    cache = LocalDB(db_path)
+    cache.create_tables()
+    children = cache.fetch(
+        "elements",
+        where=[
+            (
+                "parent_id",
+                "=",
+                convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
+            ),
+            ("id", "LIKE", "%1111%"),
+        ],
+    )
+    assert children == [ELEMENTS_TO_INSERT[0]]
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index b0ad3486..03028cad 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -14,6 +14,35 @@ from arkindex_worker.utils import convert_str_uuid_to_hex
 from arkindex_worker.worker import ElementsWorker
 
 CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
+ELEMENTS_TO_INSERT = [
+    CachedElement(
+        id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
+        parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
+        type="something",
+        polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
+        worker_version_id=convert_str_uuid_to_hex(
+            "56785678-5678-5678-5678-567856785678"
+        ),
+    ),
+    CachedElement(
+        id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
+        parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
+        type="page",
+        polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
+        worker_version_id=convert_str_uuid_to_hex(
+            "56785678-5678-5678-5678-567856785678"
+        ),
+    ),
+    CachedElement(
+        id=convert_str_uuid_to_hex("33333333-3333-3333-3333-333333333333"),
+        parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
+        type="something",
+        polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
+        worker_version_id=convert_str_uuid_to_hex(
+            "90129012-9012-9012-9012-901290129012"
+        ),
+    ),
+]
 
 
 def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker):
@@ -905,3 +934,70 @@ def test_list_element_children(responses, mock_elements_worker):
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
         "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
     ]
+
+
+def test_list_element_children_with_cache_unhandled_param(
+    mock_elements_worker_with_cache,
+):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker_with_cache.list_element_children(
+            element=elt, with_corpus=True
+        )
+    assert (
+        str(e.value)
+        == "When using the local cache, you can only filter by 'type' and/or 'worker_version'"
+    )
+
+
+def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+
+    for idx, child in enumerate(
+        mock_elements_worker_with_cache.list_element_children(element=elt)
+    ):
+        assert child == []
+
+    # Initialize SQLite cache with some elements
+    mock_elements_worker_with_cache.cache.insert("elements", ELEMENTS_TO_INSERT)
+
+    expected_children = ELEMENTS_TO_INSERT
+
+    for idx, child in enumerate(
+        mock_elements_worker_with_cache.list_element_children(element=elt)
+    ):
+        assert child == expected_children[idx]
+
+    expected_children = [ELEMENTS_TO_INSERT[1]]
+
+    for idx, child in enumerate(
+        mock_elements_worker_with_cache.list_element_children(element=elt, type="page")
+    ):
+        assert child == expected_children[idx]
+
+    expected_children = ELEMENTS_TO_INSERT[:2]
+
+    for idx, child in enumerate(
+        mock_elements_worker_with_cache.list_element_children(
+            element=elt, worker_version="56785678-5678-5678-5678-567856785678"
+        )
+    ):
+        assert child == expected_children[idx]
+
+    expected_children = [ELEMENTS_TO_INSERT[0]]
+
+    for idx, child in enumerate(
+        mock_elements_worker_with_cache.list_element_children(
+            element=elt,
+            type="something",
+            worker_version="56785678-5678-5678-5678-567856785678",
+        )
+    ):
+        assert child == expected_children[idx]
+
+    assert len(responses.calls) == 2
+    assert [call.request.url for call in responses.calls] == [
+        "http://testserver/api/v1/user/",
+        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+    ]
-- 
GitLab