diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py
index f04319f5c464ffa904cfb57e62379cb2a7e5d0d0..a608ab93e5cfe626b273d022966282067a9fab89 100644
--- a/worker_generic_training_dataset/db.py
+++ b/worker_generic_training_dataset/db.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from typing import NamedTuple
 
-from arkindex_export import Classification, ElementPath, Image
+from arkindex_export import Classification, Image
 from arkindex_export.models import (
     Element,
     Entity,
@@ -9,6 +9,7 @@ from arkindex_export.models import (
     Transcription,
     TranscriptionEntity,
 )
+from arkindex_export.queries import list_children
 from arkindex_worker.cache import (
     CachedElement,
     CachedEntity,
@@ -92,45 +93,8 @@ def retrieve_entities(transcription: CachedTranscription):
     return zip(*data)
 
 
-def list_children(parent_id):
-    # First, build the base query to get direct children
-    base = (
-        ElementPath.select(
-            ElementPath.child_id, ElementPath.parent_id, ElementPath.ordering
-        )
-        .where(ElementPath.parent_id == parent_id)
-        .cte("children", recursive=True, columns=("child_id", "parent_id", "ordering"))
-    )
-
-    # Then build the second recursive query, using an alias to join both queries on the same table
-    EP = ElementPath.alias()
-    recursion = EP.select(EP.child_id, EP.parent_id, EP.ordering).join(
-        base, on=(EP.parent_id == base.c.child_id)
-    )
-
-    # Combine both queries, using UNION and not UNION ALL to deduplicate parents
-    # that might be found multiple times with complex element structures
-    cte = base.union(recursion)
-
-    # And load all the elements found in the CTE
-    query = (
-        Element.select(
-            Element.id,
-            Element.type,
-            Image.id.alias("image_id"),
-            Image.width.alias("image_width"),
-            Image.height.alias("image_height"),
-            Image.url.alias("image_url"),
-            Element.polygon,
-            Element.rotation_angle,
-            Element.mirrored,
-            Element.worker_version,
-            Element.confidence,
-            cte.c.parent_id,
-        )
-        .with_cte(cte)
-        .join(cte, on=(Element.id == cte.c.child_id))
-        .join(Image, on=(Element.image_id == Image.id))
-        .order_by(cte.c.ordering.asc())
-    )
+def get_children(parent_id, element_type=None):
+    query = list_children(parent_id).join(Image)
+    if element_type:
+        query = query.where(Element.type == element_type)
     return query
diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py
index 47f2bfc44638a46328c9f0551a6c034c69488fd9..39b034ce98001e23880d7970fc2d6a1027542e75 100644
--- a/worker_generic_training_dataset/utils.py
+++ b/worker_generic_training_dataset/utils.py
@@ -30,7 +30,7 @@ def bounding_box(polygon: list):
 def build_image_url(element):
     x, y, width, height = bounding_box(ast.literal_eval(element.polygon))
     return urljoin(
-        element.image_url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
+        element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
     )
 
 
@@ -46,7 +46,7 @@ def download_image(element, folder: Path):
         try:
             image = iio.imread(build_image_url(element))
             cv2.imwrite(
-                str(folder / f"{element.image_id}.png"),
+                str(folder / f"{element.image.id}.jpg"),
                 cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
             )
             break
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index f24e062f0d9cde9d5e88c19fd480fc23f2f65a9c..4c475a4ae5331d271f286a6ab12da4e94002127b 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -1,8 +1,11 @@
 # -*- coding: utf-8 -*-
 import logging
 import operator
+import shutil
 import tempfile
 from pathlib import Path
+from typing import Optional
+from uuid import UUID
 
 from apistar.exceptions import ErrorResponse
 from arkindex_export import open_database
@@ -18,9 +21,9 @@ from arkindex_worker.cache import (
 )
 from arkindex_worker.cache import db as cache_database
 from arkindex_worker.cache import init_cache_db
-from arkindex_worker.worker import ElementsWorker
+from arkindex_worker.worker.base import BaseWorker
 from worker_generic_training_dataset.db import (
-    list_children,
+    get_children,
     list_classifications,
     list_transcriptions,
     retrieve_element,
@@ -36,15 +39,27 @@ logger = logging.getLogger(__name__)
 BULK_BATCH_SIZE = 50
 
 
-class DatasetExtractor(ElementsWorker):
+class DatasetExtractor(BaseWorker):
     def configure(self):
-        super().configure()
+        self.args = self.parser.parse_args()
+        if self.is_read_only:
+            super().configure_for_developers()
+            self.process_information = {
+                "train_folder_id": "47a0e07b-d07a-4969-aced-44450d132f0d",
+                "validation_folder_id": "8cbc4b53-9e07-4a72-b4e6-93f7f5b0cbed",
+                "test_folder_id": "659a37ea-3b26-42f0-8b65-78964f9e433e",
+            }
+        else:
+            super().configure()
 
         # database arg is mandatory in dev mode
         assert (
             not self.is_read_only or self.args.database is not None
         ), "`--database` arg is mandatory in developer mode."
 
+        # Read process information
+        self.read_training_related_information()
+
         # Download corpus
         self.download_latest_export()
 
@@ -54,6 +69,27 @@ class DatasetExtractor(ElementsWorker):
         # Cached Images downloaded and created in DB
         self.cached_images = dict()
 
+    def read_training_related_information(self):
+        """
+        Read from process information
+            - train_folder_id
+            - validation_folder_id
+            - test_folder_id (optional)
+            - model_id
+        """
+        logger.info("Retrieving information from process_information")
+
+        train_folder_id = self.process_information.get("train_folder_id")
+        assert train_folder_id, "A training folder id is necessary to use this worker"
+        self.training_folder_id = UUID(train_folder_id)
+
+        val_folder_id = self.process_information.get("validation_folder_id")
+        assert val_folder_id, "A validation folder id is necessary to use this worker"
+        self.validation_folder_id = UUID(val_folder_id)
+
+        test_folder_id = self.process_information.get("test_folder_id")
+        self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None
+
     def initialize_database(self):
         # Create db at
         # - self.workdir / "db.sqlite" in Arkindex mode
@@ -109,9 +145,11 @@ class DatasetExtractor(ElementsWorker):
             )
             raise e
 
-    def insert_element(self, element, image_folder: Path, root: bool = False):
+    def insert_element(
+        self, element, image_folder: Path, parent_id: Optional[str] = None
+    ):
         logger.info(f"Processing element ({element.id})")
-        if element.image_id and element.image_id not in self.cached_images:
+        if element.image and element.image.id not in self.cached_images:
             # Download image
             logger.info("Downloading image")
             download_image(element, folder=image_folder)
@@ -119,19 +157,19 @@ class DatasetExtractor(ElementsWorker):
             logger.info("Inserting image")
             # Store images in case some other elements use it as well
             self.cached_images[element.image_id] = CachedImage.create(
-                id=element.image_id,
-                width=element.image_width,
-                height=element.image_height,
-                url=element.image_url,
+                id=element.image.id,
+                width=element.image.width,
+                height=element.image.height,
+                url=element.image.url,
             )
 
         # Insert element
         logger.info("Inserting element")
         cached_element = CachedElement.create(
             id=element.id,
-            parent_id=None if root else element.parent_id,
+            parent_id=parent_id,
             type=element.type,
-            image=self.cached_images[element.image_id] if element.image_id else None,
+            image=self.cached_images[element.image.id] if element.image else None,
             polygon=element.polygon,
             rotation_angle=element.rotation_angle,
             mirrored=element.mirrored,
@@ -196,29 +234,56 @@ class DatasetExtractor(ElementsWorker):
                     batch_size=BULK_BATCH_SIZE,
                 )
 
-    def process_element(self, element):
-        # Where to save the downloaded images
-        image_folder = Path(tempfile.mkdtemp())
-
+    def process_split(self, split_name, split_id, image_folder):
         logger.info(
-            f"Filling the Base-Worker cache with information from children under element ({element.id})"
+            f"Filling the Base-Worker cache with information from children under element ({split_id})"
         )
         # Fill cache
         # Retrieve parent and create parent
-        parent = retrieve_element(element.id)
-        self.insert_element(parent, image_folder, root=True)
-        # Create children
-        children = list_children(parent_id=element.id)
-        nb_children = children.count()
-        for idx, child in enumerate(children.namedtuples(), start=1):
-            logger.info(f"Processing child ({idx}/{nb_children})")
-            self.insert_element(child, image_folder)
+        parent = retrieve_element(split_id)
+        self.insert_element(parent, image_folder)
+
+        # First list all pages
+        pages = get_children(parent_id=split_id, element_type="page")
+        nb_pages = pages.count()
+        for idx, page in enumerate(pages, start=1):
+            logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
+
+            # Insert page
+            self.insert_element(page, image_folder, parent_id=split_id)
+
+            # List children
+            children = get_children(parent_id=page.id)
+            nb_children = children.count()
+            for child_idx, child in enumerate(children, start=1):
+                logger.info(f"Processing child ({child_idx}/{nb_children})")
+                # Insert child
+                self.insert_element(child, image_folder, parent_id=page.id)
+
+    def run(self):
+        self.configure()
+
+        # Where to save the downloaded images
+        image_folder = Path(tempfile.mkdtemp())
+
+        # Iterate over given split
+        for split_name, split_id in [
+            ("Train", self.training_folder_id),
+            ("Validation", self.validation_folder_id),
+            ("Test", self.testing_folder_id),
+        ]:
+            if not split_id:
+                continue
+            self.process_split(split_name, split_id, image_folder)
 
         # TAR + ZSTD Image folder and store as task artifact
         zstd_archive_path = Path(self.work_dir) / "arkindex_data.zstd"
         logger.info(f"Compressing the images to {zstd_archive_path}")
         create_tar_zstd_archive(folder_path=image_folder, destination=zstd_archive_path)
 
+        # Cleanup image folder
+        shutil.rmtree(image_folder)
+
 
 def main():
     DatasetExtractor(