Skip to content
Snippets Groups Projects
Commit a3886fea authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

RQ overrides

parent 43bedfa4
No related branches found
No related tags found
No related merge requests found
Showing
with 582 additions and 23 deletions
......@@ -8,4 +8,4 @@ line_length = 120
default_section=FIRSTPARTY
known_first_party = arkindex_common,ponos,transkribus
known_third_party = boto3,botocore,corsheaders,django,django_admin_hstore_widget,django_rq,elasticsearch,elasticsearch_dsl,enumfields,gitlab,psycopg2,requests,responses,rest_framework,setuptools,sqlparse,tenacity,tripoli,yaml
known_third_party = boto3,botocore,corsheaders,django,django_admin_hstore_widget,django_rq,elasticsearch,elasticsearch_dsl,enumfields,gitlab,psycopg2,requests,responses,rest_framework,rq,setuptools,sqlparse,tenacity,tripoli,yaml
......@@ -416,7 +416,7 @@ class ElementsListMixin(object):
if not queryset.exists():
raise NotFound
element_trash(queryset, delete_children=delete_children)
element_trash(queryset, delete_children=delete_children, user_id=self.request.user.id)
return Response(status=status.HTTP_204_NO_CONTENT)
......@@ -822,7 +822,7 @@ class CorpusRetrieve(RetrieveUpdateDestroyAPIView):
self.permission_denied(request, message='You do not have write access to this corpus.')
def perform_destroy(self, instance):
corpus_delete(instance)
corpus_delete(instance, user_id=self.request.user.id)
class TranscriptionsPagination(PageNumberPagination):
......
......@@ -641,7 +641,7 @@ class ElementMLStats(MLStatsBase, RetrieveDestroyAPIView):
return Element.objects.filter(corpus__in=Corpus.objects.readable(self.request.user)).select_related('type')
def destroy(self, *args, **kwargs):
ml_results_delete(element=self.get_object())
ml_results_delete(element=self.get_object(), user_id=self.request.user.id)
return Response(status=status.HTTP_204_NO_CONTENT)
......@@ -657,5 +657,5 @@ class CorpusMLStats(MLStatsBase, RetrieveDestroyAPIView):
return Corpus.objects.readable(self.request.user).only('id')
def destroy(self, *args, **kwargs):
ml_results_delete(corpus=self.get_object())
ml_results_delete(corpus=self.get_object(), user_id=self.request.user.id)
return Response(status=status.HTTP_204_NO_CONTENT)
......@@ -40,4 +40,4 @@ class ReindexConfigSerializer(serializers.Serializer):
return data
def save(self):
reindex_start(**self.validated_data)
reindex_start(**self.validated_data, user_id=self.context['request'].user.id)
......@@ -5,6 +5,7 @@ from typing import Optional
from django.db.models import Q
from django.db.models.deletion import Collector
from django_rq import job
from rq import get_current_job
from arkindex.dataimport.models import DataImport, DataImportElement, WorkerRun
from arkindex.documents.indexer import Indexer
......@@ -156,6 +157,8 @@ def ml_results_delete(corpus_id: Optional[str] = None,
@job('high')
def corpus_delete(corpus_id: str) -> None:
# Note that this can be None when the task is run outside of a RQ worker (e.g. unit test)
rq_job = get_current_job()
corpus = Corpus.objects.get(id=corpus_id)
logger.info(f'Deleting {corpus!r}')
......@@ -189,7 +192,10 @@ def corpus_delete(corpus_id: str) -> None:
Corpus.objects.filter(id=corpus_id),
]
for queryset in querysets:
for i, queryset in enumerate(querysets):
if rq_job:
rq_job.set_progress(i / len(querysets))
rq_job.set_description(f'Deleting {queryset.model.__name__} on corpus {corpus.name}')
deleted_count = queryset._raw_delete(using='default')
logger.info(f'Deleted {deleted_count} {queryset.model.__name__}')
......
......@@ -69,6 +69,8 @@ class TestAdminAPI(FixtureTestCase):
elements=True,
entities=True,
drop=False,
user_id=self.superuser.id,
description='Full reindexation',
))
@override_settings(ARKINDEX_FEATURES={'search': False})
......
......@@ -93,6 +93,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase):
elements=True,
entities=False,
drop=False,
description=f'Indexation of element {self.page.id}',
user_id=None,
))
@override_settings(ARKINDEX_FEATURES={'search': False})
......@@ -383,4 +385,6 @@ class TestBulkElementTranscriptions(FixtureAPITestCase):
elements=True,
entities=False,
drop=False,
user_id=None,
description=f'Indexation of element {top_level.id}'
))
......@@ -390,4 +390,8 @@ class TestCorpus(FixtureAPITestCase):
response = self.client.delete(reverse('api:corpus-retrieve', kwargs={'pk': self.corpus_private.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(corpus_id=str(self.corpus_private.id)))
self.assertEqual(delay_mock.call_args, call(
corpus_id=str(self.corpus_private.id),
description=f'Deletion of corpus {self.corpus_private.name}',
user_id=self.user.id,
))
......@@ -97,6 +97,8 @@ class TestTranscriptionCreate(FixtureAPITestCase):
elements=True,
entities=False,
drop=False,
user_id=None,
description=f'Indexation of element {self.line.id}',
))
@patch('arkindex.project.triggers.tasks.reindex_start.delay')
......@@ -206,6 +208,8 @@ class TestTranscriptionCreate(FixtureAPITestCase):
elements=True,
entities=False,
drop=False,
user_id=None,
description=f'Indexation of element {self.line.id}',
))
def test_manual_transcription_forbidden_type(self):
......
......@@ -207,7 +207,11 @@ class TestDestroyElements(FixtureAPITestCase):
args, kwargs = delay_mock.call_args
self.assertEqual(len(args), 0)
self.assertCountEqual(list(kwargs.pop('queryset')), list(self.corpus.elements.all()))
self.assertDictEqual(kwargs, {'delete_children': True})
self.assertDictEqual(kwargs, {
'delete_children': True,
'user_id': self.user.id,
'description': 'Element deletion',
})
@patch('arkindex.project.triggers.tasks.element_trash.delay')
def test_destroy_corpus_elements_delete_children(self, delay_mock):
......@@ -223,7 +227,11 @@ class TestDestroyElements(FixtureAPITestCase):
args, kwargs = delay_mock.call_args
self.assertEqual(len(args), 0)
self.assertCountEqual(list(kwargs.pop('queryset')), list(self.corpus.elements.all()))
self.assertDictEqual(kwargs, {'delete_children': False})
self.assertDictEqual(kwargs, {
'delete_children': False,
'user_id': self.user.id,
'description': 'Element deletion',
})
@patch('arkindex.project.triggers.tasks.element_trash.delay')
def test_destroy_corpus_elements_rejected_filters(self, delay_mock):
......@@ -290,7 +298,11 @@ class TestDestroyElements(FixtureAPITestCase):
# Direct children of the volume
list(Element.objects.get_descending(self.vol.id).filter(paths__path__last=self.vol.id)),
)
self.assertDictEqual(kwargs, {'delete_children': True})
self.assertDictEqual(kwargs, {
'delete_children': True,
'user_id': self.user.id,
'description': 'Element deletion',
})
@patch('arkindex.project.triggers.tasks.element_trash.delay')
def test_destroy_element_children_delete_children(self, delay_mock):
......@@ -310,7 +322,11 @@ class TestDestroyElements(FixtureAPITestCase):
# Direct children of the volume
list(Element.objects.get_descending(self.vol.id).filter(paths__path__last=self.vol.id)),
)
self.assertDictEqual(kwargs, {'delete_children': False})
self.assertDictEqual(kwargs, {
'delete_children': False,
'user_id': self.user.id,
'description': 'Element deletion',
})
@patch('arkindex.project.triggers.tasks.element_trash.delay')
def test_destroy_element_children_rejected_filters(self, delay_mock):
......@@ -377,7 +393,11 @@ class TestDestroyElements(FixtureAPITestCase):
# Direct parents of the surface
list(Element.objects.get_ascending(self.surface.id, recursive=False)),
)
self.assertDictEqual(kwargs, {'delete_children': True})
self.assertDictEqual(kwargs, {
'delete_children': True,
'user_id': self.user.id,
'description': 'Element deletion',
})
@patch('arkindex.project.triggers.tasks.element_trash.delay')
def test_destroy_element_parents_delete_parents(self, delay_mock):
......@@ -397,7 +417,11 @@ class TestDestroyElements(FixtureAPITestCase):
# Direct parents of the surface
list(Element.objects.get_ascending(self.surface.id, recursive=False)),
)
self.assertDictEqual(kwargs, {'delete_children': False})
self.assertDictEqual(kwargs, {
'delete_children': False,
'user_id': self.user.id,
'description': 'Element deletion',
})
@patch('arkindex.project.triggers.tasks.element_trash.delay')
def test_destroy_element_parents_rejected_filters(self, delay_mock):
......
......@@ -269,6 +269,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {entity.id}',
))
@patch('arkindex.project.triggers.tasks.reindex_start.delay')
......@@ -301,6 +303,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {entity.id}',
))
@patch('arkindex.project.triggers.tasks.reindex_start.delay')
......@@ -333,6 +337,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {entity.id}',
))
def test_create_entity_requires_login(self):
......@@ -377,6 +383,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {entity.id}',
))
def test_create_link(self):
......@@ -806,6 +814,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {self.entity_bis.id}',
))
@patch('arkindex.project.triggers.tasks.reindex_start.delay')
......@@ -828,6 +838,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {self.entity_bis.id}',
))
def test_validated_entity_not_verified(self):
......@@ -926,6 +938,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {entity.id}',
))
@patch('arkindex.project.triggers.tasks.reindex_start.delay')
......@@ -947,6 +961,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
transcriptions=False,
entities=True,
drop=False,
user_id=None,
description=f'Indexation of entity {self.entity.id}',
))
@patch('arkindex.documents.api.entities.ESEntity')
......
......@@ -147,6 +147,8 @@ class TestMLResults(FixtureTestCase):
corpus_id=str(self.corpus.id),
element_id=None,
batch_size=1000,
user_id=self.superuser.id,
description=f'ML results deletion on corpus {self.corpus.id}',
))
self.assertEqual(reindex_delay_mock.call_count, 1)
self.assertEqual(reindex_delay_mock.call_args, call(
......@@ -168,6 +170,8 @@ class TestMLResults(FixtureTestCase):
corpus_id=str(self.corpus.id),
element_id=None,
batch_size=1000,
user_id=self.superuser.id,
description=f'ML results deletion on corpus {self.corpus.id}',
))
self.assertFalse(reindex_delay_mock.called)
......@@ -192,6 +196,8 @@ class TestMLResults(FixtureTestCase):
corpus_id=None,
element_id=str(self.page.id),
batch_size=1000,
user_id=self.superuser.id,
description=f'ML results deletion on element {self.page.id}',
))
self.assertEqual(reindex_delay_mock.call_count, 1)
self.assertEqual(reindex_delay_mock.call_args, call(
......@@ -213,5 +219,7 @@ class TestMLResults(FixtureTestCase):
corpus_id=None,
element_id=str(self.page.id),
batch_size=1000,
user_id=self.superuser.id,
description=f'ML results deletion on element {self.page.id}',
))
self.assertFalse(reindex_delay_mock.called)
......@@ -86,6 +86,8 @@ from arkindex.users.api import (
CredentialsList,
CredentialsRetrieve,
GroupsList,
JobList,
JobRetrieve,
OAuthCallback,
OAuthRetry,
OAuthSignIn,
......@@ -245,6 +247,10 @@ api = [
# Rights management
path('groups/', GroupsList.as_view(), name='groups-list'),
# Asynchronous jobs
path('jobs/', JobList.as_view(), name='jobs-list'),
path('jobs/<uuid:pk>/', JobRetrieve.as_view(), name='jobs-retrieve'),
# Management tools
path('reindex/', ReindexStart.as_view(), name='reindex-start'),
......
from typing import Optional
from django_rq.queues import DjangoRQ
from rq.compat import as_text, decode_redis_hash
from rq.job import Job as BaseJob
from rq.registry import BaseRegistry
def as_int(value) -> Optional[int]:
if value is None:
return
return int(value)
def as_float(value) -> Optional[float]:
if value is None:
return
return float(value)
class Job(BaseJob):
"""
Extension of RQ jobs to provide description updates and completion percentage
"""
def __init__(self, *args, user_id: Optional[int] = None, **kwargs):
super().__init__(*args, **kwargs)
self._progress = None
self.user_id = user_id
@property
def progress(self):
return self.get_progress(refresh=False)
def get_progress(self, refresh: bool = True):
if refresh:
self._progress = as_float(self.connection.hget(self.key, 'progress'))
return self._progress
def set_progress(self, progress: float):
progress = as_float(progress)
assert progress is not None and 0.0 <= progress <= 1.0, 'Progress should be a float between 0 and 1'
self._progress = progress
self.connection.hset(self.key, 'progress', self._progress)
def set_description(self, description: Optional[str]):
self.description = as_text(description)
self.connection.hset(self.key, 'description', self.description)
def to_dict(self, *args, **kwargs):
"""
Serialize the job into a dict for Redis storage
"""
obj = super().to_dict(*args, **kwargs)
# Never include None values as those are not accepted by Redis
if self._progress is not None:
obj['progress'] = self._progress
if self.user_id is not None:
obj['user_id'] = self.user_id
return obj
def restore(self, raw_data):
"""
Update job attributes from the Redis hash
"""
super().restore(raw_data)
obj = decode_redis_hash(raw_data)
self._progress = as_float(obj.get('progress'))
self.user_id = as_int(obj.get('user_id'))
def delete(self, pipeline=None, remove_from_queue=True, **kwargs):
"""
Overrides Job.delete, which already removes the job from all of RQ's registries and the queue
before removing the job itself, to also remove the job from the UserRegistry when it has a user ID.
Only remove when `remove_from_queue` is True, as users would otherwise 'lose' jobs as soon as they finish.
"""
if remove_from_queue and self.user_id:
registry = UserRegistry(
self.origin,
connection=self.connection,
job_class=self.__class__,
user_id=self.user_id,
)
registry.remove(self, pipeline=pipeline)
super().delete(pipeline=pipeline, remove_from_queue=remove_from_queue, **kwargs)
class UserRegistry(BaseRegistry):
"""Job registry to index jobs per user_id
The user_id value is extracted by the Queue class below
when a job is enqueued"""
def __init__(self, *args, user_id=None, **kwargs):
# Build one sorted set per user in Redis
super().__init__(*args, **kwargs)
self.key = self.key_template.format(f'user:{user_id}')
def cleanup(self):
"""
This method is only here to prevent errors because this method is
automatically called by `count()` and `get_job_ids()` methods
implemented in BaseRegistry.
An actual cleanup is implemented in the Job.delete method.
"""
pass
class Queue(DjangoRQ):
def user_registry(self, user_id):
"""Build a UserRegistry listing jobs for a specific user"""
return UserRegistry(user_id=user_id, queue=self, job_class=self.job_class)
def create_job(self, *args, **kwargs):
# Extract user_id from delay() kwargs
job_kwargs = kwargs.get('kwargs', {})
user_id = job_kwargs.pop('user_id', None)
description = job_kwargs.pop('description', None)
# Build normal job
job = super().create_job(*args, **kwargs)
# Add the user ID to the job
job.user_id = user_id
# Add the description too, because .delay() does not allow a custom description
# and instead uses the one defined in @job(description=…)
if description:
job.description = description
# Add job to user registry
if user_id is not None:
reg = self.user_registry(user_id)
reg.add(job)
return job
......@@ -307,6 +307,11 @@ RQ_QUEUES = {
}
}
RQ = {
'JOB_CLASS': 'arkindex.project.rq_overrides.Job',
'QUEUE_CLASS': 'arkindex.project.rq_overrides.Queue'
}
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
......
"""
Helper methods to trigger tasks in asynchronous workers
"""
from typing import Union
from typing import Optional, Union
from uuid import UUID
from django.conf import settings
......@@ -18,7 +18,8 @@ def reindex_start(*,
transcriptions: bool = False,
elements: bool = False,
entities: bool = False,
drop: bool = False) -> None:
drop: bool = False,
user_id: Optional[int] = None) -> None:
"""
Reindex elements into ElasticSearch.
......@@ -53,6 +54,15 @@ def reindex_start(*,
elif entity:
entity_id = str(entity)
if element_id:
description = f'Indexation of element {element_id}'
elif entity_id:
description = f'Indexation of entity {entity_id}'
elif corpus_id:
description = f'Indexation of corpus {corpus_id}'
else:
description = 'Full reindexation'
tasks.reindex_start.delay(
element_id=element_id,
corpus_id=corpus_id,
......@@ -61,13 +71,16 @@ def reindex_start(*,
elements=elements,
entities=entities,
drop=drop,
description=description,
user_id=user_id,
)
def ml_results_delete(*,
element: Union[Element, UUID, str] = None,
corpus: Union[Corpus, UUID, str] = None,
batch_size: int = 1000) -> None:
batch_size: int = 1000,
user_id: Optional[int] = None) -> None:
"""
Delete all ML results from all sources on a corpus
or an element and its *direct* (non-recursive) children.
......@@ -86,28 +99,48 @@ def ml_results_delete(*,
assert element_id or corpus_id, 'Missing element or corpus ID'
job = tasks.ml_results_delete.delay(corpus_id=corpus_id, element_id=element_id, batch_size=batch_size)
if element_id:
description = f'ML results deletion on element {element_id}'
else:
description = f'ML results deletion on corpus {corpus_id}'
job = tasks.ml_results_delete.delay(
corpus_id=corpus_id,
element_id=element_id,
batch_size=batch_size,
description=description,
user_id=user_id,
)
if settings.ARKINDEX_FEATURES['search']:
# Trigger a reindex afterwards to cleanup the deleted results
tasks.reindex_start.delay(corpus_id=corpus_id, element_id=element_id, depends_on=job)
def corpus_delete(corpus: Union[Corpus, UUID, str]) -> None:
def corpus_delete(corpus: Union[Corpus, UUID, str], user_id: Optional[int] = None) -> None:
"""
Delete a whole corpus without killing a server by removing all related
models first and lowering the amount of foreign keys Django has to handle.
"""
if isinstance(corpus, Corpus):
corpus_id = str(corpus.id)
description = f'Deletion of corpus {corpus.name}'
else:
corpus_id = str(corpus)
description = f'Deletion of corpus {corpus_id}'
tasks.corpus_delete.delay(corpus_id=corpus_id)
tasks.corpus_delete.delay(corpus_id=corpus_id, description=description, user_id=user_id)
def element_trash(queryset: ElementQuerySet, delete_children: bool = True) -> None:
def element_trash(queryset: ElementQuerySet,
delete_children: bool = True,
user_id: Optional[int] = None) -> None:
"""
Run ElementQuerySet.trash to delete a batch of elements.
"""
assert isinstance(queryset, ElementQuerySet), 'Only Element querysets can be trashed'
tasks.element_trash.delay(queryset=queryset, delete_children=delete_children)
tasks.element_trash.delay(
queryset=queryset,
delete_children=delete_children,
user_id=user_id,
description='Element deletion',
)
......@@ -12,8 +12,10 @@ from django.template.loader import render_to_string
from django.urls import reverse
from django.utils.http import urlsafe_base64_encode
from django.views.generic import RedirectView
from django_rq.queues import get_queue
from django_rq.settings import QUEUES
from rest_framework import status
from rest_framework.exceptions import AuthenticationFailed, PermissionDenied, ValidationError
from rest_framework.exceptions import AuthenticationFailed, NotFound, PermissionDenied, ValidationError
from rest_framework.generics import (
CreateAPIView,
ListAPIView,
......@@ -25,6 +27,7 @@ from rest_framework.generics import (
)
from rest_framework.response import Response
from rest_framework.views import APIView
from rq.job import JobStatus
from arkindex.documents.models import Corpus
from arkindex.project.permissions import IsAuthenticatedOrReadOnly, IsVerified
......@@ -33,6 +36,7 @@ from arkindex.users.providers import get_provider, oauth_providers
from arkindex.users.serializers import (
EmailLoginSerializer,
GroupSerializer,
JobSerializer,
NewUserSerializer,
OAuthCredentialsSerializer,
OAuthProviderClassSerializer,
......@@ -460,3 +464,40 @@ class GroupsList(ListCreateAPIView):
.filter(Q(public=True) | Q(users__in=[self.request.user])) \
.annotate(members_count=Count('users')) \
.order_by('id')
class JobList(ListAPIView):
"""
List asynchronous jobs linked to the current user.
"""
permission_classes = (IsVerified, )
serializer_class = JobSerializer
pagination_class = None
def get_queryset(self):
if not self.request:
return []
return self.request.user.get_rq_jobs()
class JobRetrieve(RetrieveDestroyAPIView):
"""
Retrieve a single job by ID.
"""
permission_classes = (IsVerified, )
serializer_class = JobSerializer
def get_object(self):
for queue_name in QUEUES.keys():
job = get_queue(queue_name).fetch_job(str(self.kwargs['pk']))
if not job:
continue
if job.user_id != self.request.user.id:
raise NotFound
return job
raise NotFound
def perform_destroy(self, instance):
if instance.get_status(refresh=False) == JobStatus.STARTED:
raise ValidationError(['Cannot delete a running job.'])
instance.delete()
......@@ -2,6 +2,8 @@ import uuid
from django.contrib.auth.models import AbstractBaseUser
from django.db import models
from django_rq.queues import get_queue
from django_rq.settings import QUEUES
from enumfields import Enum, EnumField
from arkindex.users.managers import UserManager
......@@ -69,6 +71,18 @@ class User(AbstractBaseUser):
# Simplest possible answer: All admins are staff
return self.is_admin
def get_rq_jobs(self):
"""
List RQ jobs linked to this user's ID in all queues
"""
for queue_name in QUEUES.keys():
queue = get_queue(queue_name)
for job_id in queue.user_registry(self.id).get_job_ids():
# queue.fetch_job detects missing job hashes and removes them
job = queue.fetch_job(job_id)
if job:
yield job
class Group(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True, editable=False)
......
......@@ -237,3 +237,24 @@ class GroupSerializer(serializers.ModelSerializer):
# Associate the creator to the group
Membership.objects.create(user=self.context['request'].user, group=group, level=100)
return group
class JobSerializer(serializers.Serializer):
"""
Serializers a RQ job.
"""
id = serializers.UUIDField(read_only=True)
description = serializers.CharField(read_only=True)
progress = serializers.FloatField(min_value=0, max_value=1, read_only=True, allow_null=True)
status = serializers.SerializerMethodField()
enqueued_at = serializers.DateTimeField(read_only=True, allow_null=True)
started_at = serializers.DateTimeField(read_only=True, allow_null=True)
ended_at = serializers.DateTimeField(read_only=True, allow_null=True)
def get_status(self, instance):
"""
Avoid causing more Redis queries to fetch a job's current status
Note that a job status is part of a JobStatus enum,
but the enum is just a plain object and not an Enum for Py2 compatibility.
"""
return instance.get_status(refresh=False)
from datetime import datetime
from unittest import expectedFailure
from unittest.mock import MagicMock, call, patch
from uuid import uuid4
from django.urls import reverse
from rest_framework import status
from rq.job import JobStatus
from arkindex.project.tests import FixtureAPITestCase
class MockedJob(object):
def __init__(self, user_id=None, status=JobStatus.QUEUED):
self.id = str(uuid4())
self.enqueued_at = datetime(2020, 1, 1, 13, 37, 42)
self.started_at = None
self.ended_at = None
self.progress = None
self.description = 'something'
self.user_id = user_id
self._status = status
self.delete = MagicMock()
def get_status(self, refresh=True):
# Those endpoints should not be setting refresh to True as they do not have to reload from Redis
assert refresh is False, 'refresh should be set to False'
return self._status
@patch('arkindex.project.rq_overrides.Queue.fetch_job')
@patch('arkindex.project.rq_overrides.UserRegistry.get_job_ids')
class TestJobs(FixtureAPITestCase):
def setUp(self):
self.user_mocked_job = MockedJob(self.user.id)
self.superuser_mocked_job = MockedJob(self.superuser.id)
def test_list_requires_login(self, ids_mock, fetch_mock):
with self.assertNumQueries(0):
response = self.client.get(reverse('api:jobs-list'))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertFalse(ids_mock.called)
self.assertFalse(fetch_mock.called)
@expectedFailure
def test_list_requires_verified(self, ids_mock, fetch_mock):
"""
This test fails due to a bug in the IsVerified permission class.
https://gitlab.com/arkindex/backend/-/issues/554
"""
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.get(reverse('api:jobs-list'))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertFalse(ids_mock.called)
self.assertFalse(fetch_mock.called)
def test_list(self, ids_mock, fetch_mock):
ids_mock.side_effect = [[self.user_mocked_job.id], []]
fetch_mock.return_value = self.user_mocked_job
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.get(reverse('api:jobs-list'))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.json(), [
{
'id': self.user_mocked_job.id,
'status': 'queued',
'enqueued_at': '2020-01-01T13:37:42Z',
'started_at': None,
'ended_at': None,
'progress': None,
'description': 'something'
}
])
self.assertEqual(ids_mock.call_count, 2)
self.assertEqual(fetch_mock.call_count, 1)
def test_retrieve_requires_login(self, ids_mock, fetch_mock):
with self.assertNumQueries(0):
response = self.client.get(reverse('api:jobs-retrieve', kwargs={'pk': self.user_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertFalse(ids_mock.called)
self.assertFalse(fetch_mock.called)
@expectedFailure
def test_retrieve_requires_verified(self, ids_mock, fetch_mock):
"""
This test fails due to a bug in the IsVerified permission class.
https://gitlab.com/arkindex/backend/-/issues/554
"""
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.get(reverse('api:jobs-retrieve', kwargs={'pk': self.user_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertFalse(ids_mock.called)
self.assertFalse(fetch_mock.called)
def test_retrieve(self, ids_mock, fetch_mock):
fetch_mock.return_value = self.user_mocked_job
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.get(reverse('api:jobs-retrieve', kwargs={'pk': self.user_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
'id': self.user_mocked_job.id,
'status': 'queued',
'enqueued_at': '2020-01-01T13:37:42Z',
'started_at': None,
'ended_at': None,
'progress': None,
'description': 'something'
})
self.assertEqual(fetch_mock.call_count, 1)
self.assertEqual(fetch_mock.call_args, call(self.user_mocked_job.id))
def test_retrieve_not_found(self, ids_mock, fetch_mock):
fetch_mock.return_value = None
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.get(
reverse('api:jobs-retrieve', kwargs={'pk': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Called once per queue
self.assertEqual(fetch_mock.call_count, 2)
self.assertListEqual(fetch_mock.call_args_list, [
call('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'),
call('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa')
])
def test_retrieve_wrong_user(self, ids_mock, fetch_mock):
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.get(reverse('api:jobs-retrieve', kwargs={'pk': self.superuser_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_destroy_requires_login(self, ids_mock, fetch_mock):
with self.assertNumQueries(0):
response = self.client.delete(reverse('api:jobs-retrieve', kwargs={'pk': self.user_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertFalse(ids_mock.called)
self.assertFalse(fetch_mock.called)
def test_destroy_requires_verified(self, ids_mock, fetch_mock):
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.delete(reverse('api:jobs-retrieve', kwargs={'pk': self.user_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertFalse(ids_mock.called)
self.assertFalse(fetch_mock.called)
def test_destroy(self, ids_mock, fetch_mock):
fetch_mock.return_value = self.user_mocked_job
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.delete(reverse('api:jobs-retrieve', kwargs={'pk': self.user_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(fetch_mock.call_count, 1)
self.assertEqual(fetch_mock.call_args, call(self.user_mocked_job.id))
self.assertEqual(self.user_mocked_job.delete.call_count, 1)
def test_destroy_not_found(self, ids_mock, fetch_mock):
fetch_mock.return_value = None
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.delete(
reverse('api:jobs-retrieve', kwargs={'pk': 'aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(fetch_mock.call_count, 2)
self.assertListEqual(fetch_mock.call_args_list, [
call('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'),
call('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa'),
])
def test_destroy_wrong_user(self, ids_mock, fetch_mock):
fetch_mock.return_value = self.superuser_mocked_job
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.get(reverse('api:jobs-retrieve', kwargs={'pk': self.superuser_mocked_job.id}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(fetch_mock.call_count, 1)
self.assertEqual(fetch_mock.call_args, call(self.superuser_mocked_job.id))
def test_destroy_not_started(self, ids_mock, fetch_mock):
started_job = MockedJob(user_id=self.user.id, status=JobStatus.STARTED)
fetch_mock.return_value = started_job
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.delete(reverse('api:jobs-retrieve', kwargs={'pk': started_job.id}))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertListEqual(response.json(), ['Cannot delete a running job.'])
self.assertEqual(fetch_mock.call_count, 1)
self.assertEqual(fetch_mock.call_args, call(started_job.id))
self.assertEqual(started_job.delete.call_count, 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment