diff --git a/arkindex/process/models.py b/arkindex/process/models.py index fa349eb6c435c0f71f4a64c852f7ce1a281cc3f2..e32dd410fe26622c3d7ea6506dac24cea163457b 100644 --- a/arkindex/process/models.py +++ b/arkindex/process/models.py @@ -28,6 +28,7 @@ from arkindex.process.managers import ( from arkindex.project.aws import S3FileMixin, S3FileStatus from arkindex.project.fields import ArrayField, MD5HashField from arkindex.project.models import IndexableModel +from arkindex.project.tools import is_prefetched from arkindex.project.validators import MaxValueValidator from arkindex.training.models import ModelVersion, ModelVersionState from arkindex.users.models import Role @@ -224,11 +225,7 @@ class Process(IndexableModel): See https://stackoverflow.com/a/19651840/5990435 """ - return ( - hasattr(self, "_prefetched_objects_cache") - and self.tasks.field.remote_field.get_cache_name() - in self._prefetched_objects_cache - ) + return is_prefetched(self.tasks) @property def expiry(self): diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py index 0563c935e041781853614adfac40cc2abac620e0..230bafd37351cccedda4d2f0c8298f0b3cc1f916 100644 --- a/arkindex/project/serializer_fields.py +++ b/arkindex/project/serializer_fields.py @@ -11,6 +11,7 @@ from arkindex.documents.models import MetaType from arkindex.ponos.utils import get_process_from_task_auth from arkindex.process.models import ProcessMode, WorkerRun from arkindex.project.gis import ensure_linear_ring +from arkindex.project.tools import is_prefetched class EnumField(serializers.ChoiceField): @@ -284,9 +285,8 @@ class DatasetSetsCountField(serializers.DictField): def get_attribute(self, instance): # Skip this field if sets are not prefetched, or if they are missing a count if ( - instance.sets.field.remote_field.get_cache_name - not in getattr(instance, "_prefetched_objects_cache", {}) - or not all(hasattr(set, "element_count") for set in instance.sets) + not is_prefetched(instance.sets) + or not all(hasattr(set, "element_count") for set in instance.sets.all()) ): return None diff --git a/arkindex/project/tools.py b/arkindex/project/tools.py index 72fde5b89d017d4f47ccafbfd2c69de69392281d..74bc0dd81f3629a722797d88824f3a0d7caa5695 100644 --- a/arkindex/project/tools.py +++ b/arkindex/project/tools.py @@ -188,3 +188,28 @@ def fake_now(): Fake creation date for fixtures and test objects """ return datetime(2020, 2, 2, 1, 23, 45, 678000, tzinfo=timezone.utc) + + +def is_prefetched(related_manager) -> bool: + """ + Determines whether the related items for a reverse foreign key have been prefetched; + that is, if calling `instance.things.all()` will not cause an SQL query. + Usage: `is_prefetched(instance.things)` + """ + return ( + related_manager.field.remote_field.get_cache_name() + in getattr(related_manager.instance, "_prefetched_objects_cache", {}) + ) + + +def add_as_prefetch(related_manager, items) -> None: + """ + Manually set a list of related items on an instance, as if they were actually prefetched from the database. + Usage: `add_as_prefetch(instance.things, [thing1, thing2])` + """ + assert ( + isinstance(items, list) and all(isinstance(item, related_manager.model) for item in items) + ), f"Prefetched items should be a list of {related_manager.model} instances." + cache = getattr(related_manager.instance, "_prefetched_objects_cache", {}) + cache[related_manager.field.remote_field.get_cache_name()] = items + related_manager.instance._prefetched_objects_cache = cache diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 71f274f34c7d7da4be4178fd6f374263046332a4..c77e858fe80606e5ae77a3fc6e7a4609c778f94d 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -1003,7 +1003,7 @@ class DatasetClone(CorpusACLMixin, CreateAPIView): DatasetSet(dataset_id=clone.id, name=set.name) for set in dataset.sets.all() ]) - set_map = {set.name: set.id for set in cloned_sets} + set_map = {set.name: set for set in cloned_sets} # Associate all elements to the clone DatasetElement.objects.bulk_create([