diff --git a/arkindex/dataimport/api.py b/arkindex/dataimport/api.py index 847b785a8adc319cc97cfd0816530833044886b4..e886bccbca1729f56006aa836f88e7546b156882 100644 --- a/arkindex/dataimport/api.py +++ b/arkindex/dataimport/api.py @@ -5,6 +5,7 @@ from django.conf import settings from django.db import transaction from django.db.models import Count, F, Max, Q from django.shortcuts import get_object_or_404 +from django.utils.functional import cached_property from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view from rest_framework import permissions, status from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError @@ -29,6 +30,8 @@ from arkindex.dataimport.models import ( RepositoryType, Revision, Worker, + WorkerActivity, + WorkerActivityState, WorkerRun, WorkerVersion, ) @@ -47,6 +50,7 @@ from arkindex.dataimport.serializers.imports import ( ) from arkindex.dataimport.serializers.workers import ( RepositorySerializer, + WorkerActivitySerializer, WorkerSerializer, WorkerVersionEditSerializer, WorkerVersionSerializer, @@ -62,7 +66,7 @@ from arkindex.project.mixins import ( SelectionMixin, WorkerACLMixin, ) -from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly +from arkindex.project.permissions import IsInternalUser, IsVerified, IsVerifiedOrReadOnly from arkindex.users.models import OAuthCredentials, Role from arkindex.users.utils import get_max_level from ponos.models import STATES_ORDERING, State @@ -1072,3 +1076,61 @@ class ListProcessElements(CustomPaginationViewMixin, CorpusACLMixin, ListAPIView return process.list_elements().values('id', 'type__slug', 'name') except AssertionError as e: raise ValidationError({'__all__': [str(e)]}) + + +@extend_schema(tags=['ml']) +class UpdateWorkerActivity(GenericAPIView): + """ + Makes a worker (internal user) able to update its activity on an element + Only allow defined evolutions of the element's state + """ + permission_classes = (IsInternalUser, ) + serializer_class = WorkerActivitySerializer + + @cached_property + def allowed_transitions(self): + # Defines a list of allowed previous states for any transition + queued = WorkerActivityState.Queued.value + started = WorkerActivityState.Started.value + error = WorkerActivityState.Error.value + processed = WorkerActivityState.Processed.value + return { + queued: [error], + started: [queued, error], + error: [started], + processed: [started], + } + + @extend_schema( + operation_id='UpdateWorkerActivity', + description=( + 'Updates the activity of a worker version on an element.\n\n' + 'The user must be **internal** to perform this request.' + ), + ) + def put(self, request, *args, **kwarg): + serializer = self.get_serializer(data=request.data, partial=False) + serializer.is_valid(raise_exception=True) + + worker_version_id = self.kwargs['pk'] + element_id = serializer.validated_data['element_id'] + state = serializer.validated_data['state'].value + + # We use the fact that only one worker activity may match the filter due to DB constraint + activity = WorkerActivity.objects.filter( + worker_version_id=worker_version_id, + element_id=element_id, + state__in=self.allowed_transitions[state] + ) + update_count = activity.update(state=state) + + if not update_count: + # As no row has been updated the provided data was incorrect + raise ValidationError({ + '__all__': [ + 'Either this worker activity does not exists ' + f"or updating the state to '{state}' is forbidden." + ] + }) + + return Response(serializer.data) diff --git a/arkindex/dataimport/serializers/workers.py b/arkindex/dataimport/serializers/workers.py index ec54d32ee1ca627bdffdcd793aa152d4e954ec04..653ee91017b761775c63852cc06181d8e140566a 100644 --- a/arkindex/dataimport/serializers/workers.py +++ b/arkindex/dataimport/serializers/workers.py @@ -4,7 +4,16 @@ from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from rest_framework.exceptions import ValidationError -from arkindex.dataimport.models import Repository, RepositoryType, Revision, Worker, WorkerVersion, WorkerVersionState +from arkindex.dataimport.models import ( + Repository, + RepositoryType, + Revision, + Worker, + WorkerActivity, + WorkerActivityState, + WorkerVersion, + WorkerVersionState, +) from arkindex.dataimport.serializers.git import RevisionWithRefsSerializer from arkindex.project.serializer_fields import EnumField @@ -145,3 +154,18 @@ class RepositorySerializer(serializers.ModelSerializer): 'url': {'read_only': True}, 'type': {'read_only': True}, } + + +class WorkerActivitySerializer(serializers.ModelSerializer): + """ + Serialize a repository + """ + state = EnumField(WorkerActivityState) + element_id = serializers.UUIDField() + + class Meta: + model = WorkerActivity + fields = ( + 'element_id', + 'state', + ) diff --git a/arkindex/dataimport/tests/test_workeractivity.py b/arkindex/dataimport/tests/test_workeractivity.py new file mode 100644 index 0000000000000000000000000000000000000000..2517a861d79d2c5262edd2084ae3ce69ae209699 --- /dev/null +++ b/arkindex/dataimport/tests/test_workeractivity.py @@ -0,0 +1,141 @@ +import uuid + +from django.urls import reverse +from rest_framework import status + +from arkindex.dataimport.models import WorkerActivityState, WorkerVersion +from arkindex.documents.models import Element +from arkindex.project.tests import FixtureTestCase + + +class TestWorkerActivity(FixtureTestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.worker_version = WorkerVersion.objects.get(worker__slug='reco') + cls.element = Element.objects.get(name='Volume 1, page 2r') + # Create a queued activity for this element + cls.activity = cls.element.activities.create(worker_version=cls.worker_version, state=WorkerActivityState.Queued) + + def test_put_activity_requires_internal(self): + """ + Only internal users (workers) are able to update the state of a worker activity + """ + cases = ( + (None, status.HTTP_403_FORBIDDEN, 0), + (self.user, status.HTTP_403_FORBIDDEN, 2), + (self.superuser, status.HTTP_403_FORBIDDEN, 2), + (self.internal_user, status.HTTP_200_OK, 3), + ) + for user, status_code, requests_count in cases: + if user: + self.client.force_login(user) + with self.assertNumQueries(requests_count): + response = self.client.put( + reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}), + {'element_id': str(self.element.id), 'state': WorkerActivityState.Started.value}, + content_type='application/json', + ) + self.assertEqual(response.status_code, status_code) + + def test_put_activity_wrong_worker_version(self): + """ + Raises a generic error in case the worker version does not exists because a single SQL request is performed + """ + self.client.force_login(self.internal_user) + with self.assertNumQueries(3): + response = self.client.put( + reverse('api:update-worker-activity', kwargs={'pk': str(uuid.uuid4())}), + {'element_id': str(self.element.id), 'state': WorkerActivityState.Started.value}, + content_type='application/json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + '__all__': [ + 'Either this worker activity does not exists or ' + f"updating the state to '{WorkerActivityState.Started.value}' is forbidden." + ] + }) + + def test_put_activity_unexisting(self): + """ + Raises a generic error in case no activity exists for this element + """ + self.client.force_login(self.internal_user) + with self.assertNumQueries(3): + response = self.client.put( + reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}), + {'element_id': str(uuid.uuid4()), 'state': WorkerActivityState.Started.value}, + content_type='application/json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + '__all__': [ + 'Either this worker activity does not exists or ' + f"updating the state to '{WorkerActivityState.Started.value}' is forbidden." + ] + }) + + def test_put_activity_allowed_states(self): + """ + Check each case state update is allowed depending on the actual state of the activity + """ + queued, started, error, processed = ( + WorkerActivityState[state].value + for state in ['Queued', 'Started', 'Error', 'Processed'] + ) + allowed_states_update = ( + (queued, started), + (started, processed), + (started, error), + (error, queued), + (error, started), + ) + self.client.force_login(self.internal_user) + for state, payload_state in allowed_states_update: + self.activity.state = state + self.activity.save() + with self.assertNumQueries(3): + response = self.client.put( + reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}), + {'element_id': str(self.element.id), 'state': payload_state}, + content_type='application/json', + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.activity.refresh_from_db() + self.assertEqual(self.activity.state.value, payload_state) + + def test_put_activity_forbidden_states(self): + """ + Check state update is forbidden for some non consistant cases + """ + queued, started, error, processed = ( + WorkerActivityState[state].value + for state in ['Queued', 'Started', 'Error', 'Processed'] + ) + forbidden_states_update = ( + (queued, queued), + (started, queued), + (queued, error), + (processed, error), + ) + self.client.force_login(self.internal_user) + for state, payload_state in forbidden_states_update: + self.activity.state = state + self.activity.save() + with self.assertNumQueries(3): + response = self.client.put( + reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}), + {'element_id': str(self.element.id), 'state': payload_state}, + content_type='application/json', + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + '__all__': [ + 'Either this worker activity does not exists or ' + f"updating the state to '{payload_state}' is forbidden." + ] + }) + self.activity.refresh_from_db() + self.assertEqual(self.activity.state.value, state) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 837a7beef1fcd4878b952d457fef97b7214aba3d..57b79617eac116bbdb4bdb6e45cb41111e8b88c8 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -20,6 +20,7 @@ from arkindex.dataimport.api import ( RepositoryRetrieve, RevisionRetrieve, StartProcess, + UpdateWorkerActivity, WorkerList, WorkerRetrieve, WorkerRunDetails, @@ -195,6 +196,7 @@ api = [ path('workers/<uuid:pk>/', WorkerRetrieve.as_view(), name='worker-retrieve'), path('workers/<uuid:pk>/versions/', WorkerVersionList.as_view(), name='worker-versions'), path('workers/versions/<uuid:pk>/', WorkerVersionRetrieve.as_view(), name='version-retrieve'), + path('workers/versions/<uuid:pk>/activity/', UpdateWorkerActivity.as_view(), name='update-worker-activity'), # Import workflows path('imports/', DataImportsList.as_view(), name='import-list'), diff --git a/arkindex/project/config.py b/arkindex/project/config.py index 8ba06ddc33cadbeea70539b5f043ae7e7aa5b2a9..686e9bc4c965fce28dde906478d1db17c686b3e1 100644 --- a/arkindex/project/config.py +++ b/arkindex/project/config.py @@ -129,6 +129,7 @@ def get_settings_parser(base_dir): features_parser.add_option('search', type=bool, default=True) features_parser.add_option('transkribus', type=bool, default=True) features_parser.add_option('workers', type=bool, default=False) + features_parser.add_option('workers_activity', type=bool, default=False) cache_parser = ConfigParser() cache_parser.add_option('type', type=CacheType, default=None) diff --git a/arkindex/project/permissions.py b/arkindex/project/permissions.py index 6de81363ad56b200e36ea329dfaa847bd6985627..aca0d5518ca1c8da16006dd29a8c616e2f936820 100644 --- a/arkindex/project/permissions.py +++ b/arkindex/project/permissions.py @@ -2,6 +2,14 @@ from rest_framework import permissions from rest_framework.exceptions import PermissionDenied +class AllowNone(object): + """ + Systematically refuse permission + """ + def has_permission(self, request, view): + return None + + class InternalGroupPermissionMixin(object): """ Immediately authenticate any non-admin Internal users. @@ -95,6 +103,10 @@ class IsAdminUser(InternalGroupPermissionMixin, pass +class IsInternalUser(InternalGroupPermissionMixin, AllowNone): + pass + + class IsVerified(VerifiedEmailPermissionMixin, IsAuthenticated): pass diff --git a/arkindex/project/tests/config_samples/defaults.yaml b/arkindex/project/tests/config_samples/defaults.yaml index de50e38a5de58a1e7198149b1a411d36f67fceb4..48c92d2ac69018a27edc9b31bea8b3638c3fde2b 100644 --- a/arkindex/project/tests/config_samples/defaults.yaml +++ b/arkindex/project/tests/config_samples/defaults.yaml @@ -34,6 +34,7 @@ features: signup: true transkribus: true workers: false + workers_activity: false gitlab: app_id: null app_secret: null diff --git a/arkindex/project/tests/config_samples/override.yaml b/arkindex/project/tests/config_samples/override.yaml index e56b28323b7bfd25d39affd150e411f7b1fe8bf8..5d5dfbf68984ad2a14d8bf2a93404595afa670ec 100644 --- a/arkindex/project/tests/config_samples/override.yaml +++ b/arkindex/project/tests/config_samples/override.yaml @@ -48,6 +48,7 @@ features: signup: false transkribus: false workers: true + workers_activity: true gitlab: app_id: a app_secret: b