Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (65)
Showing
with 761 additions and 116 deletions
......@@ -51,7 +51,7 @@ backend-tests:
stage: test
services:
- name: postgis/postgis:12-3.0
- name: postgis/postgis:12-3.1
alias: postgres
artifacts:
......@@ -91,7 +91,7 @@ backend-migrations:
stage: test
services:
- name: postgis/postgis:12-3.0
- name: postgis/postgis:12-3.1
alias: postgres
script:
......
......@@ -74,6 +74,7 @@ COPY VERSION /etc/arkindex.version
# Copy templates in base dir for binary
ENV BASE_DIR=/usr/share/arkindex
COPY arkindex/templates /usr/share/arkindex/templates
COPY arkindex/documents/export/*.sql /usr/share/arkindex/documents/export/
# Touch python files for needed management commands
# Otherwise Django will not load the compiled module
......
......@@ -5,3 +5,4 @@ include tests-requirements.txt
recursive-include arkindex/templates *.html
recursive-include arkindex/templates *.json
recursive-include arkindex/templates *.txt
include arkindex/documents/export/*.sql
1.0.1-rc2
1.0.2
......@@ -65,12 +65,14 @@ from arkindex.dataimport.serializers.workers import (
RepositorySerializer,
WorkerActivitySerializer,
WorkerSerializer,
WorkerStatisticsSerializer,
WorkerVersionEditSerializer,
WorkerVersionSerializer,
)
from arkindex.documents.models import Corpus, Element
from arkindex.project.fields import ArrayRemove
from arkindex.project.mixins import (
ConflictAPIException,
CorpusACLMixin,
CustomPaginationViewMixin,
DeprecatedMixin,
......@@ -79,6 +81,7 @@ from arkindex.project.mixins import (
SelectionMixin,
WorkerACLMixin,
)
from arkindex.project.pagination import CustomCursorPagination
from arkindex.project.permissions import IsInternalUser, IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import RTrimChr
from arkindex.users.models import OAuthCredentials, Role
......@@ -906,18 +909,27 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView):
@extend_schema_view(
get=extend_schema(
tags=['ml'],
operation_id='ListCorpusWorkerVersions',
tags=['ml'],
description=(
'List worker versions used by elements of a given corpus.\n\n'
'No check is performed on workers access level in order to allow any user to see versions.'
),
parameters=[
OpenApiParameter(
'with_element_count',
type=bool,
default=False,
description='Include element counts in the response.',
)
],
)
)
class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
"""
List worker versions used by elements of a given corpus.
"""
pagination_class = CustomCursorPagination
permission_classes = (IsVerifiedOrReadOnly, )
serializer_class = WorkerVersionSerializer
# For OpenAPI type discovery
......@@ -928,12 +940,24 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
def get_queryset(self):
corpus = self.get_corpus()
return WorkerVersion.objects \
queryset = WorkerVersion.objects \
.filter(elements__corpus_id=corpus.id) \
.select_related('revision__repo', 'worker', 'worker__repository') \
.prefetch_related('revision__refs', 'revision__versions') \
.order_by('-revision__created') \
.annotate(element_count=Count('id'))
.prefetch_related(
'revision__repo',
'revision__refs',
'revision__versions',
'worker__repository',
) \
.order_by('-id')
if self.request.query_params.get('with_element_count', '').lower() in ('true', '1'):
queryset = queryset.annotate(element_count=Count('id'))
else:
# The Count() causes Django to add a GROUP BY, and without a count we need a DISTINCT
# because filtering on `elements` causes worker versions to be duplicated.
queryset = queryset.distinct()
return queryset
@extend_schema(tags=['repos'])
......@@ -1221,7 +1245,6 @@ class ListProcessElements(CustomPaginationViewMixin, CorpusACLMixin, ListAPIView
)
@extend_schema(tags=['ml'])
class UpdateWorkerActivity(GenericAPIView):
"""
Makes a worker (internal user) able to update its activity on an element
......@@ -1229,6 +1252,7 @@ class UpdateWorkerActivity(GenericAPIView):
"""
permission_classes = (IsInternalUser, )
serializer_class = WorkerActivitySerializer
queryset = WorkerActivity.objects.none()
@cached_property
def allowed_transitions(self):
......@@ -1245,35 +1269,127 @@ class UpdateWorkerActivity(GenericAPIView):
}
@extend_schema(
tags=['ml'],
responses={
200: DataImportSerializer,
409: None,
},
operation_id='UpdateWorkerActivity',
description=(
'Updates the activity of a worker version on an element.\n\n'
'The user must be **internal** to perform this request.'
'The user must be **internal** to perform this request.\n\n'
'A **HTTP_409_CONFLICT** is returned in case the body is valid but the update failed.'
),
)
def put(self, request, *args, **kwarg):
"""
Update a worker activity with a single database requests.
If the couple element_id, worker_version matches no existing
activity, the update count is 0.
If the new state is disallowed, the update count is 0 too.
"""
serializer = self.get_serializer(data=request.data, partial=False)
serializer.is_valid(raise_exception=True)
worker_version_id = self.kwargs['pk']
element_id = serializer.validated_data['element_id']
state = serializer.validated_data['state'].value
process_id = serializer.validated_data['process_id']
# We use the fact that only one worker activity may match the filter due to DB constraint
# Between zero and one worker activity can match the filter due to the DB constraint
activity = WorkerActivity.objects.filter(
worker_version_id=worker_version_id,
element_id=element_id,
state__in=self.allowed_transitions[state]
)
update_count = activity.update(state=state)
update_count = activity.update(state=state, process_id=process_id)
if not update_count:
# As no row has been updated the provided data was incorrect
raise ValidationError({
'__all__': [
'Either this worker activity does not exists '
f"or updating the state to '{state}' is forbidden."
]
})
# As no row has been updated the provided data was in conflict with the actual state
raise ConflictAPIException(
{
'__all__': [
'Either this worker activity does not exists '
f"or updating the state to '{state}' is forbidden."
]
}
)
return Response(serializer.data)
@extend_schema_view(
get=extend_schema(
operation_id='CorpusWorkersActivity',
tags=['imports']
)
)
class CorpusWorkersActivity(CorpusACLMixin, ListAPIView):
"""
Retrieve corpus wise statistics about the activity of all its worker processes.\n
Requires a **guest** access.
"""
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects.none()
def list(self, request, *args, **kwargs):
corpus = self.get_corpus(self.kwargs['corpus'], role=Role.Guest)
# Retrieve the distribution of activities on this corpus grouped by worker version
stats = WorkerActivity.objects \
.filter(element_id__in=corpus.elements.values('id')) \
.values('worker_version_id') \
.annotate(
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
return Response(
status=status.HTTP_200_OK,
data=WorkerStatisticsSerializer(stats, many=True).data
)
@extend_schema_view(
get=extend_schema(
operation_id='ProcessWorkersActivity',
tags=['imports']
)
)
class ProcessWorkersActivity(ProcessACLMixin, ListAPIView):
"""
Retrieve process statistics about the activity of its workers.\n
Requires a **guest** access.
"""
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects.none()
def list(self, request, *args, **kwargs):
process = get_object_or_404(DataImport.objects.all(), pk=self.kwargs['pk'])
access_level = self.process_access_level(process)
if not access_level:
raise NotFound
if access_level < Role.Admin.value:
raise PermissionDenied(detail='You do not have an admin access to this process.')
# Retrieve the distribution of activities on this process grouped by worker version
stats = WorkerActivity.objects \
.filter(process_id=process.id) \
.values('worker_version_id') \
.annotate(
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
return Response(
status=status.HTTP_200_OK,
data=WorkerStatisticsSerializer(stats, many=True).data
)
# Generated by Django 3.1.5 on 2021-05-07 07:48
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataimport', '0032_dataimport_activity_state'),
]
operations = [
migrations.AddField(
model_name='workeractivity',
name='process',
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='activities',
to='dataimport.dataimport'
),
),
]
......@@ -627,7 +627,7 @@ class WorkerActivityState(Enum):
class ActivityManager(models.Manager):
"""Model management for worker activities"""
def bulk_insert(self, worker_version_id, elements_qs, state=WorkerActivityState.Queued):
def bulk_insert(self, worker_version_id, process_id, elements_qs, state=WorkerActivityState.Queued):
"""
Create initial worker activities from a queryset of elements in a efficient way.
Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances).
......@@ -641,11 +641,12 @@ class ActivityManager(models.Manager):
cursor.execute(
f"""
INSERT INTO dataimport_workeractivity
(element_id, worker_version_id, state, id, created, updated)
(element_id, worker_version_id, state, process_id, id, created, updated)
SELECT
elt.id,
'{worker_version_id}'::uuid,
'{state.value}',
'{process_id}',
uuid_generate_v4(),
current_timestamp,
current_timestamp
......@@ -675,7 +676,14 @@ class WorkerActivity(IndexableModel):
)
state = EnumField(
WorkerActivityState,
default=WorkerActivityState.Queued
default=WorkerActivityState.Queued,
)
process = models.ForeignKey(
DataImport,
related_name='activities',
on_delete=models.SET_NULL,
null=True,
blank=True,
)
# Specific WorkerActivity manager
......
......@@ -394,7 +394,16 @@ class GitLabProvider(GitProvider):
return
if not sha:
raise ValidationError('Missing checkout SHA')
# If there isn't any SHA it means that a branch was deleted
ref = request.data.get('ref')
if not ref:
raise ValidationError('Missing branch reference')
# Delete existing branch
branch_name = ref[11:] if ref.startswith('refs/heads/') else ref
repo.refs.filter(name=branch_name, type=GitRefType.Branch).delete()
return
# Already took care of this event
if repo.revisions.filter(hash=sha).exists():
......
......@@ -21,10 +21,11 @@ class DataImportLightSerializer(serializers.ModelSerializer):
Serialize a data importing workflow
"""
state = EnumField(State)
state = EnumField(State, read_only=True)
mode = EnumField(DataImportMode, read_only=True)
creator = serializers.HiddenField(default=serializers.CurrentUserDefault())
workflow = serializers.HyperlinkedRelatedField(read_only=True, view_name='ponos:workflow-details')
activity_state = EnumField(ActivityState, read_only=True)
class Meta:
model = DataImport
......@@ -36,19 +37,25 @@ class DataImportLightSerializer(serializers.ModelSerializer):
'corpus',
'creator',
'workflow',
'activity_state',
)
read_only_fields = ('id', 'state', 'mode', 'corpus', 'creator', 'workflow')
read_only_fields = ('id', 'state', 'mode', 'corpus', 'creator', 'workflow', 'activity_state')
class DataImportSerializer(DataImportLightSerializer):
"""
Serialize a data importing workflow with its settings
"""
# Redefine state as writable
state = EnumField(State, read_only=True)
revision = RevisionSerializer(read_only=True)
element = ElementSlimSerializer(read_only=True)
element_id = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.none(),
default=None,
allow_null=True,
write_only=True,
source='element',
style={'base_template': 'input.html'},
)
folder_type = serializers.SlugField(source='folder_type.slug', default=None, read_only=True)
element_type = serializers.SlugRelatedField(queryset=ElementType.objects.none(), slug_field='slug', allow_null=True)
element_name_contains = serializers.CharField(
......@@ -57,26 +64,18 @@ class DataImportSerializer(DataImportLightSerializer):
allow_blank=True,
max_length=250
)
activity_state = EnumField(ActivityState, read_only=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
dataimport = self.context.get('dataimport')
if not dataimport or not dataimport.corpus:
return
self.fields['element_type'].queryset = ElementType.objects.filter(corpus=dataimport.corpus)
class Meta(DataImportLightSerializer.Meta):
fields = DataImportLightSerializer.Meta.fields + (
'files',
'revision',
'element',
'element_id',
'folder_type',
'element_type',
'element_name_contains',
'load_children',
'use_cache',
'activity_state',
)
read_only_fields = DataImportLightSerializer.Meta.read_only_fields + (
'files',
......@@ -84,9 +83,16 @@ class DataImportSerializer(DataImportLightSerializer):
'element',
'folder_type',
'use_cache',
'activity_state',
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
dataimport = self.context.get('dataimport')
if not dataimport or not dataimport.corpus:
return
self.fields['element_type'].queryset = ElementType.objects.filter(corpus=dataimport.corpus)
self.fields['element_id'].queryset = dataimport.corpus.elements.all()
def validate(self, data):
data = super().validate(data)
# Editing a dataimport name only is always allowed
......@@ -95,10 +101,17 @@ class DataImportSerializer(DataImportLightSerializer):
if not self.instance:
return
# Allow editing the element ID on file imports at any time
if self.instance.mode in (DataImportMode.Images, DataImportMode.PDF, DataImportMode.IIIF) and set(data.keys()) == {'element'}:
return data
if self.instance.state == State.Running:
raise serializers.ValidationError({'__all__': ['Cannot edit a workflow while it is running']})
if self.instance.mode != DataImportMode.Workers:
raise serializers.ValidationError({'__all__': [f'Only processes of mode {DataImportMode.Workers} can be updated']})
return data
......
......@@ -5,6 +5,7 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex.dataimport.models import (
DataImport,
Repository,
RepositoryType,
Revision,
......@@ -158,14 +159,27 @@ class RepositorySerializer(serializers.ModelSerializer):
class WorkerActivitySerializer(serializers.ModelSerializer):
"""
Serialize a repository
Serialize a worker activity on an element
"""
state = EnumField(WorkerActivityState)
element_id = serializers.UUIDField()
process_id = serializers.PrimaryKeyRelatedField(queryset=DataImport.objects.all())
class Meta:
model = WorkerActivity
fields = (
'element_id',
'process_id',
'state',
)
class WorkerStatisticsSerializer(serializers.Serializer):
"""
Serialize activity statistics of a worker version
"""
worker_version_id = serializers.UUIDField(read_only=True)
queued = serializers.IntegerField(read_only=True)
started = serializers.IntegerField(read_only=True)
processed = serializers.IntegerField(read_only=True)
error = serializers.IntegerField(read_only=True)
......@@ -484,14 +484,6 @@ class TestGitLabProvider(FixtureTestCase):
self.assertFalse(rev.exists())
self.assertFalse(repo_imports.exists())
# Missing SHA
request_mock.data['object_kind'] = 'push'
del request_mock.data['checkout_sha']
with self.assertRaises(ValidationError):
glp.handle_webhook(self.repo, request_mock)
self.assertFalse(rev.exists())
self.assertFalse(repo_imports.exists())
# Breaking change: a list!
request_mock.data = [request_mock.data]
with self.assertRaises(ValidationError):
......@@ -499,6 +491,39 @@ class TestGitLabProvider(FixtureTestCase):
self.assertFalse(rev.exists())
self.assertFalse(repo_imports.exists())
def test_handle_webhook_delete_branch(self):
"""
Test GitLabProvider properly handles a branch deletion
"""
rev = Revision(
repo=self.repo,
hash='1',
message='commit message',
author='bob',
)
rev.save()
self.assertTrue(self.repo.revisions.filter(hash='1').exists())
repo_imports = DataImport.objects.filter(revision__repo_id=str(self.repo.id))
glp = GitLabProvider(url='http://aaa', credentials=self.creds)
glp.update_or_create_ref(self.repo, rev, 'test', GitRefType.Branch)
self.assertEqual(len(self.repo.refs.all()), 1)
request_mock = MagicMock()
request_mock.META = {
'HTTP_X_GITLAB_EVENT': 'Push Hook',
'HTTP_X_GITLAB_TOKEN': 'hook-token',
}
request_mock.data = {
'object_kind': 'push',
'ref': 'refs/heads/test',
'commits': []
}
glp.handle_webhook(self.repo, request_mock)
self.assertTrue(self.repo.revisions.filter(hash='1').exists())
self.assertEqual(len(self.repo.refs.all()), 0)
self.assertFalse(repo_imports.exists())
def test_retrieve_repo_type(self):
"""
Gitlab provider allow to retrieve a project type
......
......@@ -10,7 +10,15 @@ from django.urls import reverse
from rest_framework import status
from rest_framework.exceptions import ValidationError
from arkindex.dataimport.models import ActivityState, DataImport, DataImportMode, RepositoryType
from arkindex.dataimport.models import (
ActivityState,
DataImport,
DataImportMode,
RepositoryType,
WorkerActivity,
WorkerActivityState,
WorkerVersion,
)
from arkindex.dataimport.utils import get_default_farm_id
from arkindex.documents.models import Corpus, ElementType
from arkindex.project.tests import FixtureAPITestCase
......@@ -86,6 +94,7 @@ class TestImports(FixtureAPITestCase):
'mode': process.mode.value,
'corpus': process.corpus_id and str(process.corpus.id),
'workflow': process.workflow and f'http://testserver/ponos/v1/workflow/{process.workflow.id}/',
'activity_state': ActivityState.Disabled.value,
}
def build_task(self, workflow_id, run, state, depth=1):
......@@ -140,6 +149,8 @@ class TestImports(FixtureAPITestCase):
"""
self.user_img_process.start()
self.client.force_login(self.user)
self.user_img_process.activity_state = ActivityState.Ready
self.user_img_process.save()
with self.assertNumQueries(8):
response = self.client.get(reverse('api:import-list'), {'with_workflow': 'true'})
self.assertEqual(response.status_code, status.HTTP_200_OK)
......@@ -153,6 +164,7 @@ class TestImports(FixtureAPITestCase):
'mode': DataImportMode.Images.value,
'corpus': str(self.user_img_process.corpus.id),
'workflow': f'http://testserver/ponos/v1/workflow/{self.user_img_process.workflow.id}/',
'activity_state': ActivityState.Ready.value,
}])
def test_list_exclude_workflow(self):
......@@ -175,6 +187,7 @@ class TestImports(FixtureAPITestCase):
'mode': DataImportMode.Images.value,
'corpus': str(self.user_img_process.corpus.id),
'workflow': None,
'activity_state': ActivityState.Disabled.value,
}])
def test_list_filter_corpus(self):
......@@ -433,7 +446,7 @@ class TestImports(FixtureAPITestCase):
A user is allowed to delete a dataimport if he has an admin right to its corpus
"""
self.client.force_login(self.user)
with self.assertNumQueries(10):
with self.assertNumQueries(11):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.elts_process.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
......@@ -442,10 +455,31 @@ class TestImports(FixtureAPITestCase):
A superuser is allowed to delete any dataimport
"""
self.client.force_login(self.superuser)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.user_img_process.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_delete_import_activities(self):
"""
Deleting a process sets the the FK on related activities to null
This should be deprecated soon and done asynchroniously as
the cost of updating many activities is rather high
"""
self.client.force_login(self.superuser)
activity = WorkerActivity.objects.create(
element=self.corpus.elements.first(),
process=self.elts_process,
worker_version=WorkerVersion.objects.get(worker__slug='reco'),
state=WorkerActivityState.Queued,
)
with self.assertNumQueries(9):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.elts_process.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
with self.assertRaises(DataImport.DoesNotExist):
self.elts_process.refresh_from_db()
activity.refresh_from_db()
self.assertEqual(activity.process_id, None)
def test_update_process_requires_login(self):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.user_img_process.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
......@@ -488,6 +522,27 @@ class TestImports(FixtureAPITestCase):
self.elts_process.refresh_from_db()
self.assertEqual(self.elts_process.name, 'newName')
def test_update_file_import_element(self):
"""
A file import's element can be updated even while it is running
"""
self.client.force_login(self.user)
process = self.corpus.imports.create(mode=DataImportMode.PDF, creator=self.user)
process.start()
process.workflow.tasks.update(state=State.Running)
self.assertIsNone(self.elts_process.element)
element = self.corpus.elements.first()
with self.assertNumQueries(14):
response = self.client.patch(
reverse('api:import-details', kwargs={'pk': process.id}),
{'element_id': str(element.id)},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
process.refresh_from_db()
self.assertEqual(process.element, element)
def test_update_process_no_permission(self):
"""
A user cannot update a dataimport linked to a corpus he has no admin access to
......@@ -659,7 +714,7 @@ class TestImports(FixtureAPITestCase):
},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.status_code, status.HTTP_200_OK, response.json())
self.assertEqual(response.json(), {
'name': 'newName',
'element_name_contains': 'AAA',
......
......@@ -4,10 +4,18 @@ from unittest.mock import MagicMock, call, patch
from django.urls import reverse
from rest_framework import status
from arkindex.dataimport.models import ActivityState, DataImportMode, WorkerActivity, WorkerActivityState, WorkerVersion
from arkindex.dataimport.models import (
ActivityState,
DataImport,
DataImportMode,
WorkerActivity,
WorkerActivityState,
WorkerVersion,
)
from arkindex.documents.models import Classification, ClassificationState, Element, MLClass
from arkindex.documents.tasks import initialize_activity
from arkindex.project.tests import FixtureTestCase
from arkindex.users.models import User
class TestWorkerActivity(FixtureTestCase):
......@@ -17,8 +25,18 @@ class TestWorkerActivity(FixtureTestCase):
super().setUpTestData()
cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
cls.element = Element.objects.get(name='Volume 1, page 2r')
cls.process = DataImport.objects.create(
mode=DataImportMode.Workers,
creator=cls.user
)
def setUp(self):
# Create a queued activity for this element
cls.activity = cls.element.activities.create(worker_version=cls.worker_version, state=WorkerActivityState.Queued)
self.activity = self.element.activities.create(
process=self.process,
worker_version=self.worker_version,
state=WorkerActivityState.Queued
)
def test_bulk_insert_activity_children(self):
"""
......@@ -28,28 +46,37 @@ class TestWorkerActivity(FixtureTestCase):
params = {
'worker_version_id': self.worker_version.id,
'corpus_id': self.corpus.id,
'state': WorkerActivityState.Started.value
'state': WorkerActivityState.Started.value,
'process_id': self.process.id,
}
with self.assertExactQueries('workeractivity_bulk_insert.sql', params=params):
WorkerActivity.objects.bulk_insert(self.worker_version.id, elements_qs, state=WorkerActivityState.Started)
WorkerActivity.objects.bulk_insert(self.worker_version.id, self.process.id, elements_qs, state=WorkerActivityState.Started)
self.assertEqual(elements_qs.count(), 5)
self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started).count(), 5)
self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started, process=self.process).count(), 5)
def test_bulk_insert_activity_existing(self):
"""
Elements in the queryset should be skipped if they already have an activity
The old process ID is preserved during the bulk create
"""
elements_qs = Element.objects.filter(type__slug='act', type__corpus_id=self.corpus.id)
old_process = DataImport.objects.create(mode=DataImportMode.Workers, creator=self.user)
WorkerActivity.objects.bulk_create([
WorkerActivity(element=element, worker_version=self.worker_version, state=WorkerActivityState.Processed.value)
WorkerActivity(
element=element,
worker_version=self.worker_version,
state=WorkerActivityState.Processed.value,
process=old_process
)
for element in elements_qs[:2]
])
with self.assertNumQueries(1):
WorkerActivity.objects.bulk_insert(self.worker_version.id, elements_qs, state=WorkerActivityState.Started)
WorkerActivity.objects.bulk_insert(self.worker_version.id, self.process.id, elements_qs, state=WorkerActivityState.Started)
self.assertEqual(WorkerActivity.objects.filter(element_id__in=elements_qs.values('id')).count(), 5)
self.assertEqual(elements_qs.count(), 5)
# Only 3 acts have been marked as started for this worker
self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started).count(), 3)
self.assertEqual(WorkerActivity.objects.filter(process=old_process).count(), 2)
@patch('arkindex.project.triggers.tasks.initialize_activity.delay')
def test_bulk_insert_children_class_filter(self, activities_delay_mock):
......@@ -84,20 +111,32 @@ class TestWorkerActivity(FixtureTestCase):
def test_put_activity_requires_internal(self):
"""
Only internal users (workers) are able to update the state of a worker activity
Internal users with an instance admin are able to update a worker activity
"""
internal_admin_user = User.objects.create_user('god@test.test', 'G0D')
internal_admin_user.is_internal = True
internal_admin_user.is_admin = True
internal_admin_user.save()
cases = (
(None, status.HTTP_403_FORBIDDEN, 0),
(self.user, status.HTTP_403_FORBIDDEN, 2),
(self.superuser, status.HTTP_403_FORBIDDEN, 2),
(self.internal_user, status.HTTP_200_OK, 3),
(self.internal_user, status.HTTP_200_OK, 4),
(internal_admin_user, status.HTTP_200_OK, 4)
)
for user, status_code, requests_count in cases:
self.activity.state = WorkerActivityState.Queued
self.activity.save()
if user:
self.client.force_login(user)
with self.assertNumQueries(requests_count):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(self.element.id), 'state': WorkerActivityState.Started.value},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status_code)
......@@ -105,15 +144,20 @@ class TestWorkerActivity(FixtureTestCase):
def test_put_activity_wrong_worker_version(self):
"""
Raises a generic error in case the worker version does not exists because a single SQL request is performed
The response is a HTTP_409_CONFLICT
"""
self.client.force_login(self.internal_user)
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(uuid.uuid4())}),
{'element_id': str(self.element.id), 'state': WorkerActivityState.Started.value},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
self.assertDictEqual(response.json(), {
'__all__': [
'Either this worker activity does not exists or '
......@@ -121,18 +165,23 @@ class TestWorkerActivity(FixtureTestCase):
]
})
def test_put_activity_unexisting(self):
def test_put_activity_element_unexisting(self):
"""
Raises a generic error in case no activity exists for this element
The response is a HTTP_409_CONFLICT
"""
self.client.force_login(self.internal_user)
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(uuid.uuid4()), 'state': WorkerActivityState.Started.value},
{
'element_id': str(uuid.uuid4()),
'process_id': str(self.process.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
self.assertDictEqual(response.json(), {
'__all__': [
'Either this worker activity does not exists or '
......@@ -140,6 +189,27 @@ class TestWorkerActivity(FixtureTestCase):
]
})
def test_put_activity_process_unexisting(self):
"""
Raises an error in case the process does not exist
"""
self.client.force_login(self.internal_user)
wrong_process_id = uuid.uuid4()
with self.assertNumQueries(3):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{
'element_id': str(self.element.id),
'process_id': str(wrong_process_id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'process_id': [f'Invalid pk "{wrong_process_id}" - object does not exist.']
})
def test_put_activity_allowed_states(self):
"""
Check each case state update is allowed depending on the actual state of the activity
......@@ -159,10 +229,14 @@ class TestWorkerActivity(FixtureTestCase):
for state, payload_state in allowed_states_update:
self.activity.state = state
self.activity.save()
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(self.element.id), 'state': payload_state},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': payload_state,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
......@@ -172,6 +246,7 @@ class TestWorkerActivity(FixtureTestCase):
def test_put_activity_forbidden_states(self):
"""
Check state update is forbidden for some non consistant cases
The response is a HTTP_409_CONFLICT
"""
queued, started, error, processed = (
WorkerActivityState[state].value
......@@ -187,13 +262,17 @@ class TestWorkerActivity(FixtureTestCase):
for state, payload_state in forbidden_states_update:
self.activity.state = state
self.activity.save()
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(self.element.id), 'state': payload_state},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': payload_state,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
self.assertDictEqual(response.json(), {
'__all__': [
'Either this worker activity does not exists or '
......@@ -203,6 +282,37 @@ class TestWorkerActivity(FixtureTestCase):
self.activity.refresh_from_db()
self.assertEqual(self.activity.state.value, state)
def test_put_activity_process_overriding(self):
"""
A process can update a worker activity event if it was initialized in another process
"""
self.client.force_login(self.internal_user)
process2 = DataImport.objects.create(
mode=DataImportMode.Workers,
creator=self.user
)
self.activity.process_id = process2.id
self.activity.save()
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{
'element_id': str(self.element.id),
'process_id': str(process2.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
'element_id': str(self.element.id),
'process_id': str(process2.id),
'state': WorkerActivityState.Started.value,
})
self.activity.refresh_from_db()
self.assertEqual(self.activity.process_id, process2.id)
@patch('arkindex.dataimport.models.DataImport.versions')
@patch('arkindex.dataimport.models.ActivityManager.bulk_insert')
def test_async_activities_error(self, bulk_insert_mock, versions_mock):
......@@ -235,8 +345,8 @@ class TestWorkerActivity(FixtureTestCase):
self.assertEqual(bulk_insert_mock.call_count, 2)
self.assertListEqual(bulk_insert_mock.call_args_list, [
call(worker_version_id=v1_id, elements_qs=elts_qs),
call(worker_version_id=v2_id, elements_qs=elts_qs)
call(worker_version_id=v1_id, process_id=process.id, elements_qs=elts_qs),
call(worker_version_id=v2_id, process_id=process.id, elements_qs=elts_qs)
])
process.refresh_from_db()
......
import itertools
import uuid
from django.urls import reverse
from rest_framework import status
from arkindex.dataimport.models import DataImport, DataImportMode, WorkerActivity, WorkerActivityState, WorkerVersion
from arkindex.documents.models import Corpus
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role, User
class TestWorkersActivity(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.version_1 = WorkerVersion.objects.get(worker__slug='reco')
cls.version_2 = WorkerVersion.objects.get(worker__slug='dla')
cls.private_corpus = Corpus.objects.create(name='private', public=False)
cls.elts_count = cls.corpus.elements.count()
cls.process = DataImport.objects.create(
mode=DataImportMode.Workers,
creator=cls.user,
corpus=cls.corpus,
)
# Generate worker activities
WorkerActivity.objects.bulk_create([
*(
WorkerActivity(
element_id=elt.id,
state=state,
worker_version_id=cls.version_1.id,
process_id=cls.process.id,
) for elt, state in zip(cls.corpus.elements.all(), itertools.cycle(WorkerActivityState))
), *(
WorkerActivity(
element_id=elt.id,
state=WorkerActivityState.Processed.value,
worker_version_id=cls.version_2.id,
process_id=None,
) for elt in cls.corpus.elements.all()
)
])
cls.error, cls.processed, cls.queued, cls.started = [
WorkerActivity.objects.filter(
element__corpus_id=cls.corpus.id,
worker_version_id=cls.version_1.id,
state=state
).count()
for state in [
WorkerActivityState.Error,
WorkerActivityState.Processed,
WorkerActivityState.Queued,
WorkerActivityState.Started
]
]
def test_workeractivities_stats_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'Authentication credentials were not provided.'})
def test_workeractivities_private_corpus(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.private_corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You do not have guest access to this corpus.'})
def test_workeractivities_unexisting_corpus(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.get(
reverse('api:corpus-workers-activity', kwargs={'corpus': str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_workeractivities_empty(self):
"""
Handle a corpus that has absolutely no activity
A user with a guest access can retrieve statistics
"""
user = User.objects.create_user('user42@test.test', 'abcd')
self.private_corpus.memberships.create(user=user, level=Role.Guest.value)
self.client.force_login(user)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.private_corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.json(), [])
def test_workers_activity_distributed_states(self):
self.client.force_login(self.user)
with self.assertNumQueries(4):
response = self.client.get(
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertCountEqual(response.json(), [
{
'worker_version_id': str(self.version_1.id),
'queued': self.queued,
'started': self.started,
'processed': self.processed,
'error': self.error,
}, {
'worker_version_id': str(self.version_2.id),
'queued': 0,
'started': 0,
'processed': self.corpus.elements.count(),
'error': 0,
}
])
def test_process_activity_stats_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(self.process.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'Authentication credentials were not provided.'})
def test_process_activity_stats_private(self):
self.process.corpus = self.private_corpus
self.process.save()
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(self.process.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You do not have an admin access to this process.'})
def test_process_activity_stats_unexisting(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_process_activity_stats(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(self.process.id)})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.json(), [
{
'worker_version_id': str(self.version_1.id),
'queued': self.queued,
'started': self.started,
'processed': self.processed,
'error': self.error,
}
])
......@@ -842,45 +842,49 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
self.assertEqual(self.version_1.docker_command, 'mysupercommand')
self.version_1.configuration = {"test": "test1"}
def _assert_corpus_worker_version_list(self, response):
self.assertDictEqual(response.json(), {
'count': 1,
'next': None,
'number': 1,
'previous': None,
'results': [{
'id': str(self.version_1.id),
'configuration': {'test': 42},
'revision': {
'id': str(self.version_1.revision_id),
'hash': '1337',
'author': 'Test user',
'message': 'My w0rk3r',
'created': '2020-02-02T01:23:45.678000Z',
'commit_url': 'http://my_repo.fake/workers/worker/commit/1337',
'refs': []
},
'docker_image': str(self.version_1.docker_image_id),
'docker_image_iid': None,
'docker_image_name': self.version_1.docker_image_name,
'state': 'available',
'element_count': 9,
'worker': {
'id': str(self.worker_1.id),
'name': self.worker_1.name,
'type': self.worker_1.type,
'slug': self.worker_1.slug,
}
}]
})
def _serialize_worker_version(self, version, element_count=False):
data = {
'id': str(version.id),
'configuration': {'test': 42},
'revision': {
'id': str(version.revision_id),
'hash': '1337',
'author': 'Test user',
'message': 'My w0rk3r',
'created': '2020-02-02T01:23:45.678000Z',
'commit_url': 'http://my_repo.fake/workers/worker/commit/1337',
'refs': []
},
'docker_image': str(version.docker_image_id),
'docker_image_iid': None,
'docker_image_name': version.docker_image_name,
'state': 'available',
'worker': {
'id': str(version.worker.id),
'name': version.worker.name,
'type': version.worker.type,
'slug': version.worker.slug,
}
}
if element_count:
data['element_count'] = version.elements.filter(corpus=self.corpus).count()
return data
def test_corpus_worker_version_no_login(self):
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self._assert_corpus_worker_version_list(response)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1)
]
})
def test_corpus_worker_version_not_verified(self):
self.user.verified_email = False
......@@ -888,16 +892,52 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(11):
with self.assertNumQueries(12):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self._assert_corpus_worker_version_list(response)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1)
]
})
def test_corpus_worker_version_list(self):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(11):
with self.assertNumQueries(12):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self._assert_corpus_worker_version_list(response)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1)
]
})
def test_corpus_worker_version_list_with_element_count(self):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(12):
response = self.client.get(
reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}),
{'with_element_count': 'true'}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1, element_count=True)
]
})
......@@ -25,9 +25,11 @@ class ElementTypeInline(admin.TabularInline):
class CorpusAdmin(admin.ModelAdmin):
list_display = ('id', 'name', 'public', 'repository', 'top_level_type')
list_display = ('id', 'name', 'public', 'repository', 'top_level_type', 'created')
raw_id_fields = ('thumbnail', )
search_fields = ('name', )
inlines = (ElementTypeInline, UserMembershipInline, GroupMembershipInline)
ordering = ('-created', )
def has_delete_permission(self, request, obj=None):
# Require everyone to use the asynchronous corpus deletion
......@@ -54,6 +56,11 @@ class AllowedMetaDataAdmin(admin.ModelAdmin):
list_display = ('id', 'corpus', 'type', 'name')
readonly_fields = ('id', )
def get_form(self, *args, **kwargs):
form = super().get_form(*args, **kwargs)
form.base_fields['corpus'].queryset = Corpus.objects.order_by('name', 'id')
return form
class MetaDataAdmin(admin.ModelAdmin):
list_display = ('id', 'name', 'type', )
......@@ -105,6 +112,11 @@ class MLClassAdmin(admin.ModelAdmin):
search_fields = ('name',)
fields = ('name', 'corpus')
def get_form(self, *args, **kwargs):
form = super().get_form(*args, **kwargs)
form.base_fields['corpus'].queryset = Corpus.objects.order_by('name', 'id')
return form
class EntityMetaForm(forms.ModelForm):
metas = HStoreFormField()
......
......@@ -47,6 +47,7 @@ from arkindex.documents.serializers.elements import (
CorpusSerializer,
ElementBulkSerializer,
ElementCreateSerializer,
ElementDestinationSerializer,
ElementListSerializer,
ElementNeighborsSerializer,
ElementParentSerializer,
......@@ -63,7 +64,7 @@ from arkindex.project.openapi import AutoSchema
from arkindex.project.pagination import LargePageNumberPagination, PageNumberPagination
from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import BulkMap
from arkindex.project.triggers import corpus_delete, element_trash, worker_results_delete
from arkindex.project.triggers import corpus_delete, element_trash, move_element, worker_results_delete
from arkindex.users.models import Role
from arkindex.users.utils import filter_rights
......@@ -915,6 +916,7 @@ class CorpusList(ListCreateAPIView):
corpora = Corpus.objects \
.filter(id__in=corpora_level) \
.annotate(authorized_users=Count('memberships')) \
.select_related('thumbnail__zone__image__server') \
.prefetch_related('types') \
.order_by('name', 'id')
......@@ -1409,3 +1411,26 @@ class WorkerResultsDestroy(CorpusACLMixin, DestroyAPIView):
)
return Response(status=status.HTTP_204_NO_CONTENT)
@extend_schema_view(
post=extend_schema(operation_id='MoveElement', tags=['elements']),
)
class ElementMove(CreateAPIView):
"""
Move an element to a new destination folder
"""
serializer_class = ElementDestinationSerializer
permission_classes = (IsVerified, )
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
source = serializer.validated_data['source']
destination = serializer.validated_data['destination']
serializer.perform_create_checks(source, destination)
move_element(source=source, destination=destination, user_id=self.request.user.id)
return Response(serializer.data, status=status.HTTP_200_OK)
......@@ -213,7 +213,19 @@ class EntityLinkCreate(CreateAPIView):
serializer_class = EntityLinkCreateSerializer
@extend_schema_view(post=extend_schema(operation_id='CreateTranscriptionEntity', tags=['entities']))
@extend_schema_view(post=extend_schema(
operation_id='CreateTranscriptionEntity',
tags=['entities'],
parameters=[
OpenApiParameter(
'id',
type=UUID,
location=OpenApiParameter.PATH,
description='ID of the transcription to link an entity to.',
required=True,
)
]
))
class TranscriptionEntityCreate(CreateAPIView):
"""
Link an existing Entity to a given transcription with its position
......
......@@ -140,6 +140,9 @@ class TranscriptionEdit(ACLMixin, RetrieveUpdateDestroyAPIView):
if not self.has_access(transcription.element.corpus, role.value):
raise PermissionDenied(detail=detail)
if self.request.method in ('PUT', 'PATCH') and transcription.transcription_entities.exists():
raise PermissionDenied(detail='Transcriptions with entities cannot be modified.')
@extend_schema_view(
post=extend_schema(
......@@ -266,7 +269,8 @@ class ElementTranscriptionsBulk(CreateAPIView):
zone_id=annotation['zone_id'],
corpus_id=self.element.corpus_id,
type=elt_type,
name=next_path_ordering + 1
name=next_path_ordering + 1,
worker_version=worker_version
)
# Specify the annotated element has been created
annotation['created'] = True
......
......@@ -80,6 +80,16 @@ def delete_element(element_id: UUID) -> None:
""", {'id': element_id})
logger.info(f"Deleted {cursor.rowcount} user selections")
# Remove workers activity on this element
cursor.execute("""
DELETE FROM dataimport_workeractivity
WHERE element_id = %(id)s
OR element_id IN (
SELECT element_id FROM documents_elementpath WHERE path && ARRAY[%(id)s]
)
""", {'id': element_id})
logger.info(f"Deleted {cursor.rowcount} worker activities")
# Remove element paths and elements simultaneously
cursor.execute("""
WITH children_ids (id) AS (
......