Skip to content
Snippets Groups Projects
Commit e6f99ec5 authored by Eva Bardou's avatar Eva Bardou Committed by Bastien Abadie
Browse files

Retrieve children elements from SQLite cache in list_element_children

parent 32459852
No related branches found
No related tags found
1 merge request!68Retrieve children elements from SQLite cache in list_element_children
Pipeline #78327 passed
...@@ -21,6 +21,13 @@ CachedElement = namedtuple( ...@@ -21,6 +21,13 @@ CachedElement = namedtuple(
) )
def convert_table_tuple(table):
if table == "elements":
return CachedElement
else:
raise NotImplementedError
class LocalDB(object): class LocalDB(object):
def __init__(self, path): def __init__(self, path):
self.db = sqlite3.connect(path) self.db = sqlite3.connect(path)
...@@ -41,3 +48,29 @@ class LocalDB(object): ...@@ -41,3 +48,29 @@ class LocalDB(object):
f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
) )
self.db.commit() self.db.commit()
def fetch(self, table, where=[]):
"""
where parameter is a list containing 3-values tuples defining an SQL WHERE condition.
e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")]
stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL.
This method only supports 'AND' SQL conditions.
"""
sql = f"SELECT * FROM {table}"
if where:
assert isinstance(where, list), "where should be a list"
assert all(
isinstance(condition, tuple) and len(condition) == 3
for condition in where
), "where conditions should be tuples of 3 values"
sql += " WHERE "
sql += " AND ".join(
[f"{field} {operator} (?)" for field, operator, _ in where]
)
self.cursor.execute(sql, [value for _, _, value in where])
tuple_type = convert_table_tuple(table)
return [tuple_type(**dict(row)) for row in self.cursor.fetchall()]
...@@ -845,9 +845,29 @@ class ElementsWorker(BaseWorker): ...@@ -845,9 +845,29 @@ class ElementsWorker(BaseWorker):
), "worker_version should be of type str" ), "worker_version should be of type str"
query_params["worker_version"] = worker_version query_params["worker_version"] = worker_version
children = self.api_client.paginate( if self.cache:
"ListElementChildren", id=element.id, **query_params # Checking that we only received query_params handled by the cache
) assert set(query_params.keys()) <= {
"type",
"worker_version",
}, "When using the local cache, you can only filter by 'type' and/or 'worker_version'"
conditions = [("parent_id", "=", convert_str_uuid_to_hex(element.id))]
if type:
conditions.append(("type", "=", type))
if worker_version:
conditions.append(
("worker_version_id", "=", convert_str_uuid_to_hex(worker_version))
)
children = self.cache.fetch(
"elements",
where=conditions,
)
else:
children = self.api_client.paginate(
"ListElementChildren", id=element.id, **query_params
)
return children return children
......
...@@ -128,3 +128,29 @@ def test_insert(): ...@@ -128,3 +128,29 @@ def test_insert():
) )
assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT
def test_fetch_all():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
children = cache.fetch("elements")
assert children == ELEMENTS_TO_INSERT
def test_fetch_with_where():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
children = cache.fetch(
"elements",
where=[
(
"parent_id",
"=",
convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
),
("id", "LIKE", "%1111%"),
],
)
assert children == [ELEMENTS_TO_INSERT[0]]
...@@ -14,6 +14,35 @@ from arkindex_worker.utils import convert_str_uuid_to_hex ...@@ -14,6 +14,35 @@ from arkindex_worker.utils import convert_str_uuid_to_hex
from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker import ElementsWorker
CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
ELEMENTS_TO_INSERT = [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="page",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("33333333-3333-3333-3333-333333333333"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"90129012-9012-9012-9012-901290129012"
),
),
]
def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker): def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker):
...@@ -905,3 +934,70 @@ def test_list_element_children(responses, mock_elements_worker): ...@@ -905,3 +934,70 @@ def test_list_element_children(responses, mock_elements_worker):
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
] ]
def test_list_element_children_with_cache_unhandled_param(
mock_elements_worker_with_cache,
):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker_with_cache.list_element_children(
element=elt, with_corpus=True
)
assert (
str(e.value)
== "When using the local cache, you can only filter by 'type' and/or 'worker_version'"
)
def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt)
):
assert child == []
# Initialize SQLite cache with some elements
mock_elements_worker_with_cache.cache.insert("elements", ELEMENTS_TO_INSERT)
expected_children = ELEMENTS_TO_INSERT
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt)
):
assert child == expected_children[idx]
expected_children = [ELEMENTS_TO_INSERT[1]]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt, type="page")
):
assert child == expected_children[idx]
expected_children = ELEMENTS_TO_INSERT[:2]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(
element=elt, worker_version="56785678-5678-5678-5678-567856785678"
)
):
assert child == expected_children[idx]
expected_children = [ELEMENTS_TO_INSERT[0]]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(
element=elt,
type="something",
worker_version="56785678-5678-5678-5678-567856785678",
)
):
assert child == expected_children[idx]
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment