diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 5501704746d70367b9af8c4e51c59555f290c39f..43109d899dfd625c43951357d42624a688f69b23 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -336,7 +336,7 @@ class DatasetExtractor(DatasetWorker):
         """
         logger.info(f"Inserting dataset ({dataset.id})")
         with cache_database.atomic():
-            return CachedDataset.create(
+            self.cached_dataset = CachedDataset.create(
                 id=dataset.id,
                 name=dataset.name,
                 state=dataset.state,
@@ -347,11 +347,11 @@ class DatasetExtractor(DatasetWorker):
         # Configure temporary storage for the dataset data (cache + images)
         self.configure_storage()
 
-        splits = self.list_dataset_elements_per_split(dataset)
-        self.cached_dataset = self.insert_dataset(dataset)
+        # Insert dataset in cache database
+        self.insert_dataset(dataset)
 
         # Iterate over given splits
-        for split_name, elements in splits:
+        for split_name, elements in self.list_dataset_elements_per_split(dataset):
             casted_elements = list(map(_format_element, elements))
             self.process_split(split_name, casted_elements)