Skip to content
Snippets Groups Projects
Commit 0c4b8393 authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Add more filters on ListWorkerActivities

parent a92c4bc1
No related branches found
No related tags found
1 merge request!2311Add more filters on ListWorkerActivities
...@@ -114,7 +114,7 @@ from arkindex.project.mixins import ( ...@@ -114,7 +114,7 @@ from arkindex.project.mixins import (
SelectionMixin, SelectionMixin,
WorkerACLMixin, WorkerACLMixin,
) )
from arkindex.project.openapi import UUID_OR_STR from arkindex.project.openapi import UUID_OR_FALSE, UUID_OR_STR
from arkindex.project.pagination import CountCursorPagination from arkindex.project.pagination import CountCursorPagination
from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import PercentileCont from arkindex.project.tools import PercentileCont
...@@ -1877,15 +1877,39 @@ class ProcessWorkersActivity(ProcessACLMixin, WorkerActivityBase): ...@@ -1877,15 +1877,39 @@ class ProcessWorkersActivity(ProcessACLMixin, WorkerActivityBase):
OpenApiParameter( OpenApiParameter(
"process_id", "process_id",
type=UUID, type=UUID,
description="Filter worker activities by process ID", description="Filter worker activities by process ID.",
required=False, required=False,
), ),
OpenApiParameter( OpenApiParameter(
"state", "state",
enum=[state.value for state in WorkerActivityState], enum=[state.value for state in WorkerActivityState],
description="Filter worker activities by state", description="Filter worker activities by state.",
required=False, required=False,
), ),
OpenApiParameter(
"worker_version_id",
type=UUID,
required=False,
description="Filter worker activities by worker version.",
),
OpenApiParameter(
"model_version_id",
type=UUID_OR_FALSE,
required=False,
description=dedent("""
Filter worker activities by model version.
If set to false, only retrieve worker activities produced by a worker run with no model version.
""")
),
OpenApiParameter(
"worker_configuration_id",
type=UUID_OR_FALSE,
required=False,
description=dedent("""
Filter worker activities by worker configuration.
If set to false, only retrieve worker activities produced by a worker run with no worker configuration.
""")
),
] ]
) )
) )
...@@ -1905,15 +1929,24 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView): ...@@ -1905,15 +1929,24 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView):
def get_queryset(self): def get_queryset(self):
return WorkerActivity.objects.filter(element__corpus=self.corpus).order_by("id") return WorkerActivity.objects.filter(element__corpus=self.corpus).order_by("id")
def validate_uuid(self, value, key, errors):
try:
return UUID(value)
except (TypeError, ValueError):
errors[key] = ["Not a valid UUID."]
def get_uuid_or_none(self, key, errors):
value = self.request.query_params.get(key)
if value.lower() not in ("false", "0"):
return self.validate_uuid(value, key, errors)
return None
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
errors = {} errors = {}
process_id = None process_id = None
if "process_id" in self.request.query_params: if "process_id" in self.request.query_params:
try: process_id = self.validate_uuid(self.request.query_params["process_id"], "process_id", errors)
process_id = UUID(self.request.query_params["process_id"])
except (TypeError, ValueError):
errors["process_id"] = ["Process ID should be an UUID."]
if process_id: if process_id:
try: try:
...@@ -1935,6 +1968,20 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView): ...@@ -1935,6 +1968,20 @@ class WorkerActivityList(CorpusACLMixin, ProcessACLMixin, ListAPIView):
else: else:
queryset = queryset.filter(state=state) queryset = queryset.filter(state=state)
worker_version_id = None
if "worker_version_id" in self.request.query_params:
worker_version_id = self.validate_uuid(self.request.query_params["worker_version_id"], "worker_version_id", errors)
if worker_version_id:
queryset = queryset.filter(worker_version_id=worker_version_id)
if "model_version_id" in self.request.query_params:
model_version_filter = self.get_uuid_or_none("model_version_id", errors)
queryset = queryset.filter(model_version_id=model_version_filter)
if "worker_configuration_id" in self.request.query_params:
worker_configuration_filter = self.get_uuid_or_none("worker_configuration_id", errors)
queryset = queryset.filter(configuration_id=worker_configuration_filter)
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
......
...@@ -884,87 +884,189 @@ class TestWorkerActivity(FixtureTestCase): ...@@ -884,87 +884,189 @@ class TestWorkerActivity(FixtureTestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"process_id": ["This process does not exist."]}) self.assertDictEqual(response.json(), {"process_id": ["This process does not exist."]})
def test_list_filter_process(self): def test_list_filters(self):
process = self.corpus.processes.create( # Delete all worker activities
WorkerActivity.objects.all().delete()
# Create objects
process_1 = self.corpus.processes.create(
mode=ProcessMode.Workers,
creator=self.user,
)
process_2 = self.corpus.processes.create(
mode=ProcessMode.Workers, mode=ProcessMode.Workers,
creator=self.user, creator=self.user,
) )
activity = process.activities.create( worker_version_2 = WorkerVersion.objects.get(worker__slug="dla")
model_version_2 = self.model.versions.create(
state=ModelVersionState.Available,
hash="b",
)
configuration_2 = WorkerConfiguration.objects.create(worker=self.worker_version.worker, name="Another configuration", configuration={"c": "d"})
# Create all the worker activities
activity_1 = process_1.activities.create(
element=self.element, element=self.element,
worker_version=self.worker_version, worker_version=self.worker_version,
state=WorkerActivityState.Processed,
model_version=self.model_version,
configuration=self.configuration
) )
self.client.force_login(self.user) activity_2 = process_1.activities.create(
element=self.element,
with self.assertNumQueries(6): worker_version=worker_version_2,
response = self.client.get( state=WorkerActivityState.Error,
reverse("api:corpus-activity", kwargs={"corpus": self.corpus.id}), model_version=self.model_version,
{"process_id": str(process.id)} configuration=self.configuration
) )
self.assertEqual(response.status_code, status.HTTP_200_OK) activity_3 = process_1.activities.create(
element=self.element,
self.assertDictEqual(response.json(), {
"count": 1,
"number": 1,
"previous": None,
"next": None,
"results": [
{
"created": activity.created.isoformat().replace("+00:00", "Z"),
"updated": activity.updated.isoformat().replace("+00:00", "Z"),
"started": None,
"element_id": str(self.element.id),
"process_id": str(process.id),
"worker_version_id": str(self.worker_version.id),
"configuration_id": None,
"model_version_id": None,
"state": "queued"
}
]
})
def test_list_filter_state(self):
element = Element.objects.get(name="Volume 1")
activity = element.activities.create(
process=self.process,
worker_version=self.worker_version, worker_version=self.worker_version,
state=WorkerActivityState.Processed, state=WorkerActivityState.Processed,
model_version=self.model_version,
configuration=configuration_2
)
activity_4 = process_1.activities.create(
element=self.element,
worker_version=self.worker_version,
state=WorkerActivityState.Processed,
model_version=model_version_2,
configuration=configuration_2
)
activity_5 = process_2.activities.create(
element=self.element,
worker_version=worker_version_2,
state=WorkerActivityState.Processed,
model_version=model_version_2,
configuration=self.configuration
)
activity_6 = process_1.activities.create(
element=self.element,
worker_version=self.worker_version,
state=WorkerActivityState.Processed,
model_version=self.model_version,
configuration=None
) )
self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Processed).count(), 1)
self.client.force_login(self.user)
with self.assertNumQueries(5): # Worker activity responses
response = self.client.get( worker_activities = {
reverse("api:corpus-activity", kwargs={"corpus": self.corpus.id}), activity.id: {
{"state": "processed"} "created": activity.created.isoformat().replace("+00:00", "Z"),
) "updated": activity.updated.isoformat().replace("+00:00", "Z"),
self.assertEqual(response.status_code, status.HTTP_200_OK) "started": None,
"element_id": str(self.element.id),
"process_id": str(activity.process_id),
"worker_version_id": str(activity.worker_version.id),
"configuration_id": str(activity.configuration_id) if activity.configuration_id else None,
"model_version_id": str(activity.model_version_id) if activity.model_version_id else None,
"state": activity.state.value
}
for activity in WorkerActivity.objects.all()
}
self.assertDictEqual(response.json(), { # Test cases
"count": 1, cases = [
"number": 1, (
"previous": None, {"process_id": str(process_1.id)}, [activity_1.id, activity_2.id, activity_3.id, activity_4.id, activity_6.id]
"next": None, ),
"results": [ (
{"state": "processed"}, [activity_1.id, activity_3.id, activity_4.id, activity_5.id, activity_6.id]
),
(
{"process_id": str(process_1.id), "state": "error"}, [activity_2.id]
),
(
{"process_id": str(process_1.id), "state": "processed", "worker_version_id": str(self.worker_version.id)},
[activity_1.id, activity_3.id, activity_4.id, activity_6.id]
),
(
{"worker_version_id": str(worker_version_2.id)}, [activity_2.id, activity_5.id]
),
(
{ {
"created": activity.created.isoformat().replace("+00:00", "Z"), "process_id": str(process_1.id),
"updated": activity.updated.isoformat().replace("+00:00", "Z"), "state": "processed",
"started": None,
"element_id": str(element.id),
"process_id": str(self.process.id),
"worker_version_id": str(self.worker_version.id), "worker_version_id": str(self.worker_version.id),
"configuration_id": None, "model_version_id": str(self.model_version.id)
"model_version_id": None, },
"state": "processed" [activity_1.id, activity_3.id, activity_6.id]
} ),
] (
}) {"model_version_id": str(model_version_2.id)}, [activity_4.id, activity_5.id]
),
(
{
"process_id": str(process_1.id),
"state": "processed",
"worker_version_id": str(self.worker_version.id),
"model_version_id": str(self.model_version.id),
"worker_configuration_id": str(self.configuration.id)
},
[activity_1.id]
),
(
{"worker_configuration_id": str(configuration_2.id)}, [activity_3.id, activity_4.id]
),
(
{
"process_id": str(process_1.id),
"state": "processed",
"worker_version_id": str(self.worker_version.id),
"model_version_id": False,
"worker_configuration_id": str(self.configuration.id)
},
[]
),
(
{
"process_id": str(process_1.id),
"state": "processed",
"worker_version_id": str(self.worker_version.id),
"model_version_id": str(self.model_version.id),
"worker_configuration_id": False
},
[activity_6.id]
),
]
self.client.force_login(self.user)
for filters, activity_ids in cases:
# Sort activities by ID, like in the API response
activity_ids.sort()
# Filtering by process_id adds 1 query as it checks if the process exists
queries_count = 6 if "process_id" in filters else 5
# If there are no results returned, only the request to return a 'count' is made so we get one less query
if not len(activity_ids):
queries_count -= 1
with self.assertNumQueries(queries_count):
response = self.client.get(
reverse("api:corpus-activity", kwargs={"corpus": self.corpus.id}),
{**filters}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"count": len(activity_ids),
"number": 1,
"previous": None,
"next": None,
"results": [
worker_activities[id] for id in activity_ids
]
})
def test_list_invalid_filters(self): def test_list_invalid_filters(self):
self.client.force_login(self.superuser) self.client.force_login(self.superuser)
cases = [ cases = [
( (
{"process_id": "a"}, {"process_id": "a", "worker_version_id": "neon", "model_version_id": "genesis", "worker_configuration_id": "evangelion"},
{"process_id": ["Process ID should be an UUID."]}, {
"process_id": ["Not a valid UUID."],
"worker_version_id": ["Not a valid UUID."],
"model_version_id": ["Not a valid UUID."],
"worker_configuration_id": ["Not a valid UUID."]
},
), ),
( (
{"process_id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}, {"process_id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
...@@ -977,7 +1079,7 @@ class TestWorkerActivity(FixtureTestCase): ...@@ -977,7 +1079,7 @@ class TestWorkerActivity(FixtureTestCase):
( (
{"process_id": "a", "state": "lol"}, {"process_id": "a", "state": "lol"},
{ {
"process_id": ["Process ID should be an UUID."], "process_id": ["Not a valid UUID."],
"state": ["'lol' is not a valid WorkerActivity state"], "state": ["'lol' is not a valid WorkerActivity state"],
}, },
), ),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment