diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index dbcb8dd210af40aae3b17064e58f163b1f02cc4a..1452ba437c614b0aa8f56a6f9134cd296ab6da09 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -42,3 +42,7 @@ class LocalDB(object): f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values ) self.db.commit() + + def fetch(self, table, where_clause=""): + self.cursor.execute(f"SELECT * FROM {table} {where_clause}") + return self.cursor.fetchall() diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 839ee99f9d0514f7732ac5eacd197cd0bdac27c4..4c56bf1ef681e6b032769af5089a86b3c93f0ae1 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -889,9 +889,29 @@ class ElementsWorker(BaseWorker): ), "worker_version should be of type str" query_params["worker_version"] = worker_version - children = self.api_client.paginate( - "ListElementChildren", id=element.id, **query_params - ) + # Checking that we only have query_params handled by the cache + if self.use_cache and set(query_params.keys()) <= { + "name", + "type", + "worker_version", + }: + parent_id_hex = convert_str_uuid_to_hex(element.id) + name_condition = f" AND name LIKE '%{name}%'" if name else "" + type_condition = f" AND type='{type}'" if type else "" + worker_version_condition = ( + f" AND worker_version_id='{convert_str_uuid_to_hex(worker_version)}'" + if worker_version + else "" + ) + children = self.cache.fetch( + "elements", + where_clause=f"WHERE parent_id='{parent_id_hex}'{name_condition}{type_condition}{worker_version_condition}", + ) + children = [CachedElement(**dict(child)) for child in children] + else: + children = self.api_client.paginate( + "ListElementChildren", id=element.id, **query_params + ) return children diff --git a/tests/test_cache.py b/tests/test_cache.py index 07754469378861d41d2954f714dc14970d0877a5..b3f649455b20abc100712f99b5526ea24cc8ab57 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -130,3 +130,22 @@ def test_insert(): ) assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT + + +def test_fetch_all(): + db_path = f"{FIXTURES}/lines.sqlite" + cache = LocalDB(db_path) + cache.create_tables() + rows = cache.fetch("elements") + assert [CachedElement(**dict(row)) for row in rows] == ELEMENTS_TO_INSERT + + +def test_fetch_with_where(): + db_path = f"{FIXTURES}/lines.sqlite" + cache = LocalDB(db_path) + cache.create_tables() + rows = cache.fetch( + "elements", + where_clause=f"WHERE parent_id='{convert_str_uuid_to_hex('12341234-1234-1234-1234-123412341234')}' AND name LIKE '%0%'", + ) + assert [CachedElement(**dict(row)) for row in rows] == [ELEMENTS_TO_INSERT[0]] diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 776885fab29a040d43ec9969f16f9d0e65b66913..c79a106a6548f4a2dfbf5bec8f945dd53310b55e 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -14,6 +14,38 @@ from arkindex_worker.utils import convert_str_uuid_to_hex from arkindex_worker.worker import ElementsWorker CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" +ELEMENTS_TO_INSERT = [ + CachedElement( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + name="0", + type="something", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedElement( + id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + name="1", + type="page", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedElement( + id=convert_str_uuid_to_hex("33333333-3333-3333-3333-333333333333"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + name="10", + type="something", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "90129012-9012-9012-9012-901290129012" + ), + ), +] def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker): @@ -906,3 +938,111 @@ def test_list_element_children(responses, mock_elements_worker): "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ] + + +def test_list_element_children_with_cache_unhandled_param( + responses, mock_elements_worker_with_cache +): + """ + Calls list_elements_children on a worker using the cache. + The cache doesn't contain any information about an element corpus and with_corpus query param is set to True. + The list_elements_children function will call the API (instead of the cache) to use the with_corpus param. + """ + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + expected_children = [ + { + "id": "0000", + "type": "page", + "name": "Test", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + } + ] + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?with_corpus=True", + status=200, + json={ + "count": 1, + "next": None, + "results": expected_children, + }, + ) + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, with_corpus=True + ) + ): + assert child == expected_children[idx] + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?with_corpus=True", + ] + + +def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt) + ): + assert child == [] + + # Initialize SQLite cache with some elements + mock_elements_worker_with_cache.cache.insert("elements", ELEMENTS_TO_INSERT) + + expected_children = ELEMENTS_TO_INSERT + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt) + ): + assert child == expected_children[idx] + + expected_children = ELEMENTS_TO_INSERT[1:] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt, name="1") + ): + assert child == expected_children[idx] + + expected_children = [ELEMENTS_TO_INSERT[1]] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt, type="page") + ): + assert child == expected_children[idx] + + expected_children = ELEMENTS_TO_INSERT[:2] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, worker_version="56785678-5678-5678-5678-567856785678" + ) + ): + assert child == expected_children[idx] + + expected_children = [ELEMENTS_TO_INSERT[1]] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, + name="O", + type="something", + worker_version="56785678-5678-5678-5678-567856785678", + ) + ): + assert child == expected_children[idx] + + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + ]