diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 89014daa77e51f1dbff6c376f06bdd3bad2c9c1d..32186dcf0b8abd0e61fae89df8676e504bbbccf8 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -39,3 +39,7 @@ class LocalDB(object): f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values ) self.db.commit() + + def fetch(self, table, where_clause=""): + self.cursor.execute(f"SELECT * FROM {table} {where_clause}") + return self.cursor.fetchall() diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 5d32fff5bd981c2c48a8c390f1db237afe26467f..05ecff527b43082e271134fa915c77590650c387 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -27,7 +27,7 @@ CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}" class BaseWorker(object): - def __init__(self, description="Arkindex Base Worker"): + def __init__(self, description="Arkindex Base Worker", use_cache=False): self.parser = argparse.ArgumentParser(description=description) # Setup workdir either in Ponos environment or on host's home @@ -50,6 +50,8 @@ class BaseWorker(object): logger.info(f"Worker will use {self.work_dir} as working directory") + self.use_cache = use_cache + if os.path.isdir(CACHE_DIR): cache_path = os.path.join(CACHE_DIR, "db.sqlite") else: @@ -214,8 +216,8 @@ class ActivityState(Enum): class ElementsWorker(BaseWorker): - def __init__(self, description="Arkindex Elements Worker"): - super().__init__(description) + def __init__(self, description="Arkindex Elements Worker", use_cache=False): + super().__init__(description, use_cache) # Add report concerning elements self.report = Reporter("unknown worker") @@ -885,9 +887,29 @@ class ElementsWorker(BaseWorker): ), "worker_version should be of type str" query_params["worker_version"] = worker_version - children = self.api_client.paginate( - "ListElementChildren", id=element.id, **query_params - ) + # Checking that we only have query_params handled by the cache + if self.use_cache and set(query_params.keys()) <= { + "name", + "type", + "worker_version", + }: + parent_id_hex = convert_str_uuid_to_hex(element.id) + name_condition = f" AND name LIKE '%{name}%'" if name else "" + type_condition = f" AND type='{type}'" if type else "" + worker_version_condition = ( + f" AND worker_version_id='{convert_str_uuid_to_hex(worker_version)}'" + if worker_version + else "" + ) + children = self.cache.fetch( + "elements", + where_clause=f"WHERE parent_id='{parent_id_hex}'{name_condition}{type_condition}{worker_version_condition}", + ) + children = [CachedElement(**dict(child)) for child in children] + else: + children = self.api_client.paginate( + "ListElementChildren", id=element.id, **query_params + ) return children diff --git a/tests/conftest.py b/tests/conftest.py index 107a2567db83c8da0c6e2fbd70bcc1babbc5c716..b7ff454b633b0313ce656b48a9fdaa738edbb849 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -164,6 +164,16 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api): return worker +@pytest.fixture +def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api): + """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest""" + monkeypatch.setattr(sys, "argv", ["worker"]) + + worker = ElementsWorker(use_cache=True) + worker.configure() + return worker + + @pytest.fixture def fake_page_element(): with open(FIXTURES_DIR / "page_element.json", "r") as f: diff --git a/tests/test_cache.py b/tests/test_cache.py index 07754469378861d41d2954f714dc14970d0877a5..b3f649455b20abc100712f99b5526ea24cc8ab57 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -130,3 +130,22 @@ def test_insert(): ) assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT + + +def test_fetch_all(): + db_path = f"{FIXTURES}/lines.sqlite" + cache = LocalDB(db_path) + cache.create_tables() + rows = cache.fetch("elements") + assert [CachedElement(**dict(row)) for row in rows] == ELEMENTS_TO_INSERT + + +def test_fetch_with_where(): + db_path = f"{FIXTURES}/lines.sqlite" + cache = LocalDB(db_path) + cache.create_tables() + rows = cache.fetch( + "elements", + where_clause=f"WHERE parent_id='{convert_str_uuid_to_hex('12341234-1234-1234-1234-123412341234')}' AND name LIKE '%0%'", + ) + assert [CachedElement(**dict(row)) for row in rows] == [ELEMENTS_TO_INSERT[0]] diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 83a88bc8cc325be5a7754f65896931bf191fe8be..d31c963b750de1f61fa6e06749101afff4d9f6d9 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -14,6 +14,38 @@ from arkindex_worker.utils import convert_str_uuid_to_hex from arkindex_worker.worker import ElementsWorker CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" +ELEMENTS_TO_INSERT = [ + CachedElement( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + name="0", + type="something", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedElement( + id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + name="1", + type="page", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedElement( + id=convert_str_uuid_to_hex("33333333-3333-3333-3333-333333333333"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + name="10", + type="something", + polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]), + worker_version_id=convert_str_uuid_to_hex( + "90129012-9012-9012-9012-901290129012" + ), + ), +] def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker): @@ -906,3 +938,111 @@ def test_list_element_children(responses, mock_elements_worker): "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ] + + +def test_list_element_children_with_cache_unhandled_param( + responses, mock_elements_worker_with_cache +): + """ + Calls list_elements_children on a worker using the cache. + The cache doesn't contain any information about an element corpus and with_corpus query param is set to True. + The list_elements_children function will call the API (instead of the cache) to use the with_corpus param. + """ + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + expected_children = [ + { + "id": "0000", + "type": "page", + "name": "Test", + "corpus": {}, + "thumbnail_url": None, + "zone": {}, + "best_classes": None, + "has_children": None, + "worker_version_id": None, + } + ] + responses.add( + responses.GET, + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?with_corpus=True", + status=200, + json={ + "count": 1, + "next": None, + "results": expected_children, + }, + ) + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, with_corpus=True + ) + ): + assert child == expected_children[idx] + + assert len(responses.calls) == 3 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?with_corpus=True", + ] + + +def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt) + ): + assert child == [] + + # Initialize SQLite cache with some elements + mock_elements_worker_with_cache.cache.insert("elements", ELEMENTS_TO_INSERT) + + expected_children = ELEMENTS_TO_INSERT + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt) + ): + assert child == expected_children[idx] + + expected_children = ELEMENTS_TO_INSERT[1:] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt, name="1") + ): + assert child == expected_children[idx] + + expected_children = [ELEMENTS_TO_INSERT[1]] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children(element=elt, type="page") + ): + assert child == expected_children[idx] + + expected_children = ELEMENTS_TO_INSERT[:2] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, worker_version="56785678-5678-5678-5678-567856785678" + ) + ): + assert child == expected_children[idx] + + expected_children = [ELEMENTS_TO_INSERT[1]] + + for idx, child in enumerate( + mock_elements_worker_with_cache.list_element_children( + element=elt, + name="O", + type="something", + worker_version="56785678-5678-5678-5678-567856785678", + ) + ): + assert child == expected_children[idx] + + assert len(responses.calls) == 2 + assert [call.request.url for call in responses.calls] == [ + "http://testserver/api/v1/user/", + "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", + ]