From 49654fca0f9e0412293ac77791ff951c6ed2c2a0 Mon Sep 17 00:00:00 2001
From: Manon Blanco <blanco@teklia.com>
Date: Thu, 28 Mar 2024 15:54:56 +0000
Subject: [PATCH] Port init elements code

---
 .arkindex.yml                  |  16 +
 install_ponos_slurm.sh         |   3 +
 tests/__init__.py              |  31 ++
 tests/conftest.py              |  89 ++++-
 tests/test_activity.py         | 123 +++++++
 tests/test_run.py              | 654 +++++++++++++++++++++++++++++++++
 tests/test_worker.py           |  12 -
 worker_init_elements/worker.py | 273 +++++++++++++-
 8 files changed, 1181 insertions(+), 20 deletions(-)
 create mode 100755 install_ponos_slurm.sh
 create mode 100644 tests/__init__.py
 create mode 100644 tests/test_activity.py
 create mode 100644 tests/test_run.py
 delete mode 100644 tests/test_worker.py

diff --git a/.arkindex.yml b/.arkindex.yml
index 8e23acd..a8dad8f 100644
--- a/.arkindex.yml
+++ b/.arkindex.yml
@@ -7,3 +7,19 @@ workers:
     type: extractor
     docker:
       build: Dockerfile
+    user_configuration:
+      chunks_number:
+        title: Number of chunks to split workflow into after initialisation
+        type: int
+        default: 1
+        required: true
+      use_cache:
+        title: Enable SQLite database generation for worker caching
+        type: bool
+        default: false
+        required: false
+      sleep:
+        title: Throttle API requests by waiting for a given number of seconds
+        type: float
+        default: 0.0
+        required: true
diff --git a/install_ponos_slurm.sh b/install_ponos_slurm.sh
new file mode 100755
index 0000000..b6745b0
--- /dev/null
+++ b/install_ponos_slurm.sh
@@ -0,0 +1,3 @@
+#!/bin/sh -e
+
+pip install ${PIP_FLAGS} .
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e1f7fe4
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,31 @@
+import json
+import sqlite3
+from pathlib import Path
+
+from arkindex_worker.cache import SQL_VERSION
+
+
+def check_json(json_path: Path, elements: list) -> None:
+    assert json_path.is_file()
+
+    assert json.loads(json_path.read_text()) == elements
+
+
+def check_db(db_path: Path, elements: list, images: list) -> None:
+    assert db_path.is_file()
+    db = sqlite3.connect(str(db_path))
+    db.row_factory = sqlite3.Row
+
+    assert list(map(dict, db.execute("select * from version").fetchall())) == [
+        {"version": SQL_VERSION}
+    ]
+    assert (
+        list(map(dict, db.execute("select * from elements order by id").fetchall()))
+        == elements
+    )
+    assert (
+        list(map(dict, db.execute("select * from images order by id").fetchall()))
+        == images
+    )
+
+    db.close()
diff --git a/tests/conftest.py b/tests/conftest.py
index 47295a4..40f1ef4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,13 +1,20 @@
 import os
+import sys
 
 import pytest
 
 from arkindex.mock import MockApiClient
 from arkindex_worker.worker.base import BaseWorker
+from worker_init_elements.worker import InitElementsWorker
+
+
+@pytest.fixture()
+def mock_api_client() -> MockApiClient:
+    return MockApiClient()
 
 
 @pytest.fixture(autouse=True)
-def _setup_environment(responses, monkeypatch) -> None:
+def _setup_environment(mock_api_client: MockApiClient, responses, monkeypatch) -> None:
     """Setup needed environment variables"""
 
     # Allow accessing remote API schemas
@@ -27,6 +34,84 @@ def _setup_environment(responses, monkeypatch) -> None:
 
     # Setup a mock api client instead of using a real one
     def mock_setup_api_client(self):
-        self.api_client = MockApiClient()
+        self.api_client = mock_api_client
 
     monkeypatch.setattr(BaseWorker, "setup_api_client", mock_setup_api_client)
