diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 5501704746d70367b9af8c4e51c59555f290c39f..43109d899dfd625c43951357d42624a688f69b23 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -336,7 +336,7 @@ class DatasetExtractor(DatasetWorker): """ logger.info(f"Inserting dataset ({dataset.id})") with cache_database.atomic(): - return CachedDataset.create( + self.cached_dataset = CachedDataset.create( id=dataset.id, name=dataset.name, state=dataset.state, @@ -347,11 +347,11 @@ class DatasetExtractor(DatasetWorker): # Configure temporary storage for the dataset data (cache + images) self.configure_storage() - splits = self.list_dataset_elements_per_split(dataset) - self.cached_dataset = self.insert_dataset(dataset) + # Insert dataset in cache database + self.insert_dataset(dataset) # Iterate over given splits - for split_name, elements in splits: + for split_name, elements in self.list_dataset_elements_per_split(dataset): casted_elements = list(map(_format_element, elements)) self.process_split(split_name, casted_elements)