diff --git a/tests/conftest.py b/tests/conftest.py
index 0121c272d8c5ef29181e34e3b6cb88cc871a84f7..b0da596c961c47f0afc907acc0435a883f7a1dc3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,6 +22,8 @@ def setup_environment(responses, monkeypatch):
         "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
     )
     responses.add_passthru(schema_url)
+    # To allow image download
+    responses.add_passthru("https://europe-gamma.iiif.teklia.com/iiif/2")
 
     # Set schema url in environment
     os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
diff --git a/tests/test_worker.py b/tests/test_worker.py
index 3c3703f63277493da0e2b6ab75ec3757f3e70bd0..cf638a86a38e8513721c3ff9a78195427225a986 100644
--- a/tests/test_worker.py
+++ b/tests/test_worker.py
@@ -1,8 +1,6 @@
 # -*- coding: utf-8 -*-
 
-import tempfile
 from argparse import Namespace
-from pathlib import Path
 
 from arkindex_worker.cache import (
     CachedClassification,
@@ -15,7 +13,7 @@ from arkindex_worker.cache import (
 from worker_generic_training_dataset.worker import DatasetExtractor
 
 
-def test_process_split():
+def test_process_split(tmp_path):
     # Parent is train folder
     parent_id = "a0c4522d-2d80-4766-a01c-b9d686f41f6a"
 
@@ -26,9 +24,9 @@ def test_process_split():
     worker.cached_images = dict()
 
     # Where to save the downloaded images
-    image_folder = Path(tempfile.mkdtemp())
+    worker.image_folder = tmp_path
 
-    worker.process_split("train", parent_id, image_folder)
+    worker.process_split("train", parent_id)
 
     # Should have created two pages under root folder
     assert (
@@ -53,9 +51,9 @@ def test_process_split():
 
     # Should have created two images
     assert CachedImage.select().count() == 2
-    assert sorted(image_folder.rglob("*")) == [
-        image_folder / "80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg",
-        image_folder / "e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg",
+    assert sorted(tmp_path.rglob("*")) == [
+        tmp_path / "80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg",
+        tmp_path / "e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg",
     ]
 
     # Should have created a transcription linked to first line of first page
diff --git a/worker_generic_training_dataset/cache.py b/worker_generic_training_dataset/cache.py
deleted file mode 100644
index b3b40634c50d4f3213bda4537ed5d8ae9c617aa3..0000000000000000000000000000000000000000
--- a/worker_generic_training_dataset/cache.py
+++ /dev/null
@@ -1,300 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Database mappings and helper methods for the experimental worker caching feature.
-
-On methods that support caching, the database will be used for all reads,
-and writes will go both to the Arkindex API and the database,
-reducing network usage.
-"""
-
-import json
-from pathlib import Path
-from typing import Optional, Union
-
-from arkindex_worker import logger
-from peewee import (
-    BooleanField,
-    CharField,
-    Check,
-    CompositeKey,
-    Field,
-    FloatField,
-    ForeignKeyField,
-    IntegerField,
-    Model,
-    OperationalError,
-    SqliteDatabase,
-    TextField,
-    UUIDField,
-)
-from PIL import Image
-
-db = SqliteDatabase(None)
-
-
-class JSONField(Field):
-    """
-    A Peewee field that stores a JSON payload as a string and parses it automatically.
-    """
-
-    field_type = "text"
-
-    def db_value(self, value):
-        if value is None:
-            return
-        return json.dumps(value)
-
-    def python_value(self, value):
-        if value is None:
-            return
-        return json.loads(value)
-
-
-class Version(Model):
-    """
-    Cache version table, used to warn about incompatible cache databases
-    when a worker uses an outdated version of ``base-worker``.
-    """
-
-    version = IntegerField(primary_key=True)
-
-    class Meta:
-        database = db
-        table_name = "version"
-
-
-class CachedImage(Model):
-    """
-    Cache image table
-    """
-
-    id = UUIDField(primary_key=True)
-    width = IntegerField()
-    height = IntegerField()
-    url = TextField()
-
-    class Meta:
-        database = db
-        table_name = "images"
-
-
-class CachedElement(Model):
-    """
-    Cache element table
-    """
-
-    id = UUIDField(primary_key=True)
-    parent_id = UUIDField(null=True)
-    type = CharField(max_length=50)
-    image = ForeignKeyField(CachedImage, backref="elements", null=True)
-    polygon = JSONField(null=True)
-    rotation_angle = IntegerField(default=0)
-    mirrored = BooleanField(default=False)
-    initial = BooleanField(default=False)
-    # Needed to filter elements with cache
-    worker_version_id = UUIDField(null=True)
-    worker_run_id = UUIDField(null=True)
-    confidence = FloatField(null=True)
-
-    class Meta:
-        database = db
-        table_name = "elements"
-
-    def open_image(self, *args, max_size: Optional[int] = None, **kwargs) -> Image:
-        """
-        Open this element's image as a Pillow image.
-        This does not crop the image to the element's polygon.
-        IIIF servers with maxWidth, maxHeight or maxArea restrictions on image size are not supported.
-
-        :param *args: Positional arguments passed to [arkindex_worker.image.open_image][]
-        :param max_size: Subresolution of the image.
-        :param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][]
-        :raises ValueError: When this element does not have an image ID or a polygon.
-        :return: A Pillow image.
-        """
-        from arkindex_worker.image import open_image, polygon_bounding_box
-
-        if not self.image_id or not self.polygon:
-            raise ValueError(f"Element {self.id} has no image")
-
-        # Always fetch the image from the bounding box when size differs from full image
-        bounding_box = polygon_bounding_box(self.polygon)
-        if (
-            bounding_box.width != self.image.width
-            or bounding_box.height != self.image.height
-        ):
-            box = f"{bounding_box.x},{bounding_box.y},{bounding_box.width},{bounding_box.height}"
-        else:
-            box = "full"
-
-        if max_size is None:
-            resize = "full"
-        else:
-            # Do not resize for polygons that do not exactly match the images
-            # as the resize is made directly by the IIIF server using the box parameter
-            if (
-                bounding_box.width != self.image.width
-                or bounding_box.height != self.image.height
-            ):
-                resize = "full"
-
-            # Do not resize when the image is below the maximum size
-            elif self.image.width <= max_size and self.image.height <= max_size:
-                resize = "full"
-            else:
-                ratio = max_size / max(self.image.width, self.image.height)
-                new_width, new_height = int(self.image.width * ratio), int(
-                    self.image.height * ratio
-                )
-                resize = f"{new_width},{new_height}"
-
-        url = self.image.url
-        if not url.endswith("/"):
-            url += "/"
-
-        return open_image(
-            f"{url}{box}/{resize}/0/default.jpg",
-            *args,
-            rotation_angle=self.rotation_angle,
-            mirrored=self.mirrored,
-            **kwargs,
-        )
-
-
-class CachedTranscription(Model):
-    """
-    Cache transcription table
-    """
-
-    id = UUIDField(primary_key=True)
-    element = ForeignKeyField(CachedElement, backref="transcriptions")
-    text = TextField()
-    confidence = FloatField(null=True)
-    orientation = CharField(max_length=50)
-    # Needed to filter transcriptions with cache
-    worker_version_id = UUIDField(null=True)
-    worker_run_id = UUIDField(null=True)
-
-    class Meta:
-        database = db
-        table_name = "transcriptions"
-
-
-class CachedClassification(Model):
-    """
-    Cache classification table
-    """
-
-    id = UUIDField(primary_key=True)
-    element = ForeignKeyField(CachedElement, backref="classifications")
-    class_name = TextField()
-    confidence = FloatField()
-    state = CharField(max_length=10)
-    worker_run_id = UUIDField(null=True)
-
-    class Meta:
-        database = db
-        table_name = "classifications"
-
-
-class CachedEntity(Model):
-    """
-    Cache entity table
-    """
-
-    id = UUIDField(primary_key=True)
-    type = CharField(max_length=50)
-    name = TextField()
-    validated = BooleanField(default=False)
-    metas = JSONField(null=True)
-    worker_run_id = UUIDField(null=True)
-
-    class Meta:
-        database = db
-        table_name = "entities"
-
-
-class CachedTranscriptionEntity(Model):
-    """
-    Cache transcription entity table
-    """
-
-    transcription = ForeignKeyField(
-        CachedTranscription, backref="transcription_entities"
-    )
-    entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
-    offset = IntegerField(constraints=[Check("offset >= 0")])
-    length = IntegerField(constraints=[Check("length > 0")])
-    worker_run_id = UUIDField(null=True)
-    confidence = FloatField(null=True)
-
-    class Meta:
-        primary_key = CompositeKey("transcription", "entity")
-        database = db
-        table_name = "transcription_entities"
-
-
-# Add all the managed models in that list
-# It's used here, but also in unit tests
-MODELS = [
-    CachedImage,
-    CachedElement,
-    CachedTranscription,
-    CachedClassification,
-    CachedEntity,
-    CachedTranscriptionEntity,
-]
-SQL_VERSION = 2
-
-
-def init_cache_db(path: str):
-    """
-    Create the cache database on the given path
-    :param path: Where the new database should be created
-    """
-    db.init(
-        path,
-        pragmas={
-            # SQLite ignores foreign keys and check constraints by default!
-            "foreign_keys": 1,
-            "ignore_check_constraints": 0,
-        },
-    )
-    db.connect()
-    logger.info(f"Connected to cache on {path}")
-
-
-def create_tables():
-    """
-    Creates the tables in the cache DB only if they do not already exist.
-    """
-    db.create_tables(MODELS)
-
-
-def create_version_table():
-    """
-    Creates the Version table in the cache DB.
-    This step must be independent from other tables creation since we only
-    want to create the table and add the one and only Version entry when the
-    cache is created from scratch.
-    """
-    db.create_tables([Version])
-    Version.create(version=SQL_VERSION)
-
-
-def check_version(cache_path: Union[str, Path]):
-    """
-    Check the validity of the SQLite version
-
-    :param cache_path: Path towards a local SQLite database
-    """
-    with SqliteDatabase(cache_path) as provided_db:
-        with provided_db.bind_ctx([Version]):
-            try:
-                version = Version.get().version
-            except OperationalError:
-                version = None
-
-            assert (
-                version == SQL_VERSION
-            ), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py
index e81e652051f7e93f62dca597857506b7196d1b01..bb18f7edeb4082116f5e093751ab28a41e9db746 100644
--- a/worker_generic_training_dataset/db.py
+++ b/worker_generic_training_dataset/db.py
@@ -1,7 +1,6 @@
 # -*- coding: utf-8 -*-
-from typing import NamedTuple
 
-from arkindex_export import Classification, Image
+from arkindex_export import Classification
 from arkindex_export.models import (
     Element,
     Entity,
@@ -9,15 +8,7 @@ from arkindex_export.models import (
     Transcription,
     TranscriptionEntity,
 )
-from arkindex_export.queries import list_children
-from arkindex_worker.cache import (
-    CachedElement,
-    CachedEntity,
-    CachedTranscription,
-    CachedTranscriptionEntity,
-)
-
-DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
+from arkindex_worker.cache import CachedElement, CachedTranscription
 
 
 def retrieve_element(element_id: str):
@@ -28,72 +19,14 @@ def list_classifications(element_id: str):
     return Classification.select().where(Classification.element_id == element_id)
 
 
-def parse_transcription(transcription: Transcription, element: CachedElement):
-    return CachedTranscription(
-        id=transcription.id,
-        element=element,
-        text=transcription.text,
-        # Dodge not-null constraint for now
-        confidence=transcription.confidence or 1.0,
-        orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
-        worker_version_id=transcription.worker_version.id
-        if transcription.worker_version
-        else None,
-    )
-
-
 def list_transcriptions(element: CachedElement):
-    query = Transcription.select().where(Transcription.element_id == element.id)
-    return [parse_transcription(transcription, element) for transcription in query]
-
-
-def parse_entities(data: NamedTuple, transcription: CachedTranscription):
-    entity = CachedEntity(
-        id=data.entity_id,
-        type=data.type,
-        name=data.name,
-        validated=data.validated,
-        metas=data.metas,
-    )
-    return entity, CachedTranscriptionEntity(
-        id=data.transcription_entity_id,
-        transcription=transcription,
-        entity=entity,
-        offset=data.offset,
-        length=data.length,
-        confidence=data.confidence,
-    )
+    return Transcription.select().where(Transcription.element_id == element.id)
 
 
-def retrieve_entities(transcription: CachedTranscription):
-    query = (
-        TranscriptionEntity.select(
-            TranscriptionEntity.id.alias("transcription_entity_id"),
-            TranscriptionEntity.length.alias("length"),
-            TranscriptionEntity.offset.alias("offset"),
-            TranscriptionEntity.confidence.alias("confidence"),
-            Entity.id.alias("entity_id"),
-            EntityType.name.alias("type"),
-            Entity.name,
-            Entity.validated,
-            Entity.metas,
-        )
+def list_transcription_entities(transcription: CachedTranscription):
+    return (
+        TranscriptionEntity.select()
         .where(TranscriptionEntity.transcription_id == transcription.id)
         .join(Entity, on=TranscriptionEntity.entity)
         .join(EntityType, on=Entity.type)
     )
-    data = [
-        parse_entities(entity_data, transcription)
-        for entity_data in query.namedtuples()
-    ]
-    if not data:
-        return [], []
-
-    return zip(*data)
-
-
-def get_children(parent_id: UUID, element_type=None):
-    query = list_children(parent_id).join(Image)
-    if element_type:
-        query = query.where(Element.type == element_type)
-    return query
diff --git a/worker_generic_training_dataset/exceptions.py b/worker_generic_training_dataset/exceptions.py
deleted file mode 100644
index 062d580a5da6a9ca487dc599dd6732ec60868257..0000000000000000000000000000000000000000
--- a/worker_generic_training_dataset/exceptions.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-class ElementProcessingError(Exception):
-    """
-    Raised when a problem is encountered while processing an element
-    """
-
-    element_id: str
-    """
-    ID of the element being processed.
-    """
-
-    def __init__(self, element_id: str, *args: object) -> None:
-        super().__init__(*args)
-        self.element_id = element_id
-
-
-class ImageDownloadError(ElementProcessingError):
-    """
-    Raised when an element's image could not be downloaded
-    """
-
-    error: Exception
-    """
-    Error encountered.
-    """
-
-    def __init__(self, element_id: str, error: Exception, *args: object) -> None:
-        super().__init__(element_id, *args)
-        self.error = error
-
-    def __str__(self) -> str:
-        return (
-            f"Couldn't retrieve image of element ({self.element_id}: {str(self.error)})"
-        )
diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py
index 01ee0f9c5d112d78403f460ee47207d766f4a932..6bfb51910ffe43e299140274bf06d3e3721a7d40 100644
--- a/worker_generic_training_dataset/utils.py
+++ b/worker_generic_training_dataset/utils.py
@@ -1,16 +1,9 @@
 # -*- coding: utf-8 -*-
 import ast
 import logging
-import time
-from pathlib import Path
 from urllib.parse import urljoin
 
-import cv2
-import imageio.v2 as iio
-from worker_generic_training_dataset.exceptions import ImageDownloadError
-
 logger = logging.getLogger(__name__)
-MAX_RETRIES = 5
 
 
 def bounding_box(polygon: list):
@@ -28,27 +21,3 @@ def build_image_url(element):
     return urljoin(
         element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
     )
-
-
-def download_image(element, folder: Path):
-    """
-    Download the image to `folder / {element.image.id}.jpg`
-    """
-    tries = 1
-    # retry loop
-    while True:
-        if tries > MAX_RETRIES:
-            raise ImageDownloadError(element.id, Exception("Maximum retries reached."))
-        try:
-            image = iio.imread(build_image_url(element))
-            cv2.imwrite(
-                str(folder / f"{element.image.id}.jpg"),
-                cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
-            )
-            break
-        except TimeoutError:
-            logger.warning("Timeout, retry in 1 second.")
-            time.sleep(1)
-            tries += 1
-        except Exception as e:
-            raise ImageDownloadError(element.id, e)
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 81743500786c39f1571c2935874b6dac823c8d03..518f5624ea9ae1aa200c09bbc45110cd77a73e8d 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -1,14 +1,14 @@
 # -*- coding: utf-8 -*-
 import logging
 import operator
-import shutil
 import tempfile
 from pathlib import Path
-from typing import Optional
+from typing import List, Optional
 from uuid import UUID
 
 from apistar.exceptions import ErrorResponse
-from arkindex_export import open_database
+from arkindex_export import Element, Image, open_database
+from arkindex_export.queries import list_children
 from arkindex_worker.cache import (
     CachedClassification,
     CachedElement,
@@ -21,20 +21,21 @@ from arkindex_worker.cache import (
 )
 from arkindex_worker.cache import db as cache_database
 from arkindex_worker.cache import init_cache_db
+from arkindex_worker.image import download_image
 from arkindex_worker.utils import create_tar_zst_archive
 from arkindex_worker.worker.base import BaseWorker
 from worker_generic_training_dataset.db import (
-    get_children,
     list_classifications,
+    list_transcription_entities,
     list_transcriptions,
     retrieve_element,
-    retrieve_entities,
 )
-from worker_generic_training_dataset.utils import download_image
+from worker_generic_training_dataset.utils import build_image_url
 
 logger = logging.getLogger(__name__)
 
 BULK_BATCH_SIZE = 50
+DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
 
 
 class DatasetExtractor(BaseWorker):
@@ -49,11 +50,6 @@ class DatasetExtractor(BaseWorker):
             logger.info("Overriding with user_configuration")
             self.config.update(self.user_configuration)
 
-        # database arg is mandatory in dev mode
-        assert (
-            not self.is_read_only or self.args.database is not None
-        ), "`--database` arg is mandatory in developer mode."
-
         # Read process information
         self.read_training_related_information()
 
@@ -66,6 +62,10 @@ class DatasetExtractor(BaseWorker):
         # CachedImage downloaded and created in DB
         self.cached_images = dict()
 
+        # Where to save the downloaded images
+        self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
+        logger.info(f"Images will be saved at `{self.image_folder}`.")
+
     def read_training_related_information(self):
         """
         Read from process information
@@ -87,15 +87,11 @@ class DatasetExtractor(BaseWorker):
         self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None
 
     def initialize_database(self):
-        # Create db at
-        # - self.workdir / "db.sqlite" in Arkindex mode
-        # - self.args.database in dev mode
-        database_path = (
-            self.args.database
-            if self.is_read_only
-            else self.work_dir / "db.sqlite"
-        )
-        if database_path.exists():
+        """
+        Create an SQLite database compatible with base-worker cache and initialize it.
+        """
+        database_path = self.work_dir / "db.sqlite"
+        # Remove previous execution result if present
         database_path.unlink(missing_ok=True)
 
         init_cache_db(database_path)
@@ -105,12 +101,17 @@ class DatasetExtractor(BaseWorker):
         create_tables()
 
     def download_latest_export(self):
-        # Find export of corpus
+        """
+        Download the latest export of the current corpus.
+        Export must be in `"done"` state.
+        """
         try:
-            exports = self.api_client.request(
-                "ListExports",
-                id=self.corpus_id,
-            )["results"]
+            exports = list(
+                self.api_client.paginate(
+                    "ListExports",
+                    id=self.corpus_id,
+                )
+            )
         except ErrorResponse as e:
             logger.error(
                 f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
@@ -123,7 +124,9 @@ class DatasetExtractor(BaseWorker):
             key=operator.itemgetter("updated"),
             reverse=True,
         )
-        assert len(exports) > 0, f"No available exports found for the corpus {self.corpus_id}."
+        assert (
+            len(exports) > 0
+        ), f"No available exports found for the corpus {self.corpus_id}."
 
         # Download latest it in a tmpfile
         try:
@@ -137,18 +140,129 @@ class DatasetExtractor(BaseWorker):
             open_database(self.export.name)
         except ErrorResponse as e:
             logger.error(
-                f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e)}"
+                f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
             )
             raise e
 
-    def insert_element(
-        self, element, image_folder: Path, parent_id: Optional[str] = None
-    ):
+    def insert_classifications(self, element: CachedElement) -> None:
+        logger.info("Listing classifications")
+        classifications = [
+            CachedClassification(
+                id=classification.id,
+                element=element,
+                class_name=classification.class_name,
+                confidence=classification.confidence,
+                state=classification.state,
+            )
+            for classification in list_classifications(element.id)
+        ]
+        if classifications:
+            logger.info(f"Inserting {len(classifications)} classification(s)")
+            with cache_database.atomic():
+                CachedClassification.bulk_create(
+                    model_list=classifications,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+    def insert_transcriptions(
+        self, element: CachedElement
+    ) -> List[CachedTranscription]:
+        logger.info("Listing transcriptions")
+        transcriptions = [
+            CachedTranscription(
+                id=transcription.id,
+                element=element,
+                text=transcription.text,
+                # Dodge not-null constraint for now
+                confidence=transcription.confidence or 1.0,
+                orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
+                worker_version_id=transcription.worker_version.id
+                if transcription.worker_version
+                else None,
+            )
+            for transcription in list_transcriptions(element)
+        ]
+        if not transcriptions:
+            return []
+
+        logger.info(f"Inserting {len(transcriptions)} transcription(s)")
+        with cache_database.atomic():
+            CachedTranscription.bulk_create(
+                model_list=transcriptions,
+                batch_size=BULK_BATCH_SIZE,
+            )
+        return transcriptions
+
+    def insert_entities(self, transcriptions: List[CachedTranscription]):
+        logger.info("Listing entities")
+        extracted_entities = []
+        for transcription in transcriptions:
+            for transcription_entity in list_transcription_entities(transcription):
+                entity = CachedEntity(
+                    id=transcription_entity.entity.id,
+                    type=transcription_entity.entity.type.name,
+                    name=transcription_entity.entity.name,
+                    validated=transcription_entity.entity.validated,
+                    metas=transcription_entity.entity.metas,
+                )
+                extracted_entities.append(
+                    (
+                        entity,
+                        CachedTranscriptionEntity(
+                            id=transcription_entity.id,
+                            transcription=transcription,
+                            entity=entity,
+                            offset=transcription_entity.offset,
+                            length=transcription_entity.length,
+                            confidence=transcription_entity.confidence,
+                        ),
+                    )
+                )
+        if not extracted_entities:
+            # Early return if no entities found
+            return
+
+        entities, transcription_entities = zip(*extracted_entities)
+
+        # First insert entities since they are foreign keys on transcription entities
+        logger.info(f"Inserting {len(entities)} entities")
+        with cache_database.atomic():
+            CachedEntity.bulk_create(
+                model_list=entities,
+                batch_size=BULK_BATCH_SIZE,
+            )
+
+        if transcription_entities:
+            # Insert transcription entities
+            logger.info(
+                f"Inserting {len(transcription_entities)} transcription entities"
+            )
+            with cache_database.atomic():
+                CachedTranscriptionEntity.bulk_create(
+                    model_list=transcription_entities,
+                    batch_size=BULK_BATCH_SIZE,
+                )
+
+    def insert_element(self, element: Element, parent_id: Optional[str] = None):
+        """
+        Insert the given element's children in the cache database.
+        Their image will also be saved to disk, if they weren't already.
+
+        The insertion of an element includes:
+        - its classifications
+        - its transcriptions
+        - its transcriptions' entities (both Entity and TranscriptionEntity)
+
+        :param element: Element to insert. All its children will be inserted as well.
+        :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
+        """
         logger.info(f"Processing element ({element.id})")
         if element.image and element.image.id not in self.cached_images:
             # Download image
             logger.info("Downloading image")
-            download_image(element, folder=image_folder)
+            download_image(url=build_image_url(element)).save(
+                self.image_folder / f"{element.image.id}.jpg"
+            )
             # Insert image
             logger.info("Inserting image")
             # Store images in case some other elements use it as well
@@ -169,8 +283,6 @@ class DatasetExtractor(BaseWorker):
             polygon=element.polygon,
             rotation_angle=element.rotation_angle,
             mirrored=element.mirrored,
-            worker_version_id=element.worker_version
-            if element.worker_version
             worker_version_id=element.worker_version.id
             if element.worker_version
             else None,
@@ -178,89 +290,48 @@ class DatasetExtractor(BaseWorker):
         )
 
         # Insert classifications
-        logger.info("Listing classifications")
-        classifications = [
-            CachedClassification(
-                id=classification.id,
-                element=cached_element,
-                class_name=classification.class_name,
-                confidence=classification.confidence,
-                state=classification.state,
-            )
-            for classification in list_classifications(element.id)
-        ]
-        if classifications:
-            logger.info(f"Inserting {len(classifications)} classification(s)")
-            with cache_database.atomic():
-                CachedClassification.bulk_create(
-                    model_list=classifications,
-                    batch_size=BULK_BATCH_SIZE,
-                )
+        self.insert_classifications(cached_element)
 
         # Insert transcriptions
-        logger.info("Listing transcriptions")
-        transcriptions = list_transcriptions(cached_element)
-        if transcriptions:
-            logger.info(f"Inserting {len(transcriptions)} transcription(s)")
-            with cache_database.atomic():
-                CachedTranscription.bulk_create(
-                    model_list=transcriptions,
-                    batch_size=BULK_BATCH_SIZE,
-                )
+        transcriptions = self.insert_transcriptions(cached_element)
 
         # Insert entities
-        logger.info("Listing entities")
-        entities, transcription_entities = zip(*[retrieve_entities(transcription) for transcription in transcriptions))
-
-        if entities:
-            logger.info(f"Inserting {len(entities)} entities")
-            with cache_database.atomic():
-                CachedEntity.bulk_create(
-                    model_list=entities,
-                    batch_size=BULK_BATCH_SIZE,
-                )
-            # Insert transcription entities
-            logger.info(
-                f"Inserting {len(transcription_entities)} transcription entities"
-            )
-            with cache_database.atomic():
-                CachedTranscriptionEntity.bulk_create(
-                    model_list=transcription_entities,
-                    batch_size=BULK_BATCH_SIZE,
-                )
+        self.insert_entities(transcriptions)
 
-    def process_split(self, split_name, split_id, image_folder):
+    def process_split(self, split_name: str, split_id: UUID):
+        """
+        Insert all elements under the given parent folder.
+        - `page` elements are linked to this folder (via parent_id foreign key)
+        - `page` element children are linked to their `page` parent (via parent_id foreign key)
+        """
         logger.info(
             f"Filling the Base-Worker cache with information from children under element ({split_id})"
         )
         # Fill cache
         # Retrieve parent and create parent
         parent = retrieve_element(split_id)
-        self.insert_element(parent, image_folder)
+        self.insert_element(parent)
 
         # First list all pages
-        pages = get_children(parent_id=split_id, element_type="page")
+        pages = list_children(split_id).join(Image).where(Element.type == "page")
         nb_pages = pages.count()
         for idx, page in enumerate(pages, start=1):
             logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
 
             # Insert page
-            self.insert_element(page, image_folder, parent_id=split_id)
+            self.insert_element(page, parent_id=split_id)
 
             # List children
-            children = get_children(parent_id=page.id)
+            children = list_children(page.id)
             nb_children = children.count()
             for child_idx, child in enumerate(children, start=1):
                 logger.info(f"Processing child ({child_idx}/{nb_children})")
                 # Insert child
-                self.insert_element(child, image_folder, parent_id=page.id)
+                self.insert_element(child, parent_id=page.id)
 
     def run(self):
         self.configure()
 
-        # Where to save the downloaded images
-        image_folder = Path(tempfile.mkdtemp())
-
         # Iterate over given split
         for split_name, split_id in [
             ("Train", self.training_folder_id),
@@ -269,20 +340,18 @@ class DatasetExtractor(BaseWorker):
         ]:
             if not split_id:
                 continue
-            self.process_split(split_name, split_id, image_folder)
+            self.process_split(split_name, split_id)
 
         # TAR + ZSTD Image folder and store as task artifact
         zstd_archive_path = self.work_dir / "arkindex_data.zstd"
         logger.info(f"Compressing the images to {zstd_archive_path}")
-        create_tar_zst_archive(source=image_folder, destination=zstd_archive_path)
-
-        # Cleanup image folder
-        shutil.rmtree(image_folder)
+        create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
 
 
 def main():
     DatasetExtractor(
-        description="Fill base-worker cache with information about dataset and extract images", support_cache=True
+        description="Fill base-worker cache with information about dataset and extract images",
+        support_cache=True,
     ).run()