Skip to content
Snippets Groups Projects
Commit 8dd86ac9 authored by Manon Blanco's avatar Manon Blanco
Browse files

(re)use cache helpers + split into several functions

parent 45648da5
No related branches found
No related tags found
1 merge request!2Port init elements code
Pipeline #164734 passed
...@@ -6,6 +6,14 @@ from collections import OrderedDict ...@@ -6,6 +6,14 @@ from collections import OrderedDict
from logging import Logger, getLogger from logging import Logger, getLogger
from time import sleep from time import sleep
from arkindex_worker.cache import (
CachedElement,
CachedImage,
create_tables,
create_version_table,
init_cache_db,
)
from arkindex_worker.models import Element as BaseElement
from arkindex_worker.worker.base import BaseWorker from arkindex_worker.worker.base import BaseWorker
logger: Logger = getLogger(__name__) logger: Logger = getLogger(__name__)
...@@ -16,33 +24,6 @@ INIT_PAGE_SIZE = 500 ...@@ -16,33 +24,6 @@ INIT_PAGE_SIZE = 500
PENDING_STATE = "pending" PENDING_STATE = "pending"
ERROR_STATE = "error" ERROR_STATE = "error"
SQL_TABLES = """
CREATE TABLE version AS SELECT 3 AS version;
CREATE TABLE images (
id VARCHAR(32) PRIMARY KEY,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
url TEXT NOT NULL
);
CREATE TABLE elements (
id VARCHAR(32) PRIMARY KEY,
parent_id VARCHAR(32),
type TEXT NOT NULL,
image_id VARCHAR(32),
polygon TEXT,
rotation_angle INTEGER NOT NULL DEFAULT 0,
mirrored INTEGER NOT NULL DEFAULT 0,
initial BOOLEAN DEFAULT 0 NOT NULL,
worker_version_id VARCHAR(32),
worker_run_id VARCHAR(32),
confidence REAL,
FOREIGN KEY(image_id) REFERENCES images(id),
CHECK (rotation_angle >= 0 AND rotation_angle <= 359),
CHECK (mirrored = 0 OR mirrored = 1),
CHECK (confidence IS NULL OR (confidence >= 0 AND confidence <= 1))
);
"""
def split_chunks(items, n): def split_chunks(items, n):
""" """
...@@ -53,6 +34,21 @@ def split_chunks(items, n): ...@@ -53,6 +34,21 @@ def split_chunks(items, n):
yield items[i::n] yield items[i::n]
class Element(BaseElement):
"""
Some attributes in the `ListProcessElements` response conflict with the `BaseElement` property/function.
Override them to access the API response directly.
"""
@property
def polygon(self) -> list[float]:
return self["polygon"]
@property
def image_url(self) -> list[float]:
return self["image_url"]
class InitElementWorker(BaseWorker): class InitElementWorker(BaseWorker):
def configure(self) -> None: def configure(self) -> None:
# CLI args are stored on the instance so that implementations can access them # CLI args are stored on the instance so that implementations can access them
...@@ -73,87 +69,114 @@ class InitElementWorker(BaseWorker): ...@@ -73,87 +69,114 @@ class InitElementWorker(BaseWorker):
self.use_cache = self.config["use_cache"] self.use_cache = self.config["use_cache"]
self.api_client.sleep_duration = self.config["sleep"] self.api_client.sleep_duration = self.config["sleep"]
assert self.worker_run_id, "Missing ARKINDEX_WORKER_RUN_ID environment variable, cannot list the elements to process" assert self.worker_run_id, "Missing ARKINDEX_WORKER_RUN_ID environment variable, cannot retrieve process information"
self.process = self.request("RetrieveWorkerRun", id=self.worker_run_id)[ self.process = self.request("RetrieveWorkerRun", id=self.worker_run_id)[
"process" "process"
] ]
def dump_json(self, elements, filename="elements.json"): def dump_json(
file_path = self.work_dir / filename self, elements: list[Element], filename: str = "elements.json"
assert len(elements), f"No elements could be written to {file_path}" ) -> None:
path = self.work_dir / filename
assert not path.exists(), f"JSON at {path} already exists"
with file_path.open("w") as f: path.write_text(json.dumps(elements, indent=4))
json.dump(elements, f, indent=4)
def dump_sqlite(self, elements, filename="db.sqlite"): def dump_sqlite(self, elements: list[Element], filename: str = "db.sqlite") -> None:
if not self.use_cache: if not self.use_cache:
return return
path = self.work_dir / filename path = self.work_dir / filename
assert not path.exists(), f"Database at {path} already exists" assert not path.exists(), f"Database at {path} already exists"
db = sqlite3.connect(str(path)) db = sqlite3.connect(str(path))
db.executescript(SQL_TABLES)
# Set of unique images found in the elements init_cache_db(path)
image_rows = { create_version_table()
( create_tables()
uuid.UUID(element["image_id"]).hex,
element["image_width"],
element["image_height"],
element["image_url"],
)
for element in elements
if element.get("image_id") is not None
}
db.executemany(
"INSERT INTO images (id, width, height, url) VALUES (?,?,?,?)", image_rows
)
element_rows = [ # Set of unique images found in the elements
( CachedImage.insert_many(
uuid.UUID(element["id"]).hex, [
self.type_slugs[element["type_id"]], {
uuid.UUID(element["image_id"]).hex if element.get("image_id") else None, "id": uuid.UUID(element.image_id).hex,
json.dumps(element["polygon"]) if element.get("polygon") else None, "width": element.image_width,
element.get("rotation_angle") or 0, "height": element.image_height,
element.get("mirrored") or 0, "url": element.image_url,
element.get("confidence"), }
) for element in elements
for element in elements if element.get("image_id") is not None
] ]
db.executemany( ).on_conflict_ignore(ignore=True).execute()
"INSERT INTO elements (id, type, image_id, polygon, rotation_angle, mirrored, confidence, initial) VALUES (?,?,?,?,?,?,?,1)",
element_rows, # Fastest way to INSERT multiple rows.
) CachedElement.insert_many(
[
{
"id": uuid.UUID(element.id).hex,
"type": element.type,
"image_id": uuid.UUID(element.image_id).hex
if element.get("image_id")
else None,
"polygon": (element.polygon if element.get("polygon") else None),
"rotation_angle": element.get("rotation_angle") or 0,
"mirrored": element.get("mirrored") or False,
"confidence": element.get("confidence"),
"initial": True,
}
for element in elements
]
).execute()
db.commit()
db.close() db.close()
def check_worker_activity(self): def dump_chunks(self, elements: list[Element]) -> None:
# Check if workers activity associated to this process is in a pending state assert (
process = self.request("RetrieveProcess", id=self.process["id"]) len(elements) >= self.chunks_number
if process.get("activity_state") == ERROR_STATE: ), f"Too few elements have been retrieved to distribute workflow among {self.chunks_number} branches"
logger.error(
"Workers activity could not be initialized. Please report this incident to an instance administrator." for index, chunk_elts in enumerate(
split_chunks(elements, self.chunks_number),
start=1,
):
self.dump_json(
elements=[
{
"id": element.id,
"type": element.type,
}
for element in chunk_elts
],
**(
{"filename": f"elements_chunk_{index}.json"}
if self.chunks_number > 1
else {}
),
)
self.dump_sqlite(
elements=chunk_elts,
**(
{"filename": f"db_{index}.sqlite"} if self.chunks_number > 1 else {}
),
) )
sys.exit(1)
return process.get("activity_state") != PENDING_STATE
def run(self): logger.info(
assert self.worker_run_id, "Missing ARKINDEX_WORKER_RUN_ID environment variable, cannot list the elements to process" f"Added {len(elements)} element{'s'[:len(elements) > 1]} to workflow configuration"
)
def list_process_elements(self) -> list[Element]:
assert self.process.get( assert self.process.get(
"corpus" "corpus"
), "init_elements only supports processes on corpora." ), "init_elements only supports processes on corpora."
corpus = self.request("RetrieveCorpus", id=self.process["corpus"]) corpus = self.request("RetrieveCorpus", id=self.process["corpus"])
self.type_slugs = { type_slugs = {
element_type["id"]: element_type["slug"] for element_type in corpus["types"] element_type["id"]: element_type["slug"] for element_type in corpus["types"]
} }
elements = list( elements = list(
self.api_client.paginate( Element(**element, type=type_slugs[element["type_id"]])
for element in self.api_client.paginate(
"ListProcessElements", "ListProcessElements",
id=self.process["id"], id=self.process["id"],
with_image=self.use_cache, with_image=self.use_cache,
...@@ -178,42 +201,19 @@ class InitElementWorker(BaseWorker): ...@@ -178,42 +201,19 @@ class InitElementWorker(BaseWorker):
logger.error("No elements found, aborting workflow.") logger.error("No elements found, aborting workflow.")
sys.exit(1) sys.exit(1)
assert ( return list(unique_elements.values())
len(unique_elements) >= self.chunks_number
), f"Too few elements have been retrieved to distribute workflow among {self.chunks_number} branches"
if self.chunks_number == 1: def check_worker_activity(self):
self.dump_json( # Check if workers activity associated to this process is in a pending state
elements=[ process = self.request("RetrieveProcess", id=self.process["id"])
{"id": element_id, "type": self.type_slugs[element["type_id"]]} if process.get("activity_state") == ERROR_STATE:
for element_id, element in unique_elements.items() logger.error(
] "Workers activity could not be initialized. Please report this incident to an instance administrator."
) )
self.dump_sqlite(unique_elements.values()) sys.exit(1)
else: return process.get("activity_state") != PENDING_STATE
for index, chunk_elts in enumerate(
split_chunks(list(unique_elements.values()), self.chunks_number),
start=1,
):
self.dump_json(
elements=[
{
"id": element["id"],
"type": self.type_slugs[element["type_id"]],
}
for element in chunk_elts
],
filename=f"elements_chunk_{index}.json",
)
self.dump_sqlite(
elements=chunk_elts,
filename=f"db_{index}.sqlite",
)
logger.info(
f"Added {len(unique_elements)} element{'s'[:len(unique_elements) > 1]} to workflow configuration"
)
def await_worker_activity(self) -> None:
logger.info("Awaiting workers activity initialization") logger.info("Awaiting workers activity initialization")
# Await worker activities to be initialized for 0, 2, 4, 8 seconds up to an hour # Await worker activities to be initialized for 0, 2, 4, 8 seconds up to an hour
timer = 1 timer = 1
...@@ -229,6 +229,12 @@ class InitElementWorker(BaseWorker): ...@@ -229,6 +229,12 @@ class InitElementWorker(BaseWorker):
raise Exception("Worker activity timeout") raise Exception("Worker activity timeout")
sleep(timer) sleep(timer)
def run(self):
elements = self.list_process_elements()
self.dump_chunks(elements)
self.await_worker_activity()
def main() -> None: def main() -> None:
InitElementWorker( InitElementWorker(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment