diff --git a/arkindex/process/api.py b/arkindex/process/api.py index a129558bad4ea10829d4b1619ca2d62fd0bedd90..75b7acca9c9ba834514cdfbfaed63e9d317d7533 100644 --- a/arkindex/process/api.py +++ b/arkindex/process/api.py @@ -114,7 +114,7 @@ from arkindex.project.mixins import ( SelectionMixin, WorkerACLMixin, ) -from arkindex.project.openapi import UUID_OR_STR +from arkindex.project.openapi import UUID_OR_FALSE, UUID_OR_STR from arkindex.project.pagination import CountCursorPagination from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.project.tools import PercentileCont @@ -1877,15 +1877,39 @@ class ProcessWorkersActivity(ProcessACLMixin, WorkerActivityBase): OpenApiParameter( "process_id", type=UUID, - description="Filter worker activities by process ID", + description="Filter worker activities by process ID.", required=False, ), OpenApiParameter( "state", enum=[state.value for state in WorkerActivityState], - description="Filter worker activities by state", + description="Filter worker activities by state.", required=False, ), + OpenApiParameter( + "worker_version_id", + type=UUID, + required=False, + description="Filter worker activities by worker version.", + ), + OpenApiParameter( + "model_version_id", + type=UUID_OR_FALSE, + required=False, + description=dedent(""" + Filter worker activities by model version. + If set to false, only retrieve worker activities produced by a worker run with no model version. + """) + ), + OpenApiParameter( + "worker_configuration_id", + type=UUID_OR_FALSE, + required=False, + description=dedent(""" + Filter worker activities by worker configuration. + If set to false, only retrieve worker activities produced by a worker run with no worker configuration. + """) + ), ] ) ) @@ -1905,6 +1929,16 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView): def get_queryset(self): return WorkerActivity.objects.filter(element__corpus=self.corpus).order_by("id") + def get_filter_value(self, key, errors): + value = self.request.query_params.get(key) + if value.lower() not in ("false", "0"): + try: + return UUID(value) + except (TypeError, ValueError): + errors[key] = ["Not a valid UUID."] + else: + return None + def filter_queryset(self, queryset): errors = {} @@ -1913,7 +1947,7 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView): try: process_id = UUID(self.request.query_params["process_id"]) except (TypeError, ValueError): - errors["process_id"] = ["Process ID should be an UUID."] + errors["process_id"] = ["Not a valid UUID."] if process_id: try: @@ -1935,6 +1969,23 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView): else: queryset = queryset.filter(state=state) + worker_version_id = None + if "worker_version_id" in self.request.query_params: + try: + worker_version_id = UUID(self.request.query_params["worker_version_id"]) + except (TypeError, ValueError): + errors["worker_version_id"] = ["Not a valid UUID."] + if worker_version_id: + queryset = queryset.filter(worker_version_id=worker_version_id) + + if "model_version_id" in self.request.query_params: + model_version_filter = self.get_filter_value("model_version_id", errors) + queryset = queryset.filter(model_version_id=model_version_filter) + + if "worker_configuration_id" in self.request.query_params: + worker_configuration_filter = self.get_filter_value("worker_configuration_id", errors) + queryset = queryset.filter(configuration_id=worker_configuration_filter) + if errors: raise ValidationError(errors) diff --git a/arkindex/process/tests/test_workeractivity.py b/arkindex/process/tests/test_workeractivity.py index 6ce01a536840166cd1db4441c9a1fd9432efa653..1c32f0240c1026700c1f0aead5e346a2cab320b1 100644 --- a/arkindex/process/tests/test_workeractivity.py +++ b/arkindex/process/tests/test_workeractivity.py @@ -884,87 +884,159 @@ class TestWorkerActivity(FixtureTestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"process_id": ["This process does not exist."]}) - def test_list_filter_process(self): - process = self.corpus.processes.create( + def test_list_filters(self): + # Delete all worker activities + WorkerActivity.objects.all().delete() + + # Create objects + process_1 = self.corpus.processes.create( mode=ProcessMode.Workers, creator=self.user, ) - activity = process.activities.create( + process_2 = self.corpus.processes.create( + mode=ProcessMode.Workers, + creator=self.user, + ) + worker_version_2 = WorkerVersion.objects.get(worker__slug="dla") + model_version_2 = self.model.versions.create( + state=ModelVersionState.Available, + hash="b", + ) + configuration_2 = WorkerConfiguration.objects.create(worker=self.worker_version.worker, name="Another configuration", configuration={"c": "d"}) + + # Create all the worker activities + activity_1 = process_1.activities.create( element=self.element, worker_version=self.worker_version, + state=WorkerActivityState.Processed, + model_version=self.model_version, + configuration=self.configuration + ) + activity_2 = process_1.activities.create( + element=self.element, + worker_version=worker_version_2, + state=WorkerActivityState.Error, + model_version=self.model_version, + configuration=self.configuration + ) + activity_3 = process_1.activities.create( + element=self.element, + worker_version=self.worker_version, + state=WorkerActivityState.Processed, + model_version=self.model_version, + configuration=configuration_2 + ) + activity_4 = process_1.activities.create( + element=self.element, + worker_version=self.worker_version, + state=WorkerActivityState.Processed, + model_version=model_version_2, + configuration=configuration_2 + ) + activity_5 = process_2.activities.create( + element=self.element, + worker_version=worker_version_2, + state=WorkerActivityState.Processed, + model_version=model_version_2, + configuration=self.configuration ) - self.client.force_login(self.user) - with self.assertNumQueries(6): - response = self.client.get( - reverse("api:corpus-activity", kwargs={"corpus": self.corpus.id}), - {"process_id": str(process.id)} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) + # Worker activity responses + worker_activities = { + activity.id: { + "created": activity.created.isoformat().replace("+00:00", "Z"), + "updated": activity.updated.isoformat().replace("+00:00", "Z"), + "started": None, + "element_id": str(self.element.id), + "process_id": str(activity.process_id), + "worker_version_id": str(activity.worker_version.id), + "configuration_id": str(activity.configuration_id), + "model_version_id": str(activity.model_version_id), + "state": activity.state.value + } + for activity in WorkerActivity.objects.all() + } - self.assertDictEqual(response.json(), { - "count": 1, - "number": 1, - "previous": None, - "next": None, - "results": [ + # Test cases + cases = [ + ( + {"process_id": str(process_1.id)}, [activity_1.id, activity_2.id, activity_3.id, activity_4.id] + ), + ( + {"state": "processed"}, [activity_1.id, activity_3.id, activity_4.id, activity_5.id] + ), + ( + {"process_id": str(process_1.id), "state": "error"}, [activity_2.id] + ), + ( + {"process_id": str(process_1.id), "state": "processed", "worker_version_id": str(self.worker_version.id)}, + [activity_1.id, activity_3.id, activity_4.id] + ), + ( + {"worker_version_id": str(worker_version_2.id)}, [activity_2.id, activity_5.id] + ), + ( { - "created": activity.created.isoformat().replace("+00:00", "Z"), - "updated": activity.updated.isoformat().replace("+00:00", "Z"), - "started": None, - "element_id": str(self.element.id), - "process_id": str(process.id), + "process_id": str(process_1.id), + "state": "processed", "worker_version_id": str(self.worker_version.id), - "configuration_id": None, - "model_version_id": None, - "state": "queued" - } - ] - }) + "model_version_id": str(self.model_version.id) + }, + [activity_1.id, activity_3.id] + ), + ( + {"model_version_id": str(model_version_2.id)}, [activity_4.id, activity_5.id] + ), + ( + { + "process_id": str(process_1.id), + "state": "processed", + "worker_version_id": str(self.worker_version.id), + "model_version_id": str(self.model_version.id), + "worker_configuration_id": str(self.configuration.id) + }, + [activity_1.id] + ), + ( + {"worker_configuration_id": str(configuration_2.id)}, [activity_3.id, activity_4.id] + ), + ] - def test_list_filter_state(self): - element = Element.objects.get(name="Volume 1") - activity = element.activities.create( - process=self.process, - worker_version=self.worker_version, - state=WorkerActivityState.Processed, - ) - self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Processed).count(), 1) self.client.force_login(self.user) + for filters, activity_ids in cases: + # Sort activities by ID, like in the API response + activity_ids.sort() - with self.assertNumQueries(5): - response = self.client.get( - reverse("api:corpus-activity", kwargs={"corpus": self.corpus.id}), - {"state": "processed"} - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) + # Filtering by process_id adds 1 query as it checks if the process exists + queries_count = 6 if "process_id" in filters else 5 + with self.assertNumQueries(queries_count): + response = self.client.get( + reverse("api:corpus-activity", kwargs={"corpus": self.corpus.id}), + {**filters} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual(response.json(), { - "count": 1, - "number": 1, - "previous": None, - "next": None, - "results": [ - { - "created": activity.created.isoformat().replace("+00:00", "Z"), - "updated": activity.updated.isoformat().replace("+00:00", "Z"), - "started": None, - "element_id": str(element.id), - "process_id": str(self.process.id), - "worker_version_id": str(self.worker_version.id), - "configuration_id": None, - "model_version_id": None, - "state": "processed" - } - ] - }) + self.assertDictEqual(response.json(), { + "count": len(activity_ids), + "number": 1, + "previous": None, + "next": None, + "results": [ + worker_activities[id] for id in activity_ids + ] + }) def test_list_invalid_filters(self): self.client.force_login(self.superuser) cases = [ ( - {"process_id": "a"}, - {"process_id": ["Process ID should be an UUID."]}, + {"process_id": "a", "worker_version_id": "neon", "model_version_id": "genesis", "worker_configuration_id": "evangelion"}, + { + "process_id": ["Not a valid UUID."], + "worker_version_id": ["Not a valid UUID."], + "model_version_id": ["Not a valid UUID."], + "worker_configuration_id": ["Not a valid UUID."] + }, ), ( {"process_id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}, @@ -977,7 +1049,7 @@ class TestWorkerActivity(FixtureTestCase): ( {"process_id": "a", "state": "lol"}, { - "process_id": ["Process ID should be an UUID."], + "process_id": ["Not a valid UUID."], "state": ["'lol' is not a valid WorkerActivity state"], }, ),