diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 454663899ebd7c982c69f7a97ab8ac02fac47090..71f274f34c7d7da4be4178fd6f374263046332a4 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -1003,11 +1003,14 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
             DatasetSet(dataset_id=clone.id, name=set.name)
             for set in dataset.sets.all()
         ])
+        set_map = {set.name: set.id for set in cloned_sets}
+
         # Associate all elements to the clone
         DatasetElement.objects.bulk_create([
-            DatasetElement(element_id=elt_id, set=next(new_set for new_set in cloned_sets if new_set.name == set_name))
+            DatasetElement(element_id=elt_id, set=set_map[set_name])
             for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id)
             .values_list("element_id", "set__name")
+            .iterator()
         ])
 
         # Add the set counts to the API response