Skip to content
Snippets Groups Projects
Commit d572f849 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'init-element-worker' into 'master'

Port init elements code

Closes #2

See merge request !2
parents 272fefcb 49654fca
No related branches found
No related tags found
1 merge request!2Port init elements code
Pipeline #165792 passed
......@@ -7,3 +7,19 @@ workers:
type: extractor
docker:
build: Dockerfile
user_configuration:
chunks_number:
title: Number of chunks to split workflow into after initialisation
type: int
default: 1
required: true
use_cache:
title: Enable SQLite database generation for worker caching
type: bool
default: false
required: false
sleep:
title: Throttle API requests by waiting for a given number of seconds
type: float
default: 0.0
required: true
#!/bin/sh -e
pip install ${PIP_FLAGS} .
import json
import sqlite3
from pathlib import Path
from arkindex_worker.cache import SQL_VERSION
def check_json(json_path: Path, elements: list) -> None:
assert json_path.is_file()
assert json.loads(json_path.read_text()) == elements
def check_db(db_path: Path, elements: list, images: list) -> None:
assert db_path.is_file()
db = sqlite3.connect(str(db_path))
db.row_factory = sqlite3.Row
assert list(map(dict, db.execute("select * from version").fetchall())) == [
{"version": SQL_VERSION}
]
assert (
list(map(dict, db.execute("select * from elements order by id").fetchall()))
== elements
)
assert (
list(map(dict, db.execute("select * from images order by id").fetchall()))
== images
)
db.close()
import os
import sys
import pytest
from arkindex.mock import MockApiClient
from arkindex_worker.worker.base import BaseWorker
from worker_init_elements.worker import InitElementsWorker
@pytest.fixture()
def mock_api_client() -> MockApiClient:
return MockApiClient()
@pytest.fixture(autouse=True)
def _setup_environment(responses, monkeypatch) -> None:
def _setup_environment(mock_api_client: MockApiClient, responses, monkeypatch) -> None:
"""Setup needed environment variables"""
# Allow accessing remote API schemas
......@@ -27,6 +34,84 @@ def _setup_environment(responses, monkeypatch) -> None:
# Setup a mock api client instead of using a real one
def mock_setup_api_client(self):
self.api_client = MockApiClient()
self.api_client = mock_api_client
monkeypatch.setattr(BaseWorker, "setup_api_client", mock_setup_api_client)
@pytest.fixture()
def _mock_worker_run_api(mock_api_client: MockApiClient) -> None:
"""Provide a mock API response to get worker run information"""
mock_api_client.add_response(
"RetrieveWorkerRun",
id=os.getenv("ARKINDEX_WORKER_RUN_ID"),
response={
"id": os.getenv("ARKINDEX_WORKER_RUN_ID"),
"worker_version": {
"id": "12341234-1234-1234-1234-123412341234",
"revision": {"hash": "deadbeef1234"},
"worker": {"name": "Fake worker"},
"configuration": {
"name": "Init Elements",
"slug": "init-elements",
"type": "extractor",
"docker": {
"build": "Dockerfile",
"image": "",
"command": None,
"context": None,
"shm_size": None,
"environment": {},
},
"secrets": [],
"description": None,
"configuration": {},
"user_configuration": {
"chunks_number": {
"type": "int",
"title": "Chunks number",
"default": 1,
"required": True,
},
"use_cache": {
"type": "bool",
"title": "Use cache",
"default": False,
},
"threshold_value": {
"type": "float",
"title": "Threshold Value",
"default": 0.1,
"subtype": "number",
"required": False,
},
"sleep": {
"type": "float",
"title": "Sleep",
"default": 0.0,
},
},
},
},
"configuration": None,
"process": {
"id": "process_id",
"corpus": os.getenv("ARKINDEX_CORPUS_ID"),
"activity_state": "disabled",
},
"summary": os.getenv("ARKINDEX_WORKER_RUN_ID") + " @ version 1",
},
)
@pytest.fixture()
def mock_worker(
_mock_worker_run_api, tmp_path_factory, monkeypatch
) -> InitElementsWorker:
monkeypatch.setattr(sys, "argv", ["worker-init-elements"])
worker = InitElementsWorker()
worker.work_dir = tmp_path_factory.mktemp("data")
worker.configure()
return worker
import logging
import pytest
from tests import check_json
from worker_init_elements.worker import INIT_PAGE_SIZE
def test_activity_state_awaiting(mock_worker, monkeypatch):
"""
Init task must wait until the backend has initialized worker activities for this process
"""
mock_worker.process_information["activity_state"] = "pending"
sleep_args = iter([2, 4, 8])
def mock_sleep(seconds) -> None:
assert seconds == next(sleep_args)
monkeypatch.setattr("worker_init_elements.worker.sleep", mock_sleep)
# Report pending three times when the task is waiting for activity initialization.
for state in ["pending", "pending", "pending", "ready"]:
mock_worker.api_client.add_response(
"RetrieveProcess",
id=mock_worker.process_information["id"],
response={
"activity_state": state,
"corpus": "corpusid",
},
)
mock_worker.api_client.add_response(
"RetrieveCorpus",
id=mock_worker.process_information["corpus"],
response={
"id": mock_worker.process_information["corpus"],
"types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
},
)
mock_worker.api_client.add_response(
"ListProcessElements",
id=mock_worker.process_information["id"],
with_image=False,
allow_missing_data=True,
page_size=INIT_PAGE_SIZE,
response=[
{
"id": "11111111-1111-1111-1111-111111111111",
"type_id": "A",
"name": "Class 1",
}
],
)
mock_worker.process()
check_json(
json_path=mock_worker.work_dir / "elements.json",
elements=[
{"id": "11111111-1111-1111-1111-111111111111", "type": "class"},
],
)
assert not mock_worker.api_client.responses
def test_activity_state_timeout(mock_worker, caplog, monkeypatch):
"""
Await workers activity to be ready for an hour before raising an error
"""
caplog.set_level(logging.WARNING)
mock_worker.process_information["activity_state"] = "pending"
sleep_args = []
monkeypatch.setattr(
"worker_init_elements.worker.sleep", lambda seconds: sleep_args.append(seconds)
)
# Perpetually reply the activity is in a pending state
for _ in range(12):
mock_worker.api_client.add_response(
"RetrieveProcess",
id=mock_worker.process_information["id"],
response={
"activity_state": "pending",
"corpus": "corpusid",
},
)
mock_worker.api_client.add_response(
"RetrieveCorpus",
id=mock_worker.process_information["corpus"],
response={
"id": mock_worker.process_information["corpus"],
"types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
},
)
mock_worker.api_client.add_response(
"ListProcessElements",
id=mock_worker.process_information["id"],
with_image=False,
allow_missing_data=True,
page_size=INIT_PAGE_SIZE,
response=[
{
"id": "11111111-1111-1111-1111-111111111111",
"type_id": "A",
"name": "Class 1",
}
],
)
with pytest.raises(Exception, match="Worker activity timeout"):
mock_worker.process()
assert sum(sleep_args) == 4094
assert [(record.levelname, record.message) for record in caplog.records] == [
(
"ERROR",
"Workers activity not initialized 68 minutes after starting the process."
" Please report this incident to an instance administrator.",
)
]
This diff is collapsed.
import importlib
def test_dummy():
assert True
def test_import():
"""Import our newly created module, through importlib to avoid parsing issues"""
worker = importlib.import_module("worker_init_elements.worker")
assert hasattr(worker, "Demo")
assert hasattr(worker.Demo, "process_element")
import json
import sqlite3
import sys
import uuid
from collections import OrderedDict
from collections.abc import Iterator
from enum import Enum
from logging import Logger, getLogger
from time import sleep
from arkindex_worker.models import Element
from arkindex_worker.worker import ElementsWorker
from arkindex_worker.cache import (
CachedElement,
CachedImage,
create_tables,
create_version_table,
init_cache_db,
)
from arkindex_worker.worker.base import BaseWorker
logger: Logger = getLogger(__name__)
# Increases the number of elements returned per page by the API
INIT_PAGE_SIZE = 500
class Demo(ElementsWorker):
def process_element(self, element: Element) -> None:
logger.info(f"Demo processing element ({element.id})")
def split_chunks(items: list, n: int) -> Iterator[list]:
"""
Yield n number of elements from a given list with a balanced distribution
https://stackoverflow.com/questions/24483182/python-split-list-into-n-chunks#answer-54802737
"""
for i in range(0, n):
yield items[i::n]
class ActivityState(Enum):
"""
Store the state of the workers activity tracking for a process.
To support large elements set, the state is asynchronously set to `ready` after a process
has been started and worker activities have been initialized on its elements.
"""
Disabled = "disabled"
"""
Worker activities are disabled and will not be used
"""
Pending = "pending"
"""
Worker activities are not yet initialized
"""
Ready = "ready"
"""
Worker activities are initialized and ready for use
"""
Error = "error"
"""
An error occurred when initializing worker activities
"""
class InitElementsWorker(BaseWorker):
def configure(self) -> None:
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
# Retrieve the user configuration
if self.user_configuration:
self.config.update(self.user_configuration)
logger.info("User configuration retrieved")
self.chunks_number = self.config["chunks_number"]
self.use_cache = self.config["use_cache"]
self.api_client.sleep_duration = self.config["sleep"]
def dump_json(self, elements: list[dict], filename: str = "elements.json") -> None:
"""
Store elements in a JSON file.
This file will become an artefact.
"""
path = self.work_dir / filename
assert not path.exists(), f"JSON at {path} already exists"
path.write_text(json.dumps(elements, indent=4))
def dump_sqlite(self, elements: list[dict], filename: str = "db.sqlite") -> None:
"""
Store elements in a SQLite database. Only images and elements will be added.
This file will become an artefact.
"""
if not self.use_cache:
return
path = self.work_dir / filename
assert not path.exists(), f"Database at {path} already exists"
db = sqlite3.connect(str(path))
init_cache_db(path)
create_version_table()
create_tables()
# Set of unique images found in the elements
CachedImage.insert_many(
{
"id": uuid.UUID(element["image_id"]).hex,
"width": element["image_width"],
"height": element["image_height"],
"url": element["image_url"],
}
for element in elements
if element["image_id"]
).on_conflict_ignore(ignore=True).execute()
# Fastest way to INSERT multiple rows.
CachedElement.insert_many(
{
"id": uuid.UUID(element["id"]).hex,
"type": element["type"],
"image_id": (
uuid.UUID(element["image_id"]).hex if element["image_id"] else None
),
"polygon": element["polygon"],
"rotation_angle": element["rotation_angle"],
"mirrored": element["mirrored"],
"confidence": element["confidence"],
"initial": True,
}
for element in elements
).execute()
db.close()
def dump_chunks(self, elements: list[dict]) -> None:
"""
Store elements in a JSON file(s) and SQLite database(s).
If several chunks are requested, the files will be suffixed with the chunk index.
"""
assert (
len(elements) >= self.chunks_number
), f"Too few elements have been retrieved to distribute workflow among {self.chunks_number} branches"
for index, chunk_elts in enumerate(
split_chunks(elements, self.chunks_number),
start=1,
):
self.dump_json(
elements=[
{
"id": element["id"],
"type": element["type"],
}
for element in chunk_elts
],
**(
{"filename": f"elements_chunk_{index}.json"}
if self.chunks_number > 1
else {}
),
)
self.dump_sqlite(
elements=chunk_elts,
**(
{"filename": f"db_{index}.sqlite"} if self.chunks_number > 1 else {}
),
)
logger.info(
f"Added {len(elements)} element{'s'[:len(elements) > 1]} to workflow configuration"
)
def list_process_elements(self) -> list[dict]:
"""
List all elements linked to this process and remove duplicates
"""
assert self.process_information.get(
"corpus"
), "This worker only supports processes on corpora."
corpus = self.request("RetrieveCorpus", id=self.process_information["corpus"])
type_slugs = {
element_type["id"]: element_type["slug"] for element_type in corpus["types"]
}
elements = [
{**element, "type": type_slugs[element["type_id"]]}
for element in self.api_client.paginate(
"ListProcessElements",
id=self.process_information["id"],
with_image=self.use_cache,
allow_missing_data=True,
page_size=INIT_PAGE_SIZE,
)
]
# Use a dict to make elements unique by ID, then turn them back into a elements.json-compatible list
unique_elements = OrderedDict(
[(element["id"], element) for element in elements]
)
logger.info(
f"Retrieved {len(unique_elements)} element{'s'[:len(unique_elements) > 1]} from process {self.process_information['id']}"
)
duplicate_count = len(elements) - len(unique_elements)
if duplicate_count:
logger.warning(f"{duplicate_count} duplicate elements have been ignored.")
if not unique_elements:
logger.error("No elements found, aborting workflow.")
sys.exit(1)
return list(unique_elements.values())
def check_worker_activity(self) -> bool:
"""
Check if workers activity associated to this process is in a pending state
"""
activity_state = ActivityState(
self.request("RetrieveProcess", id=self.process_information["id"])[
"activity_state"
]
)
if activity_state == ActivityState.Error:
logger.error(
"Worker activities could not be initialized. Please report this incident to an instance administrator."
)
sys.exit(1)
return activity_state == ActivityState.Ready
def await_worker_activity(self) -> None:
"""
Worker activities are initialized asynchronously after a process has been started.
This worker should be running until all activities have moved to `Ready`.
"""
if (
ActivityState(self.process_information["activity_state"])
== ActivityState.Disabled
):
return
logger.info("Awaiting worker activities initialization")
# Await worker activities to be initialized for 0, 2, 4, 8 seconds up to an hour
timer = 1
while True:
if self.check_worker_activity():
break
timer *= 2
if timer >= 3600:
logger.error(
f"Workers activity not initialized {int(timer/60)} minutes after starting the process."
" Please report this incident to an instance administrator."
)
raise Exception("Worker activity timeout")
sleep(timer)
def process(self) -> None:
elements = self.list_process_elements()
self.dump_chunks(elements)
self.await_worker_activity()
def run(self) -> None:
self.configure()
self.process()
def main() -> None:
Demo(description="Worker to initialize Arkindex elements to process").run()
InitElementsWorker(
description="Worker to initialize Arkindex elements to process"
).run()
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment