diff --git a/arkindex/process/api.py b/arkindex/process/api.py index 75b7acca9c9ba834514cdfbfaed63e9d317d7533..0eb43c779eabac248342f8305a53271ee4a77405 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -1929,25 +1929,24 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView): def get_queryset(self): return WorkerActivity.objects.filter(element__corpus=self.corpus).order_by("id") - def get_filter_value(self, key, errors): + def validate_uuid(self, value, key, errors): + try: + return UUID(value) + except (TypeError, ValueError): + errors[key] = ["Not a valid UUID."] + + def get_uuid_or_none(self, key, errors): value = self.request.query_params.get(key) if value.lower() not in ("false", "0"): - try: - return UUID(value) - except (TypeError, ValueError): - errors[key] = ["Not a valid UUID."] - else: - return None + return self.validate_uuid(value, key, errors) + return None def filter_queryset(self, queryset): errors = {} process_id = None if "process_id" in self.request.query_params: - try: - process_id = UUID(self.request.query_params["process_id"]) - except (TypeError, ValueError): - errors["process_id"] = ["Not a valid UUID."] + process_id = self.validate_uuid(self.request.query_params["process_id"], "process_id", errors) if process_id: try: @@ -1971,19 +1970,16 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView): worker_version_id = None if "worker_version_id" in self.request.query_params: - try: - worker_version_id = UUID(self.request.query_params["worker_version_id"]) - except (TypeError, ValueError): - errors["worker_version_id"] = ["Not a valid UUID."] + worker_version_id = self.validate_uuid(self.request.query_params["worker_version_id"], "worker_version_id", errors) if worker_version_id: queryset = queryset.filter(worker_version_id=worker_version_id) if "model_version_id" in self.request.query_params: - model_version_filter = self.get_filter_value("model_version_id", errors) + model_version_filter = self.get_uuid_or_none("model_version_id", errors) queryset = queryset.filter(model_version_id=model_version_filter) if "worker_configuration_id" in self.request.query_params: - worker_configuration_filter = self.get_filter_value("worker_configuration_id", errors) + worker_configuration_filter = self.get_uuid_or_none("worker_configuration_id", errors) queryset = queryset.filter(configuration_id=worker_configuration_filter) if errors: diff --git a/arkindex/process/tests/test_workeractivity.py b/arkindex/process/tests/test_workeractivity.py index 1c32f0240c1026700c1f0aead5e346a2cab320b1..be66bd618b4d1dce6a572e9b8082291c5917975c 100644 --- a/arkindex/process/tests/test_workeractivity.py +++ b/arkindex/process/tests/test_workeractivity.py @@ -940,6 +940,13 @@ class TestWorkerActivity(FixtureTestCase): model_version=model_version_2, configuration=self.configuration ) + activity_6 = process_1.activities.create( + element=self.element, + worker_version=self.worker_version, + state=WorkerActivityState.Processed, + model_version=self.model_version, + configuration=None + ) # Worker activity responses worker_activities = { @@ -950,8 +957,8 @@ class TestWorkerActivity(FixtureTestCase): "element_id": str(self.element.id), "process_id": str(activity.process_id), "worker_version_id": str(activity.worker_version.id), - "configuration_id": str(activity.configuration_id), - "model_version_id": str(activity.model_version_id), + "configuration_id": str(activity.configuration_id) if activity.configuration_id else None, + "model_version_id": str(activity.model_version_id) if activity.model_version_id else None, "state": activity.state.value } for activity in WorkerActivity.objects.all() @@ -960,17 +967,17 @@ class TestWorkerActivity(FixtureTestCase): # Test cases cases = [ ( - {"process_id": str(process_1.id)}, [activity_1.id, activity_2.id, activity_3.id, activity_4.id] + {"process_id": str(process_1.id)}, [activity_1.id, activity_2.id, activity_3.id, activity_4.id, activity_6.id] ), ( - {"state": "processed"}, [activity_1.id, activity_3.id, activity_4.id, activity_5.id] + {"state": "processed"}, [activity_1.id, activity_3.id, activity_4.id, activity_5.id, activity_6.id] ), ( {"process_id": str(process_1.id), "state": "error"}, [activity_2.id] ), ( {"process_id": str(process_1.id), "state": "processed", "worker_version_id": str(self.worker_version.id)}, - [activity_1.id, activity_3.id, activity_4.id] + [activity_1.id, activity_3.id, activity_4.id, activity_6.id] ), ( {"worker_version_id": str(worker_version_2.id)}, [activity_2.id, activity_5.id] @@ -982,7 +989,7 @@ class TestWorkerActivity(FixtureTestCase): "worker_version_id": str(self.worker_version.id), "model_version_id": str(self.model_version.id) }, - [activity_1.id, activity_3.id] + [activity_1.id, activity_3.id, activity_6.id] ), ( {"model_version_id": str(model_version_2.id)}, [activity_4.id, activity_5.id] @@ -1000,6 +1007,26 @@ class TestWorkerActivity(FixtureTestCase): ( {"worker_configuration_id": str(configuration_2.id)}, [activity_3.id, activity_4.id] ), + ( + { + "process_id": str(process_1.id), + "state": "processed", + "worker_version_id": str(self.worker_version.id), + "model_version_id": False, + "worker_configuration_id": str(self.configuration.id) + }, + [] + ), + ( + { + "process_id": str(process_1.id), + "state": "processed", + "worker_version_id": str(self.worker_version.id), + "model_version_id": str(self.model_version.id), + "worker_configuration_id": False + }, + [activity_6.id] + ), ] self.client.force_login(self.user) @@ -1009,6 +1036,9 @@ class TestWorkerActivity(FixtureTestCase): # Filtering by process_id adds 1 query as it checks if the process exists queries_count = 6 if "process_id" in filters else 5 + # If there are no results returned, only the request to return a 'count' is made so we get one less query + if not len(activity_ids): + queries_count -= 1 with self.assertNumQueries(queries_count): response = self.client.get( reverse("api:corpus-activity", kwargs={"corpus": self.corpus.id}),