Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (10)
Showing
with 1816 additions and 536 deletions
1.5.0-rc2
1.5.0
......@@ -25,11 +25,12 @@ from arkindex.ponos.keys import load_private_key
from arkindex.ponos.models import Agent, Artifact, Farm, Secret, State, Task
from arkindex.ponos.permissions import (
IsAgent,
IsAgentOrArtifactAdmin,
IsAgentOrTaskAdmin,
IsAgentOrTaskAdminOrReadOnly,
IsAgentOrArtifactGuest,
IsAgentOrTaskGuest,
IsAssignedAgentOrReadOnly,
IsAssignedAgentOrTaskOrReadOnly,
IsTask,
IsTaskAdmin,
)
from arkindex.ponos.renderers import PublicKeyPEMRenderer
from arkindex.ponos.serializers import (
......@@ -87,26 +88,53 @@ class PublicKeyEndpoint(APIView):
@extend_schema(tags=["ponos"])
@extend_schema_view(
get=extend_schema(
operation_id="RetrieveTaskFromAgent", description="Retrieve a Ponos task status"
operation_id="RetrieveTaskFromAgent",
description=dedent("""
Retrieve a Ponos task.
Requires **guest** access on the task's process or Ponos agent authentication.
"""),
),
put=extend_schema(
operation_id="UpdateTaskFromAgent", description="Update a task, from an agent"
operation_id="UpdateTaskFromAgent",
description=dedent("""
Update a task.
Requires authentication as the Ponos agent assigned to the task.
"""),
),
patch=extend_schema(
operation_id="PartialUpdateTaskFromAgent",
description="Partially update a task, from an agent",
description=dedent("""
Partially update a task.
Requires authentication as the Ponos agent assigned to the task.
"""),
),
)
class TaskDetailsFromAgent(RetrieveUpdateAPIView):
"""
Retrieve information about a single task, including its logs.
Authenticated agents assigned to a task can use this endpoint to report its state.
"""
# Avoid stale read when a recently assigned agent wants to update
# the state of one of its tasks
queryset = Task.objects.all().using("default")
permission_classes = (IsAssignedAgentOrReadOnly,)
queryset = Task.objects.using("default").select_related(
# Serialized in responses
'agent__farm',
'gpu',
# Used for permission checks
'process__corpus',
'process__revision__repo',
)
authentication_classes = (
AgentAuthentication,
TokenAuthentication,
SessionAuthentication,
)
permission_classes = (
# On all HTTP methods, require either any Ponos agent, an instance admin, the task itself, or guest access to the process' task
IsAgentOrTaskGuest,
# On unsafe HTTP methods, require a Ponos agent assigned to the task. Both permission classes are combined.
IsAssignedAgentOrReadOnly,
)
serializer_class = TaskSerializer
......@@ -152,11 +180,12 @@ class AgentDetails(RetrieveAPIView):
"""
Retrieve details of an agent including its running tasks
Requires authentication with a verified e-mail address.
Requires authentication with a verified e-mail address. Cannot be used with Ponos agent or task authentication.
"""
authentication_classes = (TokenAuthentication, SessionAuthentication)
permission_classes = (IsVerified, )
serializer_class = AgentDetailsSerializer
queryset = Agent.objects.all()
queryset = Agent.objects.select_related('farm')
@extend_schema(
......@@ -252,41 +281,52 @@ class TaskDefinition(RetrieveAPIView):
@extend_schema_view(
get=extend_schema(
operation_id="ListArtifacts",
description="List all the artifacts of a task",
description=dedent("""
List the artifacts of a task.
Requires **guest** access on the task's process, Ponos agent authentication, or authentication as the Ponos task to list artifacts of.
"""),
),
post=extend_schema(
operation_id="CreateArtifact", description="Create an artifact on a task"
operation_id="CreateArtifact",
description=dedent("""
Create an artifact on task.
Requires authentication as the Ponos task to create an artifact on, or as a Ponos agent assigned to the task.
"""),
),
)
class TaskArtifacts(ListCreateAPIView):
"""
List all artifacts linked to a task or create one
"""
# Used for OpenAPI schema serialization: the ID in the path is the task ID
queryset = Task.objects.none()
permission_classes = (IsAgentOrTaskAdminOrReadOnly, )
serializer_class = ArtifactSerializer
permission_classes = (
# On all HTTP methods, require either any Ponos agent, an instance admin, the task itself, or guest access to the process' task
IsAgentOrTaskGuest,
# On unsafe HTTP methods, require a Ponos agent assigned to the task, or the task itself. Both permission classes are combined.
IsAssignedAgentOrTaskOrReadOnly,
)
# Force no pagination, even when global settings add them
pagination_class = None
def get_task(self):
@property
def task(self):
task = get_object_or_404(
# Select the required tables for permissions checking
Task.objects.select_related('process__corpus', 'process__revision'),
Task.objects.select_related('process__corpus', 'process__revision__repo'),
pk=self.kwargs["pk"],
)
self.check_object_permissions(self.request, task)
return task
def get_queryset(self):
task = self.get_task()
return task.artifacts.all()
return self.task.artifacts.all()
def perform_create(self, serializer):
# Assign task when creating through the API
serializer.save(task=self.get_task())
serializer.save(task=self.task)
class TaskArtifactDownload(APIView):
......@@ -294,10 +334,15 @@ class TaskArtifactDownload(APIView):
Redirects to the S3 URL of an artifact in order to download it.
"""
permission_classes = (IsAgentOrArtifactAdmin, )
permission_classes = (IsAgentOrArtifactGuest, )
def get_object(self, pk, path):
artifact = get_object_or_404(Artifact, task_id=pk, path=path)
artifact = get_object_or_404(
# Select the required tables for permissions checking
Artifact.objects.select_related('task__process__corpus', 'task__process__revision__repo'),
task_id=pk,
path=path,
)
self.check_object_permissions(self.request, artifact)
return artifact
......@@ -319,7 +364,9 @@ class TaskArtifactDownload(APIView):
)
class TaskCreate(CreateAPIView):
"""
Create a task with a parent
Create a task that depends on an existing task.
Requires authentication as a Ponos task. Tasks can only be created on the process of the authenticated task.
"""
authentication_classes = (TaskAuthentication, )
......@@ -330,18 +377,28 @@ class TaskCreate(CreateAPIView):
@extend_schema(tags=["ponos"])
@extend_schema_view(
put=extend_schema(
description="Update a task, allowing humans to change the task's state"
description=dedent("""
Update a task.
Requires **admin** access on the task's process, or to be the creator of the process.
Cannot be used with Ponos agent or task authentication.
"""),
),
patch=extend_schema(
description="Partially update a task, allowing humans to change the task's state"
description=dedent("""
Partially update a task.
Requires **admin** access on the task's process, or to be the creator of the process.
Cannot be used with Ponos agent or task authentication.
"""),
),
)
class TaskUpdate(UpdateAPIView):
"""
Admins and task creators can use this endpoint to update a task
"""
permission_classes = (IsAgentOrTaskAdmin, )
queryset = Task.objects.all()
# Only allow regular users, not Ponos agents or tasks
authentication_classes = (TokenAuthentication, SessionAuthentication)
# Only allow regular users that have admin access to the task's process
permission_classes = (IsTaskAdmin, )
queryset = Task.objects.select_related('process__corpus', 'process__revision__repo')
serializer_class = TaskTinySerializer
......
from rest_framework.permissions import SAFE_METHODS
from arkindex.ponos.models import Task
from arkindex.process.models import Process
from arkindex.project.mixins import CorpusACLMixin, ProcessACLMixin
from arkindex.project.permissions import IsAuthenticated
from arkindex.project.mixins import ProcessACLMixin
from arkindex.project.permissions import IsAuthenticated, IsVerified
from arkindex.users.models import Role
......@@ -29,6 +28,13 @@ class IsAgent(IsAuthenticated):
checks = IsAuthenticated.checks + (require_agent_or_admin, )
class IsAgentOrTask(IsAuthenticated):
"""
Only allow Ponos agents, tasks, and admins.
"""
checks = IsAuthenticated.checks + (require_agent_or_task, )
class IsAgentOrReadOnly(IsAgent):
"""
Restricts write access to Ponos agents and admins,
......@@ -54,67 +60,63 @@ class IsAssignedAgentOrReadOnly(IsAgentOrReadOnly):
return super().has_object_permission(request, view, obj)
class IsAgentOrTaskAdmin(CorpusACLMixin, IsAuthenticated):
class IsAssignedAgentOrTaskOrReadOnly(IsAgentOrTask):
"""
Permission to access a task with high privilege
Restricts write access to Ponos agents, Ponos tasks, and admins, and allows read access to anyone.
When checking object write permissions for a Ponos task, requires either a Ponos agent assigned to the task,
or authentication as the task itself.
"""
allow_safe_methods = True
Allowed for admins, agents, creators of the task's process,
and users with an admin right on the process' corpus.
def has_object_permission(self, request, view, obj) -> bool:
assert isinstance(obj, Task)
if isinstance(request.auth, Task):
return obj == request.auth
return super().has_object_permission(request, view, obj) and (
obj.agent_id == request.user.id or request.method in SAFE_METHODS
)
class IsTaskAdmin(ProcessACLMixin, IsVerified):
"""
Allow instance admins and users with a verified email and admin access to the task's process.
"""
def has_object_permission(self, request, view, task):
# Add request to attributes for the ACL mixin to work with self.user
self.request = request
return (
require_agent_or_admin(request, view)
or (
task.process is not None
and task.process.corpus_id is not None
and self.has_admin_access(task.process.corpus)
)
)
level = self.process_access_level(task.process)
# process_access_level can return None if there is no access at all
return level and level >= Role.Admin.value
class IsAgentOrTaskAdminOrReadOnly(ProcessACLMixin, IsAuthenticated):
class IsAgentOrTaskGuest(ProcessACLMixin, IsAuthenticated):
"""
Instance admins, agents, process admins, and the task itself are always allowed.
For GET/HEAD, only a Guest level on the process is required for regular users.
Allow admins, Ponos agents, users with a verified email and guest access to the task's process, or the task itself.
"""
def has_object_permission(self, request, view, task):
# Allow agents and instance admins
if require_agent_or_admin(request, view):
return True
# Allow a task to access itself
if isinstance(request.auth, Task) and request.auth.id == task.id:
return True
assert isinstance(task, Task)
# Add request to attributes for the ACL mixin to work with self.user
self.request = request
try:
level = self.process_access_level(task.process)
except Process.DoesNotExist:
# Reject if the task has no process
return False
# Require *some* access to the process
if level is None:
return False
# Require only a guest access for GET/HEAD
if request.method in SAFE_METHODS:
return level >= Role.Guest.value
# Require admin access for other methods
return level >= Role.Admin.value
return (
task == request.auth
or require_agent_or_admin(request, view)
or (
getattr(request.user, 'verified_email', False)
# process_access_level can return None if there is no access at all
and (self.process_access_level(task.process) or 0) >= Role.Guest.value
)
)
class IsAgentOrArtifactAdmin(IsAgentOrTaskAdmin):
class IsAgentOrArtifactGuest(IsAgentOrTaskGuest):
"""
Permission to access an artifact with high privilege, based on
access to the artifact's task through IsAgentOrTaskAdmin.
Permission to access an artifact, based on access to the artifact's task through IsAgentOrTaskGuest.
"""
def has_object_permission(self, request, view, artifact):
......
......@@ -506,7 +506,7 @@ class ArtifactSerializer(serializers.ModelSerializer):
def validate_path(self, path):
"""Check that no artifacts with this path already exist in DB"""
task = self.context["view"].get_task()
task = self.context["view"].task
if task.artifacts.filter(path=path).exists():
raise ValidationError("An artifact with this path already exists")
......
This diff is collapsed.
This diff is collapsed.
......@@ -48,6 +48,7 @@ from arkindex.process.models import (
GitRef,
GitRefType,
Process,
ProcessDataset,
ProcessMode,
Repository,
Revision,
......@@ -78,7 +79,7 @@ from arkindex.process.serializers.imports import (
StartProcessSerializer,
)
from arkindex.process.serializers.ingest import BucketSerializer, S3ImportSerializer
from arkindex.process.serializers.training import StartTrainingSerializer
from arkindex.process.serializers.training import ProcessDatasetSerializer, StartTrainingSerializer
from arkindex.process.serializers.worker_runs import WorkerRunEditSerializer, WorkerRunSerializer
from arkindex.process.serializers.workers import (
DockerWorkerVersionSerializer,
......@@ -108,6 +109,8 @@ from arkindex.project.pagination import CustomCursorPagination
from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import PercentileCont, RTrimChr
from arkindex.project.triggers import process_delete
from arkindex.training.models import Dataset
from arkindex.training.serializers import DatasetSerializer
from arkindex.users.models import OAuthCredentials, Role, Scope
from arkindex.users.utils import get_max_level
......@@ -557,6 +560,7 @@ class StartProcess(CorpusACLMixin, CreateAPIView):
.select_related('corpus')
.filter(corpus_id__isnull=False)
.prefetch_related(Prefetch('worker_runs', queryset=WorkerRun.objects.select_related('version', 'model_version')))
.prefetch_related('datasets')
# Uses Exists() for has_tasks and not a __isnull because we are not joining on tasks and do not need to fetch them
.annotate(has_tasks=Exists(Task.objects.filter(process=OuterRef('pk'))))
)
......@@ -567,9 +571,9 @@ class StartProcess(CorpusACLMixin, CreateAPIView):
if not self.has_admin_access(process.corpus):
raise PermissionDenied(detail='You do not have an admin access to the corpus of this process.')
if process.mode != ProcessMode.Workers or process.has_tasks:
if process.mode not in (ProcessMode.Workers, ProcessMode.Dataset) or process.has_tasks:
raise ValidationError(
{'__all__': ['Only a Process with Workers mode and not already launched can be started later on']})
{'__all__': ['Only a Process with Workers or Dataset mode and not already launched can be started later on']})
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data, instance=self.get_object())
......@@ -652,6 +656,92 @@ class DataFileCreate(CreateAPIView):
serializer_class = DataFileCreateSerializer
@extend_schema(tags=['process'])
@extend_schema_view(
get=extend_schema(
operation_id='ListProcessDatasets',
description=dedent(
"""
List all datasets on a process.
Requires a **guest** access to the process.
"""
),
),
)
class ProcessDatasets(ProcessACLMixin, ListAPIView):
permission_classes = (IsVerified, )
serializer_class = DatasetSerializer
queryset = Dataset.objects.none()
@cached_property
def process(self):
process = get_object_or_404(
Process.objects.using('default').select_related('corpus', 'revision__repo'),
Q(pk=self.kwargs['pk'])
)
if not self.process_access_level(process):
raise PermissionDenied(detail='You do not have guest access to this process.')
return process
def get_queryset(self):
return self.process.datasets.select_related('creator').order_by('name')
def get_serializer_context(self):
context = super().get_serializer_context()
# Ignore this step when generating the schema with OpenAPI
if not self.kwargs:
return context
context['process'] = self.process
return context
@extend_schema(tags=['process'])
@extend_schema_view(
post=extend_schema(
operation_id='CreateProcessDataset',
description=dedent(
"""
Add a dataset to a process.
Requires an **admin** access to the process and a **guest** access to the dataset's corpus.
"""
),
),
delete=extend_schema(
operation_id='DestroyProcessDataset',
description=dedent(
"""
Remove a dataset from a process.
Requires an **admin** access to the process.
"""
),
),
)
class ProcessDatasetManage(CreateAPIView, DestroyAPIView):
permission_classes = (IsVerified, )
serializer_class = ProcessDatasetSerializer
def get_serializer_from_params(self, process=None, dataset=None, **kwargs):
data = {'process': process, 'dataset': dataset}
kwargs['context'] = self.get_serializer_context()
return ProcessDatasetSerializer(data=data, **kwargs)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer_from_params(**kwargs)
serializer.is_valid(raise_exception=True)
serializer.create(serializer.validated_data)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def destroy(self, request, *args, **kwargs):
serializer = self.get_serializer_from_params(**kwargs)
serializer.is_valid(raise_exception=True)
get_object_or_404(ProcessDataset, **serializer.validated_data).delete()
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema(exclude=True)
class GitRepositoryImportHook(APIView):
"""
......
......@@ -467,6 +467,7 @@ class Process(IndexableModel):
Create Ponos tasks according to configuration
"""
ml_workflow_chunks = 1
import_task = None
import_task_name = 'import'
# Use the default Ponos farm if no farm is specified
self.farm_id = farm.id if farm is not None else get_default_farm_id()
......@@ -619,12 +620,19 @@ class Process(IndexableModel):
)
import_task.save()
elif self.mode == ProcessMode.Dataset:
# For Dataset processes, there is no initial task
import_task_name = None
if chunks is not None:
assert chunks <= settings.MAX_CHUNKS, f'Import distribution is limited to {settings.MAX_CHUNKS} chunks'
ml_workflow_chunks = chunks
else:
raise NotImplementedError
# Handle chunks and thumbnails generation on processes that list elements during the initial task
# The S3 import does not generate a list of elements yet
if self.mode == ProcessMode.Workers or (thumbnails and self.mode in (
if self.mode in (ProcessMode.Workers, ProcessMode.Dataset) or (thumbnails and self.mode in (
ProcessMode.Transkribus,
ProcessMode.IIIF,
ProcessMode.Files
......@@ -633,8 +641,8 @@ class Process(IndexableModel):
if ml_workflow_chunks > 1:
elts_chunk_files = [f'elements_chunk_{n}.json' for n in range(1, ml_workflow_chunks + 1)]
if self.mode == ProcessMode.Workers:
# Build up tasks for each worker run on each chunk
if self.mode in (ProcessMode.Workers, ProcessMode.Dataset):
# Retrieve worker runs
worker_runs = list(self.worker_runs.select_related('version__worker__repository', 'model_version').using('default'))
else:
worker_runs = []
......@@ -647,9 +655,12 @@ class Process(IndexableModel):
tasks_to_create = []
# Holds the task parents relationships, so that they can be built after building tasks
# {task.slug: [parent_slug1, …]}
parents = {
import_task.slug: [],
}
if import_task:
parents = {
import_task.slug: [],
}
else:
parents = {}
for index, elts_chunk in enumerate(elts_chunk_files, start=1):
# Add a name suffix if task is distributed
......@@ -659,10 +670,17 @@ class Process(IndexableModel):
task_suffix = f'_{index}'
chunk = index
elements_path = shlex.quote(path.join('/data', import_task_name, elts_chunk))
elements_path = None
if self.mode != ProcessMode.Dataset:
# No element handling in process for Dataset processes
elements_path = shlex.quote(path.join('/data', import_task_name, elts_chunk))
# Generate thumbnails directly after import step
if thumbnails:
# Thumbnails generation can't work with dataset processes because it requires the element.json that
# is generated in the initialisation task, and there is not initialisation task in dataset processes
assert self.mode != ProcessMode.Dataset, 'Thumbnails generation is incompatible with dataset mode processes.'
thumbnails_task = self.build_task(
command=f'python3 -m arkindex_tasks.generate_thumbnails {elements_path}',
slug=f'thumbnails{task_suffix}',
......@@ -670,7 +688,10 @@ class Process(IndexableModel):
env=env,
)
tasks_to_create.append(thumbnails_task)
parents[thumbnails_task.slug] = [import_task.slug]
if import_task:
parents[thumbnails_task.slug] = [import_task.slug]
else:
parents[thumbnails_task.slug] = []
# Generate a task for each WorkerRun on the Process
for worker_run in worker_runs:
......@@ -697,7 +718,8 @@ class Process(IndexableModel):
# Use a tasks by slug map to find tasks more easily
tasks = {task.slug: task for task in tasks_to_create}
tasks[import_task.slug] = import_task
if import_task:
tasks[import_task.slug] = import_task
Task.parents.through.objects.bulk_create(
Task.parents.through(
......@@ -1184,8 +1206,10 @@ class WorkerRun(models.Model):
f'{worker_run.version.slug}{suffix}'
for worker_run in parent_runs
]
else:
elif import_task_name:
parents = [import_task_name]
else:
parents = []
assert (
self.version.state == WorkerVersionState.Available
......
......@@ -332,6 +332,15 @@ class StartProcessSerializer(serializers.Serializer):
assert self.instance is not None, 'A Process instance is required for this serializer'
errors = defaultdict(list)
if self.instance.mode == ProcessMode.Dataset:
# Only call .count() and .all() as they will access the prefetched datasets and not cause any extra query
if not self.instance.datasets.count():
errors['non_field_errors'].append('A dataset process cannot be started if it does not have any associated datasets.')
elif not any(dataset.corpus_id == self.instance.corpus.id for dataset in self.instance.datasets.all()):
errors['non_field_errors'].append('At least one of the process datasets must be from the same corpus as the process.')
if validated_data.get('thumbnails'):
errors['thumbnails'].append('Thumbnails generation is incompatible with dataset mode processes.')
# Use process.worker_runs.all() to access the (prefetched) worker_runs to avoid new SQL queries
# The related worker versions have also been prefetched
if len(list(self.instance.worker_runs.all())) > 0:
......
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import PermissionDenied, ValidationError
from arkindex.documents.models import Corpus, Element
from arkindex.process.models import (
Process,
ProcessDataset,
ProcessMode,
WorkerConfiguration,
WorkerVersion,
WorkerVersionGPUUsage,
WorkerVersionState,
)
from arkindex.project.mixins import TrainingModelMixin, WorkerACLMixin
from arkindex.training.models import Model, ModelVersion
from arkindex.project.mixins import ProcessACLMixin, TrainingModelMixin, WorkerACLMixin
from arkindex.training.models import Dataset, Model, ModelVersion
from arkindex.users.models import Role
class StartTrainingSerializer(serializers.ModelSerializer, WorkerACLMixin, TrainingModelMixin):
......@@ -192,3 +194,43 @@ class StartTrainingSerializer(serializers.ModelSerializer, WorkerACLMixin, Train
use_gpu=validated_data["use_gpu"],
)
return self.instance
class ProcessDatasetSerializer(serializers.ModelSerializer, ProcessACLMixin):
process = serializers.PrimaryKeyRelatedField(
queryset=Process.objects.using('default').select_related('corpus'),
style={'base_template': 'input.html'},
)
dataset = serializers.PrimaryKeyRelatedField(
queryset=Dataset.objects.none(),
style={'base_template': 'input.html'},
)
class Meta():
model = ProcessDataset
fields = ('dataset', 'process', 'id', )
read_only_fields = ('process', 'id', )
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.context.get('request'):
# Do not raise Error in order to create OpenAPI schema
return
request_method = self.context['request'].method
# Required for the ProcessACLMixin and readable corpora
self._user = self.context['request'].user
if request_method == 'DELETE':
# Allow deleting ProcessDatasets even if the user looses access to the corpus
self.fields['dataset'].queryset = Dataset.objects.all()
else:
self.fields['dataset'].queryset = Dataset.objects.filter(corpus__in=Corpus.objects.readable(self._user))
def validate_process(self, process):
if process.mode != ProcessMode.Dataset:
raise ValidationError(detail='Datasets can only be added to or removed from processes of mode "dataset".')
access = self.process_access_level(process)
if not access or not (access >= Role.Admin.value):
raise PermissionDenied(detail='You do not have admin access to this process.')
return process
......@@ -587,7 +587,7 @@ class TestCreateProcess(FixtureAPITestCase):
self.assertFalse(self.corpus.worker_versions.exists())
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process_2.id)})
)
......@@ -674,7 +674,7 @@ class TestCreateProcess(FixtureAPITestCase):
process_2.use_cache = True
process_2.save()
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process_2.id)})
)
......@@ -721,7 +721,7 @@ class TestCreateProcess(FixtureAPITestCase):
process_2.use_gpu = True
process_2.save()
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(reverse('api:process-start', kwargs={'pk': str(process_2.id)}))
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
......@@ -789,7 +789,7 @@ class TestCreateProcess(FixtureAPITestCase):
process.use_gpu = True
process.save()
self.client.force_login(self.user)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)}),
{'use_gpu': 'true'}
......@@ -813,7 +813,7 @@ class TestCreateProcess(FixtureAPITestCase):
process.use_gpu = True
process.save()
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)}),
{'use_gpu': 'true'}
......
import uuid
from unittest.mock import patch
from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.process.models import Process, ProcessDataset, ProcessMode
from arkindex.project.tests import FixtureAPITestCase
from arkindex.training.models import Dataset
from arkindex.users.models import Role, User
# Using the fake DB fixtures creation date when needed
FAKE_CREATED = '2020-02-02T01:23:45.678000Z'
class TestProcessDatasets(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.private_corpus = Corpus.objects.create(name='Private corpus')
with patch('django.utils.timezone.now') as mock_now:
mock_now.return_value = FAKE_CREATED
cls.private_dataset = cls.private_corpus.datasets.create(
name='Dead sea scrolls',
description='Human instrumentality manual',
creator=cls.user
)
cls.test_user = User.objects.create(email='katsuragi@nerv.co.jp', verified_email=True)
cls.private_corpus.memberships.create(user=cls.test_user, level=Role.Admin.value)
# Datasets from another corpus
cls.dataset1, cls.dataset2 = Dataset.objects.filter(corpus=cls.corpus).order_by('name')
cls.dataset_process = Process.objects.create(
creator_id=cls.user.id,
mode=ProcessMode.Dataset,
corpus_id=cls.private_corpus.id
)
cls.dataset_process.datasets.set([cls.dataset1, cls.private_dataset])
# Control process to check that its datasets are not retrieved
cls.dataset_process_2 = Process.objects.create(
creator_id=cls.user.id,
mode=ProcessMode.Dataset,
corpus_id=cls.corpus.id
)
cls.dataset_process_2.datasets.set([cls.dataset2])
# For repository process
cls.creds = cls.user.credentials.get()
cls.repo = cls.creds.repos.get(url='http://my_repo.fake/workers/worker')
cls.repo.memberships.create(user=cls.test_user, level=Role.Admin.value)
cls.rev = cls.repo.revisions.get()
# List process datasets
def test_list_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_list_process_does_not_exist(self):
self.client.force_login(self.test_user)
with self.assertNumQueries(3):
response = self.client.get(reverse('api:process-datasets', kwargs={'pk': str(uuid.uuid4())}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_list_process_access_level(self):
self.private_corpus.memberships.filter(user=self.test_user).delete()
self.client.force_login(self.test_user)
with self.assertNumQueries(5):
response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You do not have guest access to this process.'})
def test_list_process_datasets(self):
self.client.force_login(self.test_user)
with self.assertNumQueries(8):
response = self.client.get(reverse('api:process-datasets', kwargs={'pk': self.dataset_process.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()['results'], [
{
'id': str(self.private_dataset.id),
'name': 'Dead sea scrolls',
'description': 'Human instrumentality manual',
'creator': 'Test user',
'sets': ['training', 'test', 'validation'],
'corpus_id': str(self.private_corpus.id),
'state': 'open',
'task_id': None,
'created': FAKE_CREATED,
'updated': FAKE_CREATED
},
{
'id': str(self.dataset1.id),
'name': 'First Dataset',
'description': 'dataset number one',
'creator': 'Test user',
'sets': ['training', 'test', 'validation'],
'corpus_id': str(self.corpus.id),
'state': 'open',
'task_id': None,
'created': FAKE_CREATED,
'updated': FAKE_CREATED
}
])
# Create process dataset
def test_create_process_dataset_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_create_process_dataset_requires_verified(self):
unverified_user = User.objects.create(email='email@mail.com')
self.client.force_login(unverified_user)
with self.assertNumQueries(2):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_create_process_dataset_access_level(self):
cases = [None, Role.Guest, Role.Contributor]
for level in cases:
with self.subTest(level=level):
self.private_corpus.memberships.filter(user=self.test_user).delete()
if level:
self.private_corpus.memberships.create(user=self.test_user, level=level.value)
self.client.force_login(self.test_user)
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.json(), {'detail': 'You do not have admin access to this process.'})
def test_create_process_dataset_process_mode(self):
cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local}
for mode in cases:
with self.subTest(mode=mode):
self.dataset_process.mode = mode
self.dataset_process.corpus = self.private_corpus
if mode == ProcessMode.Repository:
self.dataset_process.corpus = None
self.dataset_process.revision = self.rev
self.dataset_process.save()
self.client.force_login(self.test_user)
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'process': ['Datasets can only be added to or removed from processes of mode "dataset".']})
def test_create_process_dataset_process_mode_local(self):
self.client.force_login(self.user)
local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local)
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': local_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'process': ['Datasets can only be added to or removed from processes of mode "dataset".']})
def test_create_process_dataset_wrong_process_uuid(self):
self.client.force_login(self.test_user)
wrong_id = uuid.uuid4()
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': wrong_id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'process': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
def test_create_process_dataset_wrong_dataset_uuid(self):
self.client.force_login(self.test_user)
wrong_id = uuid.uuid4()
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': wrong_id}),
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'dataset': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
def test_create_process_dataset_dataset_access(self):
new_corpus = Corpus.objects.create(name='NERV')
new_dataset = new_corpus.datasets.create(name='Eva series', description='We created the Evas from Adam', creator=self.user)
self.client.force_login(self.test_user)
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': new_dataset.id}),
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'dataset': [f'Invalid pk "{str(new_dataset.id)}" - object does not exist.']})
def test_create_process_dataset(self):
self.client.force_login(self.test_user)
self.assertEqual(ProcessDataset.objects.count(), 3)
self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(ProcessDataset.objects.count(), 4)
self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process.id, dataset=self.dataset2.id).exists())
self.assertQuerysetEqual(self.dataset_process.datasets.order_by('name'), [
self.private_dataset,
self.dataset1,
self.dataset2
])
# Destroy process dataset
def test_destroy_requires_login(self):
with self.assertNumQueries(0):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.private_dataset.id}),
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_destroy_process_does_not_exist(self):
self.client.force_login(self.test_user)
wrong_id = uuid.uuid4()
with self.assertNumQueries(4):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': wrong_id, 'dataset': self.private_dataset.id})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'process': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
def test_destroy_dataset_does_not_exist(self):
self.client.force_login(self.test_user)
wrong_id = uuid.uuid4()
with self.assertNumQueries(7):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': wrong_id})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'dataset': [f'Invalid pk "{str(wrong_id)}" - object does not exist.']})
def test_destroy_not_found(self):
self.assertFalse(self.dataset_process.datasets.filter(id=self.dataset2.id).exists())
self.client.force_login(self.test_user)
with self.assertNumQueries(8):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_destroy_process_access_level(self):
self.private_corpus.memberships.filter(user=self.test_user).delete()
self.client.force_login(self.test_user)
with self.assertNumQueries(6):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.private_dataset.id})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You do not have admin access to this process.'})
def test_destroy_no_dataset_access_requirement(self):
new_corpus = Corpus.objects.create(name='NERV')
new_dataset = new_corpus.datasets.create(name='Eva series', description='We created the Evas from Adam', creator=self.user)
self.dataset_process.datasets.add(new_dataset)
self.assertTrue(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists())
self.client.force_login(self.test_user)
with self.assertNumQueries(9):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': new_dataset.id}),
)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process, dataset=new_dataset).exists())
def test_destroy_process_mode(self):
cases = set(ProcessMode) - {ProcessMode.Dataset, ProcessMode.Local}
for mode in cases:
with self.subTest(mode=mode):
self.dataset_process.mode = mode
self.dataset_process.corpus = self.private_corpus
if mode == ProcessMode.Repository:
self.dataset_process.corpus = None
self.dataset_process.revision = self.rev
self.dataset_process.save()
self.client.force_login(self.test_user)
with self.assertNumQueries(4):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'process': ['Datasets can only be added to or removed from processes of mode "dataset".']})
def test_destroy_process_mode_local(self):
self.client.force_login(self.user)
local_process = Process.objects.get(creator=self.user, mode=ProcessMode.Local)
with self.assertNumQueries(4):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': local_process.id, 'dataset': self.dataset2.id}),
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {'process': ['Datasets can only be added to or removed from processes of mode "dataset".']})
def test_destroy(self):
self.client.force_login(self.test_user)
with self.assertNumQueries(9):
response = self.client.delete(
reverse('api:process-dataset', kwargs={'process': self.dataset_process.id, 'dataset': self.dataset1.id}),
)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertFalse(ProcessDataset.objects.filter(process=self.dataset_process, dataset=self.dataset1).exists())
......@@ -19,7 +19,7 @@ from arkindex.process.models import (
)
from arkindex.process.utils import get_default_farm_id
from arkindex.project.tests import FixtureAPITestCase
from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.training.models import Dataset, Model, ModelVersion, ModelVersionState
from arkindex.users.models import Role, User
......@@ -35,6 +35,13 @@ class TestProcesses(FixtureAPITestCase):
cls.creds = cls.user.credentials.get()
cls.repo = cls.creds.repos.get(url='http://my_repo.fake/workers/worker')
cls.rev = cls.repo.revisions.get()
cls.dataset1, cls.dataset2 = Dataset.objects.filter(corpus=cls.corpus).order_by('name')
cls.private_corpus = Corpus.objects.create(name='Private corpus')
cls.private_dataset = cls.private_corpus.datasets.create(
name='Dead sea scrolls',
description='Human instrumentality manual',
creator=cls.user
)
cls.img_df = cls.corpus.files.create(
name='test.jpg',
size=42,
......@@ -1889,7 +1896,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertEqual(response.json(), {'detail': 'Not found.'})
def test_start_process_wrong_process_mode(self):
self.assertNotEqual(self.user_img_process.mode, ProcessMode.Workers)
self.assertFalse(self.user_img_process.mode in (ProcessMode.Workers, ProcessMode.Dataset))
# grant an admin access to this process
self.user_img_process.corpus.memberships.create(user=self.user, level=Role.Admin.value)
self.client.force_login(self.user)
......@@ -1899,7 +1906,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{'__all__': ['Only a Process with Workers mode and not already launched can be started later on']}
{'__all__': ['Only a Process with Workers or Dataset mode and not already launched can be started later on']}
)
def test_start_process_process_already_started(self):
......@@ -1914,7 +1921,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{'__all__': ['Only a Process with Workers mode and not already launched can be started later on']}
{'__all__': ['Only a Process with Workers or Dataset mode and not already launched can be started later on']}
)
def test_start_process_without_required_model(self):
......@@ -1940,7 +1947,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
......@@ -1951,7 +1958,7 @@ class TestProcesses(FixtureAPITestCase):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
self.client.force_login(self.user)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
......@@ -1970,7 +1977,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
......@@ -1988,7 +1995,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
......@@ -1998,7 +2005,7 @@ class TestProcesses(FixtureAPITestCase):
{'model_version': ['This process contains one or more unavailable model versions and cannot be started.']},
)
def test_start_process(self):
def test_start_process_workers(self):
"""
A user can start a process with no parameters.
Default chunks, thumbnails and farm are used. Nor cache or workers activity is set.
......@@ -2009,7 +2016,7 @@ class TestProcesses(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
......@@ -2026,6 +2033,65 @@ class TestProcesses(FixtureAPITestCase):
self.assertEqual(task2.slug, f'reco_{str(self.recognizer.id)[:6]}')
self.assertIn('--chunks-number 1', task1.command)
def test_start_process_dataset_requires_datasets(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'non_field_errors': ['A dataset process cannot be started if it does not have any associated datasets.']
})
def test_start_process_dataset_requires_dataset_in_same_corpus(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
process2.datasets.set([self.private_dataset])
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'non_field_errors': ['At least one of the process datasets must be from the same corpus as the process.']
})
def test_start_process_dataset(self):
process2 = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
process2.datasets.set([self.dataset1, self.private_dataset])
process2.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.assertFalse(process2.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(16):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process2.id)})
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.json()['id'], str(process2.id))
process2.refresh_from_db()
self.assertEqual(process2.state, State.Unscheduled)
# Ensure default parameters are used
self.assertEqual(process2.farm_id, get_default_farm_id())
self.assertEqual(process2.tasks.count(), 1)
task = process2.tasks.get()
self.assertEqual(task.slug, f'reco_{str(self.recognizer.id)[:6]}')
self.assertQuerysetEqual(process2.datasets.order_by('name'), [
self.private_dataset, self.dataset1
])
def test_start_process_from_docker_image(self):
"""
Start a process with an available WorkerVersion that only has a docker_image_iid
......@@ -2040,7 +2106,7 @@ class TestProcesses(FixtureAPITestCase):
self.assertFalse(process.tasks.exists())
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)})
)
......@@ -2065,7 +2131,7 @@ class TestProcesses(FixtureAPITestCase):
workers_process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user)
with self.assertNumQueries(19):
with self.assertNumQueries(20):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(workers_process.id)}),
{'farm': str(barley_farm.id)}
......@@ -2084,7 +2150,7 @@ class TestProcesses(FixtureAPITestCase):
self.client.force_login(self.user)
wrong_farm_id = uuid.uuid4()
with self.assertNumQueries(8):
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(workers_process.id)}),
{'farm': str(wrong_farm_id)}
......@@ -2103,7 +2169,7 @@ class TestProcesses(FixtureAPITestCase):
({'thumbnails': 'gloubiboulga'}, {'thumbnails': ['Must be a valid boolean.']})
]
for (params, check) in wrong_params_checks:
with self.subTest(**params), self.assertNumQueries(7):
with self.subTest(**params), self.assertNumQueries(8):
response = self.client.post(reverse('api:process-start', kwargs={'pk': str(process.id)}), params)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), check)
......@@ -2117,7 +2183,7 @@ class TestProcesses(FixtureAPITestCase):
process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)}),
{'chunks': 43},
......@@ -2128,9 +2194,9 @@ class TestProcesses(FixtureAPITestCase):
'chunks': ['Ensure this value is less than or equal to 42.'],
})
def test_start_process_parameters(self):
def test_start_process_workers_parameters(self):
"""
It should be possible to pass chunks and thumbnails parameters when starting a process
It should be possible to pass chunks and thumbnails parameters when starting a workers process
"""
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Workers)
# Add a worker run to this process
......@@ -2154,6 +2220,46 @@ class TestProcesses(FixtureAPITestCase):
'thumbnails_3'
])
def test_start_process_dataset_chunks(self):
"""
It should be possible to pass chunks when starting a dataset process
"""
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
process.datasets.set([self.dataset1, self.dataset2])
# Add a worker run to this process
process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user)
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)}),
{'chunks': 3}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
process.refresh_from_db()
self.assertEqual(list(process.tasks.order_by('slug').values_list('slug', flat=True)), [
f'reco_{str(self.recognizer.id)[:6]}_1',
f'reco_{str(self.recognizer.id)[:6]}_2',
f'reco_{str(self.recognizer.id)[:6]}_3'
])
def test_start_process_dataset_no_thumbnails(self):
"""
It is not possible to pass thumbnails when starting a dataset process
"""
process = self.corpus.processes.create(creator=self.user, mode=ProcessMode.Dataset)
process.datasets.set([self.dataset1, self.dataset2])
# Add a worker run to this process
process.worker_runs.create(version=self.recognizer, parents=[], configuration=None)
self.client.force_login(self.user)
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)}),
{'thumbnails': True}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'thumbnails': ['Thumbnails generation is incompatible with dataset mode processes.']})
@patch('arkindex.process.models.Process.worker_runs')
@patch('arkindex.project.triggers.process_tasks.initialize_activity.delay')
def test_start_process_options_requires_workers(self, activities_delay_mock, worker_runs_mock):
......@@ -2166,7 +2272,7 @@ class TestProcesses(FixtureAPITestCase):
element_type=self.corpus.types.get(slug='page')
)
self.client.force_login(self.user)
with self.assertNumQueries(6):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)}),
{'use_cache': 'true', 'worker_activity': 'true', 'use_gpu': 'true'}
......@@ -2206,7 +2312,7 @@ class TestProcesses(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(14):
with self.assertNumQueries(15):
response = self.client.post(
reverse('api:process-start', kwargs={'pk': str(process.id)}),
{'use_cache': 'true', 'worker_activity': 'true', 'use_gpu': 'true'}
......
......@@ -92,6 +92,8 @@ from arkindex.process.api import (
GitRepositoryImportHook,
ImportTranskribus,
ListProcessElements,
ProcessDatasetManage,
ProcessDatasets,
ProcessDetails,
ProcessList,
ProcessRetry,
......@@ -289,6 +291,8 @@ api = [
path('process/<uuid:pk>/clear/', ClearProcess.as_view(), name='clear-process'),
path('process/training/', StartTraining.as_view(), name='process-training'),
path('process/<uuid:pk>/select-failures/', SelectProcessFailures.as_view(), name='process-select-failures'),
path('process/<uuid:pk>/datasets/', ProcessDatasets.as_view(), name='process-datasets'),
path('process/<uuid:process>/dataset/<uuid:dataset>/', ProcessDatasetManage.as_view(), name='process-dataset'),
# ML models training
path('modelversion/<uuid:pk>/', ModelVersionsRetrieve.as_view(), name='model-version-retrieve'),
......
from django.test import override_settings
from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.ponos.authentication import AgentUser
from arkindex.ponos.models import Agent, Artifact, Farm
from arkindex.process.models import ProcessMode
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role, User
@override_settings(PONOS_PRIVATE_KEY='staging')
class TestPonosView(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.creator = User.objects.create(email="creator@user.me")
cls.artifact = Artifact.objects.get(path='/path/to/docker_build')
cls.task = cls.artifact.task
# Assign a corpus to the task's process so we can test the process corpus permissions
cls.process_corpus = Corpus.objects.create(name='Another public corpus', public=True)
cls.corpus_admin = User.objects.create(email='corpusadmin@test.me')
cls.corpus_admin.rights.create(content_object=cls.process_corpus, level=Role.Admin.value)
process = cls.task.process
process.mode = ProcessMode.Files
process.corpus = cls.process_corpus
process.save()
cls.agent = Agent.objects.create(
cpu_cores=3,
cpu_frequency=3e9,
farm=Farm.objects.create(),
ram_total=2e9,
last_ping='1999-09-09',
)
def test_retrieve_agent_requires_login(self):
"""
Only authenticated users should have the ability to retrieve details of an agent
"""
response = self.client.get(reverse('api:agent-details', kwargs={'pk': str(self.agent.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_retrieve_agent_requires_verified(self):
"""
Only verified users should have the ability to retrieve details of an agent
"""
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
response = self.client.get(reverse('api:agent-details', kwargs={'pk': str(self.agent.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_retrieve_agent(self):
self.client.force_login(self.user)
response = self.client.get(reverse('api:agent-details', kwargs={'pk': str(self.agent.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_update_task(self):
"""
Only users with an admin privilege have the ability to update a process task
"""
test_cases = (
(None, status.HTTP_403_FORBIDDEN, 0),
(self.creator, status.HTTP_403_FORBIDDEN, 8),
(self.user, status.HTTP_403_FORBIDDEN, 8),
(self.superuser, status.HTTP_200_OK, 10),
(self.corpus_admin, status.HTTP_200_OK, 14),
)
for user, status_code, requests_count in test_cases:
with self.subTest(user=user):
if user:
self.client.force_login(user)
with self.assertNumQueries(requests_count):
response = self.client.patch(
reverse('api:task-update', kwargs={'pk': str(self.task.id)}),
json={'state': 'stopping'}
)
self.assertEqual(response.status_code, status_code)
def test_download_artifacts(self):
"""
Only users with an admin privilege have the ability to
download an artifact of a process task
"""
test_cases = (
(None, status.HTTP_403_FORBIDDEN, 0),
(self.creator, status.HTTP_403_FORBIDDEN, 9),
(self.user, status.HTTP_403_FORBIDDEN, 9),
(self.superuser, status.HTTP_302_FOUND, 4),
(self.corpus_admin, status.HTTP_302_FOUND, 9),
)
for user, status_code, requests_count in test_cases:
with self.subTest(user=user):
if user:
self.client.force_login(user)
with self.assertNumQueries(requests_count):
response = self.client.get(
reverse('api:task-artifact-download', kwargs={'pk': str(self.task.id), 'path': self.artifact.path}),
follow=False
)
self.assertEqual(response.status_code, status_code)
def test_download_artifacts_by_agent(self):
"""
Agents should still be able to download artifacts in order
to run followup tasks
"""
agent_user = AgentUser.objects.create(
cpu_cores=3,
cpu_frequency=3e9,
farm_id=Farm.objects.create().id,
ram_total=2e9,
last_ping='1999-09-09'
)
with self.assertNumQueries(3):
response = self.client.get(
reverse('api:task-artifact-download', kwargs={'pk': str(self.task.id), 'path': self.artifact.path}),
follow=False,
HTTP_AUTHORIZATION="Bearer {}".format(agent_user.token.access_token),
)
self.assertEqual(response.status_code, status.HTTP_302_FOUND)