diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 454663899ebd7c982c69f7a97ab8ac02fac47090..71f274f34c7d7da4be4178fd6f374263046332a4 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -1003,11 +1003,14 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): DatasetSet(dataset_id=clone.id, name=set.name) for set in dataset.sets.all() ]) + set_map = {set.name: set.id for set in cloned_sets} + # Associate all elements to the clone DatasetElement.objects.bulk_create([ - DatasetElement(element_id=elt_id, set=next(new_set for new_set in cloned_sets if new_set.name == set_name)) + DatasetElement(element_id=elt_id, set=set_map[set_name]) for elt_id, set_name in DatasetElement.objects.filter(set__dataset_id=dataset.id) .values_list("element_id", "set__name") + .iterator() ]) # Add the set counts to the API response