+
+
+@pytest.fixture()
+def _mock_worker_run_api(mock_api_client: MockApiClient) -> None:
+    """Provide a mock API response to get worker run information"""
+    mock_api_client.add_response(
+        "RetrieveWorkerRun",
+        id=os.getenv("ARKINDEX_WORKER_RUN_ID"),
+        response={
+            "id": os.getenv("ARKINDEX_WORKER_RUN_ID"),
+            "worker_version": {
+                "id": "12341234-1234-1234-1234-123412341234",
+                "revision": {"hash": "deadbeef1234"},
+                "worker": {"name": "Fake worker"},
+                "configuration": {
+                    "name": "Init Elements",
+                    "slug": "init-elements",
+                    "type": "extractor",
+                    "docker": {
+                        "build": "Dockerfile",
+                        "image": "",
+                        "command": None,
+                        "context": None,
+                        "shm_size": None,
+                        "environment": {},
+                    },
+                    "secrets": [],
+                    "description": None,
+                    "configuration": {},
+                    "user_configuration": {
+                        "chunks_number": {
+                            "type": "int",
+                            "title": "Chunks number",
+                            "default": 1,
+                            "required": True,
+                        },
+                        "use_cache": {
+                            "type": "bool",
+                            "title": "Use cache",
+                            "default": False,
+                        },
+                        "threshold_value": {
+                            "type": "float",
+                            "title": "Threshold Value",
+                            "default": 0.1,
+                            "subtype": "number",
+                            "required": False,
+                        },
+                        "sleep": {
+                            "type": "float",
+                            "title": "Sleep",
+                            "default": 0.0,
+                        },
+                    },
+                },
+            },
+            "configuration": None,
+            "process": {
+                "id": "process_id",
+                "corpus": os.getenv("ARKINDEX_CORPUS_ID"),
+                "activity_state": "disabled",
+            },
+            "summary": os.getenv("ARKINDEX_WORKER_RUN_ID") + " @ version 1",
+        },
+    )
+
+
+@pytest.fixture()
+def mock_worker(
+    _mock_worker_run_api, tmp_path_factory, monkeypatch
+) -> InitElementsWorker:
+    monkeypatch.setattr(sys, "argv", ["worker-init-elements"])
+
+    worker = InitElementsWorker()
+    worker.work_dir = tmp_path_factory.mktemp("data")
+    worker.configure()
+
+    return worker
diff --git a/tests/test_activity.py b/tests/test_activity.py
new file mode 100644
index 0000000..c2aeffc
--- /dev/null
+++ b/tests/test_activity.py
@@ -0,0 +1,123 @@
+import logging
+
+import pytest
+
+from tests import check_json
+from worker_init_elements.worker import INIT_PAGE_SIZE
+
+
+def test_activity_state_awaiting(mock_worker, monkeypatch):
+    """
+    Init task must wait until the backend has initialized worker activities for this process
+    """
+    mock_worker.process_information["activity_state"] = "pending"
+
+    sleep_args = iter([2, 4, 8])
+
+    def mock_sleep(seconds) -> None:
+        assert seconds == next(sleep_args)
+
+    monkeypatch.setattr("worker_init_elements.worker.sleep", mock_sleep)
+
+    # Report pending three times when the task is waiting for activity initialization.
+    for state in ["pending", "pending", "pending", "ready"]:
+        mock_worker.api_client.add_response(
+            "RetrieveProcess",
+            id=mock_worker.process_information["id"],
+            response={
+                "activity_state": state,
+                "corpus": "corpusid",
+            },
+        )
+    mock_worker.api_client.add_response(
+        "RetrieveCorpus",
+        id=mock_worker.process_information["corpus"],
+        response={
+            "id": mock_worker.process_information["corpus"],
+            "types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
+        },
+    )
+    mock_worker.api_client.add_response(
+        "ListProcessElements",
+        id=mock_worker.process_information["id"],
+        with_image=False,
+        allow_missing_data=True,
+        page_size=INIT_PAGE_SIZE,
+        response=[
+            {
+                "id": "11111111-1111-1111-1111-111111111111",
+                "type_id": "A",
+                "name": "Class 1",
+            }
+        ],
+    )
+
+    mock_worker.process()
+
+    check_json(
+        json_path=mock_worker.work_dir / "elements.json",
+        elements=[
+            {"id": "11111111-1111-1111-1111-111111111111", "type": "class"},
+        ],
+    )
+
+    assert not mock_worker.api_client.responses
+
+
+def test_activity_state_timeout(mock_worker, caplog, monkeypatch):
+    """
+    Await workers activity to be ready for an hour before raising an error
+    """
+    caplog.set_level(logging.WARNING)
+
+    mock_worker.process_information["activity_state"] = "pending"
+
+    sleep_args = []
+    monkeypatch.setattr(
+        "worker_init_elements.worker.sleep", lambda seconds: sleep_args.append(seconds)
+    )
+
+    # Perpetually reply the activity is in a pending state
+    for _ in range(12):
+        mock_worker.api_client.add_response(
+            "RetrieveProcess",
+            id=mock_worker.process_information["id"],
+            response={
+                "activity_state": "pending",
+                "corpus": "corpusid",
+            },
+        )
+    mock_worker.api_client.add_response(
+        "RetrieveCorpus",
+        id=mock_worker.process_information["corpus"],
+        response={
+            "id": mock_worker.process_information["corpus"],
+            "types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
+        },
+    )
+    mock_worker.api_client.add_response(
+        "ListProcessElements",
+        id=mock_worker.process_information["id"],
+        with_image=False,
+        allow_missing_data=True,
+        page_size=INIT_PAGE_SIZE,
+        response=[
+            {
+                "id": "11111111-1111-1111-1111-111111111111",
+                "type_id": "A",
+                "name": "Class 1",
+            }
+        ],
+    )
+
+    with pytest.raises(Exception, match="Worker activity timeout"):
+        mock_worker.process()
+
+    assert sum(sleep_args) == 4094
+    assert [(record.levelname, record.message) for record in caplog.records] == [
+        (
+            "ERROR",
+            "Workers activity not initialized 68 minutes after starting the process."
+            " Please report this incident to an instance administrator.",
+        )
+    ]
diff --git a/tests/test_run.py b/tests/test_run.py
new file mode 100644
index 0000000..798495e
--- /dev/null
+++ b/tests/test_run.py
@@ -0,0 +1,654 @@
+import logging
+
+import pytest
+
+from tests import check_db, check_json
+from worker_init_elements.worker import INIT_PAGE_SIZE
+
+
+@pytest.mark.parametrize("use_cache", [True, False])
+def test_run_process(use_cache, mock_worker):
+    mock_worker.use_cache = use_cache
+
+    mock_worker.api_client.add_response(
+        "RetrieveCorpus",
+        id=mock_worker.process_information["corpus"],
+        response={
+            "id": mock_worker.process_information["corpus"],
+            "types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
+        },
+    )
+    mock_worker.api_client.add_response(
+        "ListProcessElements",
+        id=mock_worker.process_information["id"],
+        with_image=mock_worker.use_cache,
+        allow_missing_data=True,
+        page_size=INIT_PAGE_SIZE,
+        response=[
+            {
+                "id": "11111111-1111-1111-1111-111111111111",
+                "type_id": "A",
+                "name": "Class 1",
+                "confidence": None,
+                **(
+                    {
+                        "image_id": "cafecafe-cafe-cafe-cafe-cafecafecafe",
+                        "image_width": 42,
+                        "image_height": 1337,
+                        "image_url": "http://cafe",
+                        "polygon": [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]],
+                        "rotation_angle": 42,
+                        "mirrored": False,
+                    }
+                    if mock_worker.use_cache
+                    else {}
+                ),
+            },
+            {
+                "id": "22222222-2222-2222-2222-222222222222",
+                "type_id": "A",
+                "name": "Class 2",
+                "confidence": None,
+                **(
+                    {
+                        "image_id": "beefbeef-beef-beef-beef-beefbeefbeef",
+                        "image_width": 42,
+                        "image_height": 1337,
+                        "image_url": "http://beef",
+                        "polygon": [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]],
+                        "rotation_angle": 0,
+                        "mirrored": True,
+                    }
+                    if mock_worker.use_cache
+                    else {}
+                ),
+            },
+            {
+                "id": "33333333-3333-3333-3333-333333333333",
+                "type_id": "A",
+                "name": "Class 3",
+                "confidence": 0.42,
+                **(
+                    {
+                        "rotation_angle": 17,
+                        "mirrored": True,
+                        "image_id": None,
+                        "image_width": None,
+                        "image_height": None,
+                        "image_url": None,
+                        "polygon": None,
+                    }
+                    if mock_worker.use_cache
+                    else {}
+                ),
+            },
+        ],
+    )
+
+    mock_worker.process()
+
+    check_json(
+        json_path=mock_worker.work_dir / "elements.json",
+        elements=[
+            {"id": "11111111-1111-1111-1111-111111111111", "type": "class"},
+            {"id": "22222222-2222-2222-2222-222222222222", "type": "class"},
+            {"id": "33333333-3333-3333-3333-333333333333", "type": "class"},
+        ],
+    )
+
+    db_path = mock_worker.work_dir / "db.sqlite"
+    assert db_path.is_file() is use_cache
+    if use_cache:
+        check_db(
+            db_path=db_path,
+            elements=[
+                {
+                    "id": "11111111111111111111111111111111",
+                    "image_id": "cafecafecafecafecafecafecafecafe",
+                    "initial": 1,
+                    "parent_id": None,
+                    "polygon": "[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]",
+                    "type": "class",
+                    "worker_version_id": None,
+                    "worker_run_id": None,
+                    "rotation_angle": 42,
+                    "mirrored": False,
+                    "confidence": None,
+                },
+                {
+                    "id": "22222222222222222222222222222222",
+                    "image_id": "beefbeefbeefbeefbeefbeefbeefbeef",
+                    "initial": 1,
+                    "parent_id": None,
+                    "polygon": "[[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]]",
+                    "type": "class",
+                    "worker_version_id": None,
+                    "worker_run_id": None,
+                    "rotation_angle": 0,
+                    "mirrored": True,
+                    "confidence": None,
+                },
+                {
+                    "id": "33333333333333333333333333333333",
+                    "image_id": None,
+                    "initial": 1,
+                    "parent_id": None,
+                    "polygon": None,
+                    "type": "class",
+                    "worker_version_id": None,
+                    "worker_run_id": None,
+                    "rotation_angle": 17,
+                    "mirrored": True,
+                    "confidence": 0.42,
+                },
+            ],
+            images=[
+                {
+                    "id": "beefbeefbeefbeefbeefbeefbeefbeef",
+                    "width": 42,
+                    "height": 1337,
+                    "url": "http://beef",
+                },
+                {
+                    "id": "cafecafecafecafecafecafecafecafe",
+                    "width": 42,
+                    "height": 1337,
+                    "url": "http://cafe",
+                },
+            ],
+        )
+
+
+def test_run_distributed(mock_worker):
+    mock_worker.use_cache = True
+    mock_worker.chunks_number = 4
+
+    mock_worker.api_client.add_response(
+        "RetrieveCorpus",
+        id=mock_worker.process_information["corpus"],
+        response={
+            "id": mock_worker.process_information["corpus"],
+            "types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
+        },
+    )
+    mock_worker.api_client.add_response(
+        "ListProcessElements",
+        id=mock_worker.process_information["id"],
+        with_image=True,
+        allow_missing_data=True,
+        page_size=INIT_PAGE_SIZE,
+        response=[
+            {
+                "id": "22222222-2222-2222-2222-222222222222",
+                "type_id": "A",
+                "name": "Class 2",
+                "image_id": None,
+                "image_width": None,
+                "image_height": None,
+                "image_url": None,
+                "polygon": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
+                "type_id": "B",
+                "name": "Student 1",
+                "image_id": "cafecafe-cafe-cafe-cafe-cafecafecafe",
+                "image_width": 42,
+                "image_height": 1337,
+                "image_url": "http://cafe",
+                "polygon": [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]],
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
+                "type_id": "B",
+                "name": "Student 2",
+                "image_id": "beefbeef-beef-beef-beef-beefbeefbeef",
+                "image_width": 42,
+                "image_height": 1337,
+                "image_url": "http://beef",
+                "polygon": [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]],
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "cccccccc-cccc-cccc-cccc-cccccccccccc",
+                "type_id": "B",
+                "name": "Student 3",
+                "image_id": "beefbeef-beef-beef-beef-beefbeefbeef",
+                "image_width": 42,
+                "image_height": 1337,
+                "image_url": "http://beef",
+                "polygon": [[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]],
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "dddddddd-dddd-dddd-dddd-dddddddddddd",
+                "type_id": "B",
+                "name": "Student 4",
+                "image_id": "cafecafe-cafe-cafe-cafe-cafecafecafe",
+                "image_width": 42,
+                "image_height": 1337,
+                "image_url": "http://cafe",
+                "polygon": [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]],
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee",
+                "type_id": "B",
+                "name": "Student 5",
+                "image_id": None,
+                "image_width": None,
+                "image_height": None,
+                "image_url": None,
+                "polygon": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "ffffffff-ffff-ffff-ffff-ffffffffffff",
+                "type_id": "B",
+                "name": "Student 6",
+                "image_id": None,
+                "image_width": None,
+                "image_height": None,
+                "image_url": None,
+                "polygon": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+        ],
+    )
+
+    mock_worker.process()
+
+    check_json(
+        json_path=mock_worker.work_dir / "elements_chunk_1.json",
+        elements=[
+            {"id": "22222222-2222-2222-2222-222222222222", "type": "class"},
+            {"id": "dddddddd-dddd-dddd-dddd-dddddddddddd", "type": "student"},
+        ],
+    )
+    check_json(
+        json_path=mock_worker.work_dir / "elements_chunk_2.json",
+        elements=[
+            {"id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "type": "student"},
+            {"id": "eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee", "type": "student"},
+        ],
+    )
+    check_json(
+        json_path=mock_worker.work_dir / "elements_chunk_3.json",
+        elements=[
+            {"id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "type": "student"},
+            {"id": "ffffffff-ffff-ffff-ffff-ffffffffffff", "type": "student"},
+        ],
+    )
+    check_json(
+        json_path=mock_worker.work_dir / "elements_chunk_4.json",
+        elements=[{"id": "cccccccc-cccc-cccc-cccc-cccccccccccc", "type": "student"}],
+    )
+
+    check_db(
+        db_path=mock_worker.work_dir / "db_1.sqlite",
+        elements=[
+            {
+                "id": "22222222222222222222222222222222",
+                "image_id": None,
+                "initial": 1,
+                "parent_id": None,
+                "polygon": None,
+                "type": "class",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "dddddddddddddddddddddddddddddddd",
+                "image_id": "cafecafecafecafecafecafecafecafe",
+                "initial": 1,
+                "parent_id": None,
+                "polygon": "[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]",
+                "type": "student",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+        ],
+        images=[
+            {
+                "id": "cafecafecafecafecafecafecafecafe",
+                "width": 42,
+                "height": 1337,
+                "url": "http://cafe",
+            },
+        ],
+    )
+
+    check_db(
+        db_path=mock_worker.work_dir / "db_2.sqlite",
+        elements=[
+            {
+                "id": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+                "image_id": "cafecafecafecafecafecafecafecafe",
+                "initial": 1,
+                "parent_id": None,
+                "polygon": "[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]",
+                "type": "student",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
+                "image_id": None,
+                "initial": 1,
+                "parent_id": None,
+                "polygon": None,
+                "type": "student",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+        ],
+        images=[
+            {
+                "id": "cafecafecafecafecafecafecafecafe",
+                "width": 42,
+                "height": 1337,
+                "url": "http://cafe",
+            },
+        ],
+    )
+
+    check_db(
+        db_path=mock_worker.work_dir / "db_3.sqlite",
+        elements=[
+            {
+                "id": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
+                "image_id": "beefbeefbeefbeefbeefbeefbeefbeef",
+                "initial": 1,
+                "parent_id": None,
+                "polygon": "[[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]]",
+                "type": "student",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "ffffffffffffffffffffffffffffffff",
+                "image_id": None,
+                "initial": 1,
+                "parent_id": None,
+                "polygon": None,
+                "type": "student",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+        ],
+        images=[
+            {
+                "id": "beefbeefbeefbeefbeefbeefbeefbeef",
+                "width": 42,
+                "height": 1337,
+                "url": "http://beef",
+            },
+        ],
+    )
+
+    check_db(
+        db_path=mock_worker.work_dir / "db_4.sqlite",
+        elements=[
+            {
+                "id": "cccccccccccccccccccccccccccccccc",
+                "image_id": "beefbeefbeefbeefbeefbeefbeefbeef",
+                "initial": 1,
+                "parent_id": None,
+                "polygon": "[[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]]",
+                "type": "student",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+        ],
+        images=[
+            {
+                "id": "beefbeefbeefbeefbeefbeefbeefbeef",
+                "width": 42,
+                "height": 1337,
+                "url": "http://beef",
+            },
+        ],
+    )
+
+
+def test_not_enough_elements(mock_worker):
+    mock_worker.chunks_number = 5
+
+    mock_worker.api_client.add_response(
+        "RetrieveCorpus",
+        id=mock_worker.process_information["corpus"],
+        response={
+            "id": mock_worker.process_information["corpus"],
+            "types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
+        },
+    )
+    mock_worker.api_client.add_response(
+        "ListProcessElements",
+        id=mock_worker.process_information["id"],
+        with_image=False,
+        allow_missing_data=True,
+        page_size=INIT_PAGE_SIZE,
+        response=[
+            {
+                "id": "22222222-2222-2222-2222-222222222222",
+                "type_id": "A",
+                "name": "Class 2",
+                "confidence": 1.0,
+            }
+        ],
+    )
+    with pytest.raises(AssertionError, match="Too few elements have been retrieved"):
+        mock_worker.process()
+
+
+def test_run_duplicates(mock_worker, caplog):
+    caplog.set_level(logging.WARNING)
+
+    mock_worker.use_cache = True
+
+    mock_worker.api_client.add_response(
+        "RetrieveCorpus",
+        id=mock_worker.process_information["corpus"],
+        response={
+            "id": mock_worker.process_information["corpus"],
+            "types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
+        },
+    )
+    mock_worker.api_client.add_response(
+        "ListProcessElements",
+        id=mock_worker.process_information["id"],
+        with_image=True,
+        allow_missing_data=True,
+        page_size=INIT_PAGE_SIZE,
+        response=[
+            {
+                "id": "11111111-1111-1111-1111-111111111111",
+                "type_id": "A",
+                "name": "Class 1",
+                "image_id": "cafecafe-cafe-cafe-cafe-cafecafecafe",
+                "image_width": 42,
+                "image_height": 1337,
+                "image_url": "http://cafe",
+                "polygon": [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]],
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "22222222-2222-2222-2222-222222222222",
+                "type_id": "A",
+                "name": "Class 2",
+                "image_id": "cafecafe-cafe-cafe-cafe-cafecafecafe",
+                "image_width": 42,
+                "image_height": 1337,
+                "image_url": "http://cafe",
+                "polygon": [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]],
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "22222222-2222-2222-2222-222222222222",
+                "type_id": "A",
+                "name": "Class 2",
+                "image_id": "cafecafe-cafe-cafe-cafe-cafecafecafe",
+                "image_width": 42,
+                "image_height": 1337,
+                "image_url": "http://cafe",
+                "polygon": [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]],
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "33333333-3333-3333-3333-333333333333",
+                "type_id": "A",
+                "name": "Class 3",
+                "image_id": None,
+                "image_width": None,
+                "image_height": None,
+                "image_url": None,
+                "polygon": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+        ],
+    )
+
+    mock_worker.process()
+
+    check_json(
+        json_path=mock_worker.work_dir / "elements.json",
+        elements=[
+            {"id": "11111111-1111-1111-1111-111111111111", "type": "class"},
+            {"id": "22222222-2222-2222-2222-222222222222", "type": "class"},
+            {"id": "33333333-3333-3333-3333-333333333333", "type": "class"},
+        ],
+    )
+
+    check_db(
+        db_path=mock_worker.work_dir / "db.sqlite",
+        elements=[
+            {
+                "id": "11111111111111111111111111111111",
+                "image_id": "cafecafecafecafecafecafecafecafe",
+                "initial": 1,
+                "parent_id": None,
+                "polygon": "[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]",
+                "type": "class",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "22222222222222222222222222222222",
+                "image_id": "cafecafecafecafecafecafecafecafe",
+                "initial": 1,
+                "parent_id": None,
+                "polygon": "[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]",
+                "type": "class",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+            {
+                "id": "33333333333333333333333333333333",
+                "image_id": None,
+                "initial": 1,
+                "parent_id": None,
+                "polygon": None,
+                "type": "class",
+                "worker_version_id": None,
+                "worker_run_id": None,
+                "rotation_angle": 0,
+                "mirrored": False,
+                "confidence": None,
+            },
+        ],
+        images=[
+            {
+                "id": "cafecafecafecafecafecafecafecafe",
+                "width": 42,
+                "height": 1337,
+                "url": "http://cafe",
+            },
+        ],
+    )
+
+    assert [(record.levelname, record.message) for record in caplog.records] == [
+        ("WARNING", "1 duplicate elements have been ignored.")
+    ]
+
+
+def test_run_empty(mock_worker, caplog):
+    caplog.set_level(logging.WARNING)
+
+    mock_worker.api_client.add_response(
+        "RetrieveCorpus",
+        id=mock_worker.process_information["corpus"],
+        response={
+            "id": mock_worker.process_information["corpus"],
+            "types": [{"id": "A", "slug": "class"}, {"id": "B", "slug": "student"}],
+        },
+    )
+    mock_worker.api_client.add_response(
+        "ListProcessElements",
+        id=mock_worker.process_information["id"],
+        with_image=False,
+        allow_missing_data=True,
+        page_size=INIT_PAGE_SIZE,
+        response=[],
+    )
+
+    with pytest.raises(SystemExit) as ctx:
+        mock_worker.process()
+
+    assert ctx.value.code == 1
+
+    assert list(mock_worker.work_dir.rglob("*")) == []
+
+    assert [(record.levelname, record.message) for record in caplog.records] == [
+        ("ERROR", "No elements found, aborting workflow."),
+    ]
diff --git a/tests/test_worker.py b/tests/test_worker.py
deleted file mode 100644
index 929a2fb..0000000
--- a/tests/test_worker.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import importlib
-
-
-def test_dummy():
-    assert True
-
-
-def test_import():
-    """Import our newly created module, through importlib to avoid parsing issues"""
-    worker = importlib.import_module("worker_init_elements.worker")
-    assert hasattr(worker, "Demo")
-    assert hasattr(worker.Demo, "process_element")
diff --git a/worker_init_elements/worker.py b/worker_init_elements/worker.py
index 047866a..d167636 100644
--- a/worker_init_elements/worker.py
+++ b/worker_init_elements/worker.py
@@ -1,18 +1,279 @@
+import json
+import sqlite3
+import sys
+import uuid
+from collections import OrderedDict
+from collections.abc import Iterator
+from enum import Enum
 from logging import Logger, getLogger
+from time import sleep
 
-from arkindex_worker.models import Element
-from arkindex_worker.worker import ElementsWorker
+from arkindex_worker.cache import (
+    CachedElement,
+    CachedImage,
+    create_tables,
+    create_version_table,
+    init_cache_db,
+)
+from arkindex_worker.worker.base import BaseWorker
 
 logger: Logger = getLogger(__name__)
 
+# Increases the number of elements returned per page by the API
+INIT_PAGE_SIZE = 500
 
-class Demo(ElementsWorker):
-    def process_element(self, element: Element) -> None:
-        logger.info(f"Demo processing element ({element.id})")
+
+def split_chunks(items: list, n: int) -> Iterator[list]:
+    """
+    Yield n number of elements from a given list with a balanced distribution
+    https://stackoverflow.com/questions/24483182/python-split-list-into-n-chunks#answer-54802737
+    """
+    for i in range(0, n):
+        yield items[i::n]
+
+
+class ActivityState(Enum):
+    """
+    Store the state of the workers activity tracking for a process.
+    To support large elements set, the state is asynchronously set to `ready` after a process
+    has been started and worker activities have been initialized on its elements.
+    """
+
+    Disabled = "disabled"
+    """
+    Worker activities are disabled and will not be used
+    """
+
+    Pending = "pending"
+    """
+    Worker activities are not yet initialized
+    """
+
+    Ready = "ready"
+    """
+    Worker activities are initialized and ready for use
+    """
+
+    Error = "error"
+    """
+    An error occurred when initializing worker activities
+    """
+
+
+class InitElementsWorker(BaseWorker):
+    def configure(self) -> None:
+        # CLI args are stored on the instance so that implementations can access them
+        self.args = self.parser.parse_args()
+
+        if self.is_read_only:
+            super().configure_for_developers()
+        else:
+            super().configure()
+
+        # Retrieve the user configuration
+        if self.user_configuration:
+            self.config.update(self.user_configuration)
+            logger.info("User configuration retrieved")
+
+        self.chunks_number = self.config["chunks_number"]
+        self.use_cache = self.config["use_cache"]
+        self.api_client.sleep_duration = self.config["sleep"]
+
+    def dump_json(self, elements: list[dict], filename: str = "elements.json") -> None:
+        """
+        Store elements in a JSON file.
+        This file will become an artefact.
+        """
+        path = self.work_dir / filename
+        assert not path.exists(), f"JSON at {path} already exists"
+
+        path.write_text(json.dumps(elements, indent=4))
+
+    def dump_sqlite(self, elements: list[dict], filename: str = "db.sqlite") -> None:
+        """
+        Store elements in a SQLite database. Only images and elements will be added.
+        This file will become an artefact.
+        """
+        if not self.use_cache:
+            return
+
+        path = self.work_dir / filename
+        assert not path.exists(), f"Database at {path} already exists"
+
+        db = sqlite3.connect(str(path))
+
+        init_cache_db(path)
+        create_version_table()
+        create_tables()
+
+        # Set of unique images found in the elements
+        CachedImage.insert_many(
+            {
+                "id": uuid.UUID(element["image_id"]).hex,
+                "width": element["image_width"],
+                "height": element["image_height"],
+                "url": element["image_url"],
+            }
+            for element in elements
+            if element["image_id"]
+        ).on_conflict_ignore(ignore=True).execute()
+
+        # Fastest way to INSERT multiple rows.
+        CachedElement.insert_many(
+            {
+                "id": uuid.UUID(element["id"]).hex,
+                "type": element["type"],
+                "image_id": (
+                    uuid.UUID(element["image_id"]).hex if element["image_id"] else None
+                ),
+                "polygon": element["polygon"],
+                "rotation_angle": element["rotation_angle"],
+                "mirrored": element["mirrored"],
+                "confidence": element["confidence"],
+                "initial": True,
+            }
+            for element in elements
+        ).execute()
+
+        db.close()
+
+    def dump_chunks(self, elements: list[dict]) -> None:
+        """
+        Store elements in a JSON file(s) and SQLite database(s).
+        If several chunks are requested, the files will be suffixed with the chunk index.
+        """
+        assert (
+            len(elements) >= self.chunks_number
+        ), f"Too few elements have been retrieved to distribute workflow among {self.chunks_number} branches"
+
+        for index, chunk_elts in enumerate(
+            split_chunks(elements, self.chunks_number),
+            start=1,
+        ):
+            self.dump_json(
+                elements=[
+                    {
+                        "id": element["id"],
+                        "type": element["type"],
+                    }
+                    for element in chunk_elts
+                ],
+                **(
+                    {"filename": f"elements_chunk_{index}.json"}
+                    if self.chunks_number > 1
+                    else {}
+                ),
+            )
+            self.dump_sqlite(
+                elements=chunk_elts,
+                **(
+                    {"filename": f"db_{index}.sqlite"} if self.chunks_number > 1 else {}
+                ),
+            )
+
+        logger.info(
+            f"Added {len(elements)} element{'s'[:len(elements) > 1]} to workflow configuration"
+        )
+
+    def list_process_elements(self) -> list[dict]:
+        """
+        List all elements linked to this process and remove duplicates
+        """
+        assert self.process_information.get(
+            "corpus"
+        ), "This worker only supports processes on corpora."
+
+        corpus = self.request("RetrieveCorpus", id=self.process_information["corpus"])
+        type_slugs = {
+            element_type["id"]: element_type["slug"] for element_type in corpus["types"]
+        }
+
+        elements = [
+            {**element, "type": type_slugs[element["type_id"]]}
+            for element in self.api_client.paginate(
+                "ListProcessElements",
+                id=self.process_information["id"],
+                with_image=self.use_cache,
+                allow_missing_data=True,
+                page_size=INIT_PAGE_SIZE,
+            )
+        ]
+        # Use a dict to make elements unique by ID, then turn them back into a elements.json-compatible list
+        unique_elements = OrderedDict(
+            [(element["id"], element) for element in elements]
+        )
+
+        logger.info(
+            f"Retrieved {len(unique_elements)} element{'s'[:len(unique_elements) > 1]} from process {self.process_information['id']}"
+        )
+
+        duplicate_count = len(elements) - len(unique_elements)
+        if duplicate_count:
+            logger.warning(f"{duplicate_count} duplicate elements have been ignored.")
+
+        if not unique_elements:
+            logger.error("No elements found, aborting workflow.")
+            sys.exit(1)
+
+        return list(unique_elements.values())
+
+    def check_worker_activity(self) -> bool:
+        """
+        Check if workers activity associated to this process is in a pending state
+        """
+        activity_state = ActivityState(
+            self.request("RetrieveProcess", id=self.process_information["id"])[
+                "activity_state"
+            ]
+        )
+        if activity_state == ActivityState.Error:
+            logger.error(
+                "Worker activities could not be initialized. Please report this incident to an instance administrator."
+            )
+            sys.exit(1)
+        return activity_state == ActivityState.Ready
+
+    def await_worker_activity(self) -> None:
+        """
+        Worker activities are initialized asynchronously after a process has been started.
+        This worker should be running until all activities have moved to `Ready`.
+        """
+        if (
+            ActivityState(self.process_information["activity_state"])
+            == ActivityState.Disabled
+        ):
+            return
+
+        logger.info("Awaiting worker activities initialization")
+        # Await worker activities to be initialized for 0, 2, 4, 8 seconds up to an hour
+        timer = 1
+        while True:
+            if self.check_worker_activity():
+                break
+            timer *= 2
+            if timer >= 3600:
+                logger.error(
+                    f"Workers activity not initialized {int(timer/60)} minutes after starting the process."
+                    " Please report this incident to an instance administrator."
+                )
+                raise Exception("Worker activity timeout")
+            sleep(timer)
+
+    def process(self) -> None:
+        elements = self.list_process_elements()
+        self.dump_chunks(elements)
+
+        self.await_worker_activity()
+
+    def run(self) -> None:
+        self.configure()
+        self.process()
 
 
 def main() -> None:
-    Demo(description="Worker to initialize Arkindex elements to process").run()
+    InitElementsWorker(
+        description="Worker to initialize Arkindex elements to process"
+    ).run()
 
 
 if __name__ == "__main__":
-- 
GitLab