Skip to content
Snippets Groups Projects
Commit 1d8948e7 authored by Eva Bardou's avatar Eva Bardou
Browse files

Retrieve children elements from SQLite cache in list_element_children

parent 8a3c8ad5
No related branches found
No related tags found
No related merge requests found
...@@ -39,3 +39,7 @@ class LocalDB(object): ...@@ -39,3 +39,7 @@ class LocalDB(object):
f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values f"INSERT INTO {table} ({columns}) VALUES ({placeholders})", values
) )
self.db.commit() self.db.commit()
def fetch(self, table, where_clause=""):
self.cursor.execute(f"SELECT * FROM {table} {where_clause}")
return self.cursor.fetchall()
...@@ -27,7 +27,7 @@ CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}" ...@@ -27,7 +27,7 @@ CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}"
class BaseWorker(object): class BaseWorker(object):
def __init__(self, description="Arkindex Base Worker"): def __init__(self, description="Arkindex Base Worker", use_cache=False):
self.parser = argparse.ArgumentParser(description=description) self.parser = argparse.ArgumentParser(description=description)
# Setup workdir either in Ponos environment or on host's home # Setup workdir either in Ponos environment or on host's home
...@@ -50,6 +50,8 @@ class BaseWorker(object): ...@@ -50,6 +50,8 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory") logger.info(f"Worker will use {self.work_dir} as working directory")
self.use_cache = use_cache
if os.path.isdir(CACHE_DIR): if os.path.isdir(CACHE_DIR):
cache_path = os.path.join(CACHE_DIR, "db.sqlite") cache_path = os.path.join(CACHE_DIR, "db.sqlite")
else: else:
...@@ -214,8 +216,8 @@ class ActivityState(Enum): ...@@ -214,8 +216,8 @@ class ActivityState(Enum):
class ElementsWorker(BaseWorker): class ElementsWorker(BaseWorker):
def __init__(self, description="Arkindex Elements Worker"): def __init__(self, description="Arkindex Elements Worker", use_cache=False):
super().__init__(description) super().__init__(description, use_cache)
# Add report concerning elements # Add report concerning elements
self.report = Reporter("unknown worker") self.report = Reporter("unknown worker")
...@@ -885,9 +887,29 @@ class ElementsWorker(BaseWorker): ...@@ -885,9 +887,29 @@ class ElementsWorker(BaseWorker):
), "worker_version should be of type str" ), "worker_version should be of type str"
query_params["worker_version"] = worker_version query_params["worker_version"] = worker_version
children = self.api_client.paginate( # Checking that we only have query_params handled by the cache
"ListElementChildren", id=element.id, **query_params if self.use_cache and set(query_params.keys()) <= {
) "name",
"type",
"worker_version",
}:
parent_id_hex = convert_str_uuid_to_hex(element.id)
name_condition = f" AND name LIKE '%{name}%'" if name else ""
type_condition = f" AND type='{type}'" if type else ""
worker_version_condition = (
f" AND worker_version_id='{convert_str_uuid_to_hex(worker_version)}'"
if worker_version
else ""
)
children = self.cache.fetch(
"elements",
where_clause=f"WHERE parent_id='{parent_id_hex}'{name_condition}{type_condition}{worker_version_condition}",
)
children = [CachedElement(**dict(child)) for child in children]
else:
children = self.api_client.paginate(
"ListElementChildren", id=element.id, **query_params
)
return children return children
......
...@@ -164,6 +164,16 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api): ...@@ -164,6 +164,16 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
return worker return worker
@pytest.fixture
def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
"""Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker(use_cache=True)
worker.configure()
return worker
@pytest.fixture @pytest.fixture
def fake_page_element(): def fake_page_element():
with open(FIXTURES_DIR / "page_element.json", "r") as f: with open(FIXTURES_DIR / "page_element.json", "r") as f:
......
...@@ -130,3 +130,22 @@ def test_insert(): ...@@ -130,3 +130,22 @@ def test_insert():
) )
assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT
def test_fetch_all():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
rows = cache.fetch("elements")
assert [CachedElement(**dict(row)) for row in rows] == ELEMENTS_TO_INSERT
def test_fetch_with_where():
db_path = f"{FIXTURES}/lines.sqlite"
cache = LocalDB(db_path)
cache.create_tables()
rows = cache.fetch(
"elements",
where_clause=f"WHERE parent_id='{convert_str_uuid_to_hex('12341234-1234-1234-1234-123412341234')}' AND name LIKE '%0%'",
)
assert [CachedElement(**dict(row)) for row in rows] == [ELEMENTS_TO_INSERT[0]]
...@@ -14,6 +14,38 @@ from arkindex_worker.utils import convert_str_uuid_to_hex ...@@ -14,6 +14,38 @@ from arkindex_worker.utils import convert_str_uuid_to_hex
from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker import ElementsWorker
CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache"
ELEMENTS_TO_INSERT = [
CachedElement(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
name="0",
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
name="1",
type="page",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedElement(
id=convert_str_uuid_to_hex("33333333-3333-3333-3333-333333333333"),
parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"),
name="10",
type="something",
polygon=json.dumps([[1, 1], [2, 2], [2, 1], [1, 2]]),
worker_version_id=convert_str_uuid_to_hex(
"90129012-9012-9012-9012-901290129012"
),
),
]
def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker): def test_list_elements_elements_list_arg_wrong_type(monkeypatch, mock_elements_worker):
...@@ -906,3 +938,111 @@ def test_list_element_children(responses, mock_elements_worker): ...@@ -906,3 +938,111 @@ def test_list_element_children(responses, mock_elements_worker):
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/", "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/",
] ]
def test_list_element_children_with_cache_unhandled_param(
responses, mock_elements_worker_with_cache
):
"""
Calls list_elements_children on a worker using the cache.
The cache doesn't contain any information about an element corpus and with_corpus query param is set to True.
The list_elements_children function will call the API (instead of the cache) to use the with_corpus param.
"""
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
expected_children = [
{
"id": "0000",
"type": "page",
"name": "Test",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
}
]
responses.add(
responses.GET,
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?with_corpus=True",
status=200,
json={
"count": 1,
"next": None,
"results": expected_children,
},
)
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(
element=elt, with_corpus=True
)
):
assert child == expected_children[idx]
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?with_corpus=True",
]
def test_list_element_children_with_cache(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt)
):
assert child == []
# Initialize SQLite cache with some elements
mock_elements_worker_with_cache.cache.insert("elements", ELEMENTS_TO_INSERT)
expected_children = ELEMENTS_TO_INSERT
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt)
):
assert child == expected_children[idx]
expected_children = ELEMENTS_TO_INSERT[1:]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt, name="1")
):
assert child == expected_children[idx]
expected_children = [ELEMENTS_TO_INSERT[1]]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(element=elt, type="page")
):
assert child == expected_children[idx]
expected_children = ELEMENTS_TO_INSERT[:2]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(
element=elt, worker_version="56785678-5678-5678-5678-567856785678"
)
):
assert child == expected_children[idx]
expected_children = [ELEMENTS_TO_INSERT[1]]
for idx, child in enumerate(
mock_elements_worker_with_cache.list_element_children(
element=elt,
name="O",
type="something",
worker_version="56785678-5678-5678-5678-567856785678",
)
):
assert child == expected_children[idx]
assert len(responses.calls) == 2
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment