diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 9bb6c0fa8aa3d5089354016db0dbed097fd82f54..4c45d1349e7d429fb6435126cfc70b97b834d4ce 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -21,6 +21,13 @@ CachedElement = namedtuple( ) +def convert_table_tuple(table): + if table == "elements": + return CachedElement + else: + raise NotImplementedError + + class LocalDB(object): def __init__(self, path): self.db = sqlite3.connect(path) @@ -41,3 +48,29 @@ class LocalDB(object): f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values ) self.db.commit() + + def fetch(self, table, where=[]): + """ + where parameter is a list containing 3-values tuples defining an SQL WHERE condition. + + e.g: where=[("id", "LIKE", "%0000%"), ("id", "NOT LIKE", "%1111%")] + stands for "WHERE id LIKE '%0000%' AND id NOT LIKE '%1111%'" in SQL. + + This method only supports 'AND' SQL conditions. + """ + + sql = f"SELECT * FROM {table}" + if where: + assert isinstance(where, list), "where should be a list" + assert all( + isinstance(condition, tuple) and len(condition) == 3 + for condition in where + ), "where conditions should be tuples of 3 values" + + sql += " WHERE " + sql += " AND ".join( + [f"{field} {operator} (?)" for field, operator, _ in where] + ) + self.cursor.execute(sql, [value for _, _, value in where]) + tuple_type = convert_table_tuple(table) + return [tuple_type(**dict(row)) for row in self.cursor.fetchall()] diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 00a3f5507f5340ef784d5be52d714e2413928a81..19ae2a7fc7dd3cc640e3679697e70362a49e4d82 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -845,9 +845,29 @@ class ElementsWorker(BaseWorker): ), "worker_version should be of type str" query_params["worker_version"] = worker_version - children = self.api_client.paginate( - "ListElementChildren", id=element.id, **query_params - ) + if self.cache: + # Checking that we only received query_params handled by the cache + assert set(query_params.keys()) <= { + "type", + "worker_version", + }, "When using the local cache, you can only filter by 'type' and/or 'worker_version'" + + conditions = [("parent_id", "=", convert_str_uuid_to_hex(element.id))] + if type: + conditions.append(("type", "=", type)) + if worker_version: + conditions.append( + ("worker_version_id", "=", convert_str_uuid_to_hex(worker_version)) + ) + + children = self.cache.fetch( + "elements", + where=conditions, + ) + else: + children = self.api_client.paginate( + "ListElementChildren", id=element.id, **query_params + ) return children diff --git a/tests/test_cache.py b/tests/test_cache.py index fb3c28a125417e1d480f19627fdcf54f341f11ae..a978e3632ab165d32e49b86ba02230f7b3d2e7e3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -128,3 +128,29 @@ def test_insert(): ) assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT + + +def test_fetch_all(): + db_path = f"{FIXTURES}/lines.sqlite" + cache = LocalDB(db_path) + cache.create_tables() + children = cache.fetch("elements") + assert children == ELEMENTS_TO_INSERT + + +def test_fetch_with_where(): + db_path = f"{FIXTURES}/lines.sqlite" + cache = LocalDB(db_path) + cache.create_tables() + children = cache.fetch( + "elements", + where=[ + ( + "parent_id", + "=", + convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + ), + ("id", "LIKE", "%1111%"), + ], + ) + assert children == [ELEMENTS_TO_INSERT[0]] diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index b0ad3486fbfad16061883aa5d948b5240d55c122..03028cadd74ce04ea62c0079ca74a493cb6bba34 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -14,6 +14,35 @@ from arkindex_worker.utils import convert_str_uuid_to_hex from arkindex_worker.worker import ElementsWorker CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" +ELEMENTS_TO_INSERT = [ + CachedElement( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + type="something", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedElement( + id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + type="page", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedElement( + id=convert_str_uuid_to_hex("33333333-3333-3333-3333-333333333333"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + type="something", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "90129012-9012-9012-9012-901290129012" + ), + ), +] def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker): @@ -905,3 +934,70 @@ def test_list_element_children(responses, mock_elements_worker): "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ] + + +def test_list_element_children_with_cache_unhandled_param( + mock_elements_worker_with_cache, +): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + mock_elements_worker_with_cache.list_element_children( + element=elt, with_corpus=True + ) + assert ( + str(e.value) + == "When using the local cache, you can only filter by 'type' and/or 'worker_version'" + ) + + +def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt) + ): + assert child == [] + + # Initialize SQLite cache with some elements + mock_elements_worker_with_cache.cache.insert("elements", ELEMENTS_TO_INSERT) + + expected_children = ELEMENTS_TO_INSERT + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt) + ): + assert child == expected_children[idx] + + expected_children = [ELEMENTS_TO_INSERT[1]] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt, type="page") + ): + assert child == expected_children[idx] + + expected_children = ELEMENTS_TO_INSERT[:2] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, worker_version="56785678-5678-5678-5678-567856785678" + ) + ): + assert child == expected_children[idx] + + expected_children = [ELEMENTS_TO_INSERT[0]] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, + type="something", + worker_version="56785678-5678-5678-5678-567856785678", + ) + ): + assert child == expected_children[idx] + + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + ]