diff --git a/arkindex/process/models.py b/arkindex/process/models.py
index fa349eb6c435c0f71f4a64c852f7ce1a281cc3f2..e32dd410fe26622c3d7ea6506dac24cea163457b 100644
--- a/arkindex/process/models.py
+++ b/arkindex/process/models.py
@@ -28,6 +28,7 @@ from arkindex.process.managers import (
 from arkindex.project.aws import S3FileMixin, S3FileStatus
 from arkindex.project.fields import ArrayField, MD5HashField
 from arkindex.project.models import IndexableModel
+from arkindex.project.tools import is_prefetched
 from arkindex.project.validators import MaxValueValidator
 from arkindex.training.models import ModelVersion, ModelVersionState
 from arkindex.users.models import Role
@@ -224,11 +225,7 @@ class Process(IndexableModel):
 
         See https://stackoverflow.com/a/19651840/5990435
         """
-        return (
-            hasattr(self, "_prefetched_objects_cache")
-            and self.tasks.field.remote_field.get_cache_name()
-            in self._prefetched_objects_cache
-        )
+        return is_prefetched(self.tasks)
 
     @property
     def expiry(self):
diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py
index 0563c935e041781853614adfac40cc2abac620e0..230bafd37351cccedda4d2f0c8298f0b3cc1f916 100644
--- a/arkindex/project/serializer_fields.py
+++ b/arkindex/project/serializer_fields.py
@@ -11,6 +11,7 @@ from arkindex.documents.models import MetaType
 from arkindex.ponos.utils import get_process_from_task_auth
 from arkindex.process.models import ProcessMode, WorkerRun
 from arkindex.project.gis import ensure_linear_ring
+from arkindex.project.tools import is_prefetched
 
 
 class EnumField(serializers.ChoiceField):
@@ -284,9 +285,8 @@ class DatasetSetsCountField(serializers.DictField):
     def get_attribute(self, instance):
         # Skip this field if sets are not prefetched, or if they are missing a count
         if (
-            instance.sets.field.remote_field.get_cache_name
-            not in getattr(instance, "_prefetched_objects_cache", {})
-            or not all(hasattr(set, "element_count") for set in instance.sets)
+            not is_prefetched(instance.sets)
+            or not all(hasattr(set, "element_count") for set in instance.sets.all())
         ):
             return None
 
diff --git a/arkindex/project/tools.py b/arkindex/project/tools.py
index 72fde5b89d017d4f47ccafbfd2c69de69392281d..74bc0dd81f3629a722797d88824f3a0d7caa5695 100644
--- a/arkindex/project/tools.py
+++ b/arkindex/project/tools.py
@@ -188,3 +188,28 @@ def fake_now():
     Fake creation date for fixtures and test objects
     """
     return datetime(2020, 2, 2, 1, 23, 45, 678000, tzinfo=timezone.utc)
+
+
+def is_prefetched(related_manager) -> bool:
+    """
+    Determines whether the related items for a reverse foreign key have been prefetched;
+    that is, if calling `instance.things.all()` will not cause an SQL query.
+    Usage: `is_prefetched(instance.things)`
+    """
+    return (
+        related_manager.field.remote_field.get_cache_name()
+        in getattr(related_manager.instance, "_prefetched_objects_cache", {})
+    )
+
+
+def add_as_prefetch(related_manager, items) -> None:
+    """
+    Manually set a list of related items on an instance, as if they were actually prefetched from the database.
+    Usage: `add_as_prefetch(instance.things, [thing1, thing2])`
+    """
+    assert (
+        isinstance(items, list) and all(isinstance(item, related_manager.model) for item in items)
+    ), f"Prefetched items should be a list of {related_manager.model} instances."
+    cache = getattr(related_manager.instance, "_prefetched_objects_cache", {})
+    cache[related_manager.field.remote_field.get_cache_name()] = items
+    related_manager.instance._prefetched_objects_cache = cache
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index 71f274f34c7d7da4be4178fd6f374263046332a4..c77e858fe80606e5ae77a3fc6e7a4609c778f94d 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -1003,7 +1003,7 @@ class DatasetClone(CorpusACLMixin, CreateAPIView):
             DatasetSet(dataset_id=clone.id, name=set.name)
             for set in dataset.sets.all()
         ])
-        set_map = {set.name: set.id for set in cloned_sets}
+        set_map = {set.name: set for set in cloned_sets}
 
         # Associate all elements to the clone
         DatasetElement.objects.bulk_create([