Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (18)
Showing
with 933 additions and 176 deletions
1.0.2-rc1
1.0.2-rc3
......@@ -65,12 +65,14 @@ from arkindex.dataimport.serializers.workers import (
RepositorySerializer,
WorkerActivitySerializer,
WorkerSerializer,
WorkerStatisticsSerializer,
WorkerVersionEditSerializer,
WorkerVersionSerializer,
)
from arkindex.documents.models import Corpus, Element
from arkindex.project.fields import ArrayRemove
from arkindex.project.mixins import (
ConflictAPIException,
CorpusACLMixin,
CustomPaginationViewMixin,
DeprecatedMixin,
......@@ -79,6 +81,7 @@ from arkindex.project.mixins import (
SelectionMixin,
WorkerACLMixin,
)
from arkindex.project.pagination import CustomCursorPagination
from arkindex.project.permissions import IsInternalUser, IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import RTrimChr
from arkindex.users.models import OAuthCredentials, Role
......@@ -906,18 +909,27 @@ class WorkerVersionList(WorkerACLMixin, ListCreateAPIView):
@extend_schema_view(
get=extend_schema(
tags=['ml'],
operation_id='ListCorpusWorkerVersions',
tags=['ml'],
description=(
'List worker versions used by elements of a given corpus.\n\n'
'No check is performed on workers access level in order to allow any user to see versions.'
),
parameters=[
OpenApiParameter(
'with_element_count',
type=bool,
default=False,
description='Include element counts in the response.',
)
],
)
)
class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
"""
List worker versions used by elements of a given corpus.
"""
pagination_class = CustomCursorPagination
permission_classes = (IsVerifiedOrReadOnly, )
serializer_class = WorkerVersionSerializer
# For OpenAPI type discovery
......@@ -928,12 +940,24 @@ class CorpusWorkerVersionList(CorpusACLMixin, ListAPIView):
def get_queryset(self):
corpus = self.get_corpus()
return WorkerVersion.objects \
queryset = WorkerVersion.objects \
.filter(elements__corpus_id=corpus.id) \
.select_related('revision__repo', 'worker', 'worker__repository') \
.prefetch_related('revision__refs', 'revision__versions') \
.order_by('-revision__created') \
.annotate(element_count=Count('id'))
.prefetch_related(
'revision__repo',
'revision__refs',
'revision__versions',
'worker__repository',
) \
.order_by('-id')
if self.request.query_params.get('with_element_count', '').lower() in ('true', '1'):
queryset = queryset.annotate(element_count=Count('id'))
else:
# The Count() causes Django to add a GROUP BY, and without a count we need a DISTINCT
# because filtering on `elements` causes worker versions to be duplicated.
queryset = queryset.distinct()
return queryset
@extend_schema(tags=['repos'])
......@@ -1221,7 +1245,6 @@ class ListProcessElements(CustomPaginationViewMixin, CorpusACLMixin, ListAPIView
)
@extend_schema(tags=['ml'])
class UpdateWorkerActivity(GenericAPIView):
"""
Makes a worker (internal user) able to update its activity on an element
......@@ -1229,6 +1252,7 @@ class UpdateWorkerActivity(GenericAPIView):
"""
permission_classes = (IsInternalUser, )
serializer_class = WorkerActivitySerializer
queryset = WorkerActivity.objects.none()
@cached_property
def allowed_transitions(self):
......@@ -1245,35 +1269,125 @@ class UpdateWorkerActivity(GenericAPIView):
}
@extend_schema(
tags=['ml'],
responses={
200: DataImportSerializer,
409: None,
},
operation_id='UpdateWorkerActivity',
description=(
'Updates the activity of a worker version on an element.\n\n'
'The user must be **internal** to perform this request.'
'The user must be **internal** to perform this request.\n\n'
'A **HTTP_409_CONFLICT** is returned in case the body is valid but the update failed.'
),
)
def put(self, request, *args, **kwarg):
"""
Update a worker activity with a single database requests.
If the couple element_id, worker_version matches no existing
activity, the update count is 0.
If the new state is disallowed, the update count is 0 too.
"""
serializer = self.get_serializer(data=request.data, partial=False)
serializer.is_valid(raise_exception=True)
worker_version_id = self.kwargs['pk']
element_id = serializer.validated_data['element_id']
state = serializer.validated_data['state'].value
process_id = serializer.validated_data['process_id']
# We use the fact that only one worker activity may match the filter due to DB constraint
# Between zero and one worker activity can match the filter due to the DB constraint
activity = WorkerActivity.objects.filter(
worker_version_id=worker_version_id,
element_id=element_id,
state__in=self.allowed_transitions[state]
)
update_count = activity.update(state=state)
update_count = activity.update(state=state, process_id=process_id)
if not update_count:
# As no row has been updated the provided data was incorrect
raise ValidationError({
'__all__': [
'Either this worker activity does not exists '
f"or updating the state to '{state}' is forbidden."
]
})
# As no row has been updated the provided data was in conflict with the actual state
raise ConflictAPIException(
{
'__all__': [
'Either this worker activity does not exists '
f"or updating the state to '{state}' is forbidden."
]
}
)
return Response(serializer.data)
@extend_schema_view(
get=extend_schema(
operation_id='CorpusWorkersActivity',
tags=['imports']
)
)
class CorpusWorkersActivity(CorpusACLMixin, ListAPIView):
"""
Retrieve corpus wise statistics about the activity of all its worker processes
"""
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects.none()
def list(self, request, *args, **kwargs):
corpus = self.get_corpus(self.kwargs['corpus'], role=Role.Admin)
# Retrieve the distribution of activities on this corpus grouped by worker version
stats = WorkerActivity.objects \
.filter(element_id__in=corpus.elements.values('id')) \
.values('worker_version_id') \
.annotate(
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
return Response(
status=status.HTTP_200_OK,
data=WorkerStatisticsSerializer(stats, many=True).data
)
@extend_schema_view(
get=extend_schema(
operation_id='ProcessWorkersActivity',
tags=['imports']
)
)
class ProcessWorkersActivity(ProcessACLMixin, ListAPIView):
"""
Retrieve process statistics about the activity of its workers
"""
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects.none()
def list(self, request, *args, **kwargs):
process = get_object_or_404(DataImport.objects.all(), pk=self.kwargs['pk'])
access_level = self.process_access_level(process)
if not access_level:
raise NotFound
if access_level < Role.Admin.value:
raise PermissionDenied(detail='You do not have an admin access to this process.')
# Retrieve the distribution of activities on this process grouped by worker version
stats = WorkerActivity.objects \
.filter(process_id=process.id) \
.values('worker_version_id') \
.annotate(
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
return Response(
status=status.HTTP_200_OK,
data=WorkerStatisticsSerializer(stats, many=True).data
)
# Generated by Django 3.1.5 on 2021-05-07 07:48
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataimport', '0032_dataimport_activity_state'),
]
operations = [
migrations.AddField(
model_name='workeractivity',
name='process',
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='activities',
to='dataimport.dataimport'
),
),
]
......@@ -627,7 +627,7 @@ class WorkerActivityState(Enum):
class ActivityManager(models.Manager):
"""Model management for worker activities"""
def bulk_insert(self, worker_version_id, elements_qs, state=WorkerActivityState.Queued):
def bulk_insert(self, worker_version_id, process_id, elements_qs, state=WorkerActivityState.Queued):
"""
Create initial worker activities from a queryset of elements in a efficient way.
Due to the possible large amount of elements, we use a bulk insert from the elements query (best performances).
......@@ -641,11 +641,12 @@ class ActivityManager(models.Manager):
cursor.execute(
f"""
INSERT INTO dataimport_workeractivity
(element_id, worker_version_id, state, id, created, updated)
(element_id, worker_version_id, state, process_id, id, created, updated)
SELECT
elt.id,
'{worker_version_id}'::uuid,
'{state.value}',
'{process_id}',
uuid_generate_v4(),
current_timestamp,
current_timestamp
......@@ -675,7 +676,14 @@ class WorkerActivity(IndexableModel):
)
state = EnumField(
WorkerActivityState,
default=WorkerActivityState.Queued
default=WorkerActivityState.Queued,
)
process = models.ForeignKey(
DataImport,
related_name='activities',
on_delete=models.SET_NULL,
null=True,
blank=True,
)
# Specific WorkerActivity manager
......
......@@ -21,10 +21,11 @@ class DataImportLightSerializer(serializers.ModelSerializer):
Serialize a data importing workflow
"""
state = EnumField(State)
state = EnumField(State, read_only=True)
mode = EnumField(DataImportMode, read_only=True)
creator = serializers.HiddenField(default=serializers.CurrentUserDefault())
workflow = serializers.HyperlinkedRelatedField(read_only=True, view_name='ponos:workflow-details')
activity_state = EnumField(ActivityState, read_only=True)
class Meta:
model = DataImport
......@@ -36,19 +37,25 @@ class DataImportLightSerializer(serializers.ModelSerializer):
'corpus',
'creator',
'workflow',
'activity_state',
)
read_only_fields = ('id', 'state', 'mode', 'corpus', 'creator', 'workflow')
read_only_fields = ('id', 'state', 'mode', 'corpus', 'creator', 'workflow', 'activity_state')
class DataImportSerializer(DataImportLightSerializer):
"""
Serialize a data importing workflow with its settings
"""
# Redefine state as writable
state = EnumField(State, read_only=True)
revision = RevisionSerializer(read_only=True)
element = ElementSlimSerializer(read_only=True)
element_id = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.none(),
default=None,
allow_null=True,
write_only=True,
source='element',
style={'base_template': 'input.html'},
)
folder_type = serializers.SlugField(source='folder_type.slug', default=None, read_only=True)
element_type = serializers.SlugRelatedField(queryset=ElementType.objects.none(), slug_field='slug', allow_null=True)
element_name_contains = serializers.CharField(
......@@ -57,26 +64,18 @@ class DataImportSerializer(DataImportLightSerializer):
allow_blank=True,
max_length=250
)
activity_state = EnumField(ActivityState, read_only=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
dataimport = self.context.get('dataimport')
if not dataimport or not dataimport.corpus:
return
self.fields['element_type'].queryset = ElementType.objects.filter(corpus=dataimport.corpus)
class Meta(DataImportLightSerializer.Meta):
fields = DataImportLightSerializer.Meta.fields + (
'files',
'revision',
'element',
'element_id',
'folder_type',
'element_type',
'element_name_contains',
'load_children',
'use_cache',
'activity_state',
)
read_only_fields = DataImportLightSerializer.Meta.read_only_fields + (
'files',
......@@ -84,9 +83,16 @@ class DataImportSerializer(DataImportLightSerializer):
'element',
'folder_type',
'use_cache',
'activity_state',
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
dataimport = self.context.get('dataimport')
if not dataimport or not dataimport.corpus:
return
self.fields['element_type'].queryset = ElementType.objects.filter(corpus=dataimport.corpus)
self.fields['element_id'].queryset = dataimport.corpus.elements.all()
def validate(self, data):
data = super().validate(data)
# Editing a dataimport name only is always allowed
......@@ -95,10 +101,17 @@ class DataImportSerializer(DataImportLightSerializer):
if not self.instance:
return
# Allow editing the element ID on file imports at any time
if self.instance.mode in (DataImportMode.Images, DataImportMode.PDF, DataImportMode.IIIF) and set(data.keys()) == {'element'}:
return data
if self.instance.state == State.Running:
raise serializers.ValidationError({'__all__': ['Cannot edit a workflow while it is running']})
if self.instance.mode != DataImportMode.Workers:
raise serializers.ValidationError({'__all__': [f'Only processes of mode {DataImportMode.Workers} can be updated']})
return data
......
......@@ -5,6 +5,7 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex.dataimport.models import (
DataImport,
Repository,
RepositoryType,
Revision,
......@@ -158,14 +159,27 @@ class RepositorySerializer(serializers.ModelSerializer):
class WorkerActivitySerializer(serializers.ModelSerializer):
"""
Serialize a repository
Serialize a worker activity on an element
"""
state = EnumField(WorkerActivityState)
element_id = serializers.UUIDField()
process_id = serializers.PrimaryKeyRelatedField(queryset=DataImport.objects.all())
class Meta:
model = WorkerActivity
fields = (
'element_id',
'process_id',
'state',
)
class WorkerStatisticsSerializer(serializers.Serializer):
"""
Serialize activity statistics of a worker version
"""
worker_version_id = serializers.UUIDField(read_only=True)
queued = serializers.IntegerField(read_only=True)
started = serializers.IntegerField(read_only=True)
processed = serializers.IntegerField(read_only=True)
error = serializers.IntegerField(read_only=True)
......@@ -10,7 +10,15 @@ from django.urls import reverse
from rest_framework import status
from rest_framework.exceptions import ValidationError
from arkindex.dataimport.models import ActivityState, DataImport, DataImportMode, RepositoryType
from arkindex.dataimport.models import (
ActivityState,
DataImport,
DataImportMode,
RepositoryType,
WorkerActivity,
WorkerActivityState,
WorkerVersion,
)
from arkindex.dataimport.utils import get_default_farm_id
from arkindex.documents.models import Corpus, ElementType
from arkindex.project.tests import FixtureAPITestCase
......@@ -86,6 +94,7 @@ class TestImports(FixtureAPITestCase):
'mode': process.mode.value,
'corpus': process.corpus_id and str(process.corpus.id),
'workflow': process.workflow and f'http://testserver/ponos/v1/workflow/{process.workflow.id}/',
'activity_state': ActivityState.Disabled.value,
}
def build_task(self, workflow_id, run, state, depth=1):
......@@ -140,6 +149,8 @@ class TestImports(FixtureAPITestCase):
"""
self.user_img_process.start()
self.client.force_login(self.user)
self.user_img_process.activity_state = ActivityState.Ready
self.user_img_process.save()
with self.assertNumQueries(8):
response = self.client.get(reverse('api:import-list'), {'with_workflow': 'true'})
self.assertEqual(response.status_code, status.HTTP_200_OK)
......@@ -153,6 +164,7 @@ class TestImports(FixtureAPITestCase):
'mode': DataImportMode.Images.value,
'corpus': str(self.user_img_process.corpus.id),
'workflow': f'http://testserver/ponos/v1/workflow/{self.user_img_process.workflow.id}/',
'activity_state': ActivityState.Ready.value,
}])
def test_list_exclude_workflow(self):
......@@ -175,6 +187,7 @@ class TestImports(FixtureAPITestCase):
'mode': DataImportMode.Images.value,
'corpus': str(self.user_img_process.corpus.id),
'workflow': None,
'activity_state': ActivityState.Disabled.value,
}])
def test_list_filter_corpus(self):
......@@ -433,7 +446,7 @@ class TestImports(FixtureAPITestCase):
A user is allowed to delete a dataimport if he has an admin right to its corpus
"""
self.client.force_login(self.user)
with self.assertNumQueries(10):
with self.assertNumQueries(11):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.elts_process.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
......@@ -442,10 +455,31 @@ class TestImports(FixtureAPITestCase):
A superuser is allowed to delete any dataimport
"""
self.client.force_login(self.superuser)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.user_img_process.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_delete_import_activities(self):
"""
Deleting a process sets the the FK on related activities to null
This should be deprecated soon and done asynchroniously as
the cost of updating many activities is rather high
"""
self.client.force_login(self.superuser)
activity = WorkerActivity.objects.create(
element=self.corpus.elements.first(),
process=self.elts_process,
worker_version=WorkerVersion.objects.get(worker__slug='reco'),
state=WorkerActivityState.Queued,
)
with self.assertNumQueries(9):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.elts_process.id}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
with self.assertRaises(DataImport.DoesNotExist):
self.elts_process.refresh_from_db()
activity.refresh_from_db()
self.assertEqual(activity.process_id, None)
def test_update_process_requires_login(self):
response = self.client.delete(reverse('api:import-details', kwargs={'pk': self.user_img_process.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
......@@ -488,6 +522,27 @@ class TestImports(FixtureAPITestCase):
self.elts_process.refresh_from_db()
self.assertEqual(self.elts_process.name, 'newName')
def test_update_file_import_element(self):
"""
A file import's element can be updated even while it is running
"""
self.client.force_login(self.user)
process = self.corpus.imports.create(mode=DataImportMode.PDF, creator=self.user)
process.start()
process.workflow.tasks.update(state=State.Running)
self.assertIsNone(self.elts_process.element)
element = self.corpus.elements.first()
with self.assertNumQueries(14):
response = self.client.patch(
reverse('api:import-details', kwargs={'pk': process.id}),
{'element_id': str(element.id)},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
process.refresh_from_db()
self.assertEqual(process.element, element)
def test_update_process_no_permission(self):
"""
A user cannot update a dataimport linked to a corpus he has no admin access to
......@@ -659,7 +714,7 @@ class TestImports(FixtureAPITestCase):
},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.status_code, status.HTTP_200_OK, response.json())
self.assertEqual(response.json(), {
'name': 'newName',
'element_name_contains': 'AAA',
......
......@@ -4,7 +4,14 @@ from unittest.mock import MagicMock, call, patch
from django.urls import reverse
from rest_framework import status
from arkindex.dataimport.models import ActivityState, DataImportMode, WorkerActivity, WorkerActivityState, WorkerVersion
from arkindex.dataimport.models import (
ActivityState,
DataImport,
DataImportMode,
WorkerActivity,
WorkerActivityState,
WorkerVersion,
)
from arkindex.documents.models import Classification, ClassificationState, Element, MLClass
from arkindex.documents.tasks import initialize_activity
from arkindex.project.tests import FixtureTestCase
......@@ -18,8 +25,18 @@ class TestWorkerActivity(FixtureTestCase):
super().setUpTestData()
cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
cls.element = Element.objects.get(name='Volume 1, page 2r')
cls.process = DataImport.objects.create(
mode=DataImportMode.Workers,
creator=cls.user
)
def setUp(self):
# Create a queued activity for this element
cls.activity = cls.element.activities.create(worker_version=cls.worker_version, state=WorkerActivityState.Queued)
self.activity = self.element.activities.create(
process=self.process,
worker_version=self.worker_version,
state=WorkerActivityState.Queued
)
def test_bulk_insert_activity_children(self):
"""
......@@ -29,28 +46,37 @@ class TestWorkerActivity(FixtureTestCase):
params = {
'worker_version_id': self.worker_version.id,
'corpus_id': self.corpus.id,
'state': WorkerActivityState.Started.value
'state': WorkerActivityState.Started.value,
'process_id': self.process.id,
}
with self.assertExactQueries('workeractivity_bulk_insert.sql', params=params):
WorkerActivity.objects.bulk_insert(self.worker_version.id, elements_qs, state=WorkerActivityState.Started)
WorkerActivity.objects.bulk_insert(self.worker_version.id, self.process.id, elements_qs, state=WorkerActivityState.Started)
self.assertEqual(elements_qs.count(), 5)
self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started).count(), 5)
self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started, process=self.process).count(), 5)
def test_bulk_insert_activity_existing(self):
"""
Elements in the queryset should be skipped if they already have an activity
The old process ID is preserved during the bulk create
"""
elements_qs = Element.objects.filter(type__slug='act', type__corpus_id=self.corpus.id)
old_process = DataImport.objects.create(mode=DataImportMode.Workers, creator=self.user)
WorkerActivity.objects.bulk_create([
WorkerActivity(element=element, worker_version=self.worker_version, state=WorkerActivityState.Processed.value)
WorkerActivity(
element=element,
worker_version=self.worker_version,
state=WorkerActivityState.Processed.value,
process=old_process
)
for element in elements_qs[:2]
])
with self.assertNumQueries(1):
WorkerActivity.objects.bulk_insert(self.worker_version.id, elements_qs, state=WorkerActivityState.Started)
WorkerActivity.objects.bulk_insert(self.worker_version.id, self.process.id, elements_qs, state=WorkerActivityState.Started)
self.assertEqual(WorkerActivity.objects.filter(element_id__in=elements_qs.values('id')).count(), 5)
self.assertEqual(elements_qs.count(), 5)
# Only 3 acts have been marked as started for this worker
self.assertEqual(WorkerActivity.objects.filter(state=WorkerActivityState.Started).count(), 3)
self.assertEqual(WorkerActivity.objects.filter(process=old_process).count(), 2)
@patch('arkindex.project.triggers.tasks.initialize_activity.delay')
def test_bulk_insert_children_class_filter(self, activities_delay_mock):
......@@ -95,8 +121,8 @@ class TestWorkerActivity(FixtureTestCase):
(None, status.HTTP_403_FORBIDDEN, 0),
(self.user, status.HTTP_403_FORBIDDEN, 2),
(self.superuser, status.HTTP_403_FORBIDDEN, 2),
(self.internal_user, status.HTTP_200_OK, 3),
(internal_admin_user, status.HTTP_200_OK, 3)
(self.internal_user, status.HTTP_200_OK, 4),
(internal_admin_user, status.HTTP_200_OK, 4)
)
for user, status_code, requests_count in cases:
self.activity.state = WorkerActivityState.Queued
......@@ -106,7 +132,11 @@ class TestWorkerActivity(FixtureTestCase):
with self.assertNumQueries(requests_count):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(self.element.id), 'state': WorkerActivityState.Started.value},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status_code)
......@@ -114,15 +144,20 @@ class TestWorkerActivity(FixtureTestCase):
def test_put_activity_wrong_worker_version(self):
"""
Raises a generic error in case the worker version does not exists because a single SQL request is performed
The response is a HTTP_409_CONFLICT
"""
self.client.force_login(self.internal_user)
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(uuid.uuid4())}),
{'element_id': str(self.element.id), 'state': WorkerActivityState.Started.value},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
self.assertDictEqual(response.json(), {
'__all__': [
'Either this worker activity does not exists or '
......@@ -130,18 +165,23 @@ class TestWorkerActivity(FixtureTestCase):
]
})
def test_put_activity_unexisting(self):
def test_put_activity_element_unexisting(self):
"""
Raises a generic error in case no activity exists for this element
The response is a HTTP_409_CONFLICT
"""
self.client.force_login(self.internal_user)
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(uuid.uuid4()), 'state': WorkerActivityState.Started.value},
{
'element_id': str(uuid.uuid4()),
'process_id': str(self.process.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
self.assertDictEqual(response.json(), {
'__all__': [
'Either this worker activity does not exists or '
......@@ -149,6 +189,27 @@ class TestWorkerActivity(FixtureTestCase):
]
})
def test_put_activity_process_unexisting(self):
"""
Raises an error in case the process does not exist
"""
self.client.force_login(self.internal_user)
wrong_process_id = uuid.uuid4()
with self.assertNumQueries(3):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{
'element_id': str(self.element.id),
'process_id': str(wrong_process_id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'process_id': [f'Invalid pk "{wrong_process_id}" - object does not exist.']
})
def test_put_activity_allowed_states(self):
"""
Check each case state update is allowed depending on the actual state of the activity
......@@ -168,10 +229,14 @@ class TestWorkerActivity(FixtureTestCase):
for state, payload_state in allowed_states_update:
self.activity.state = state
self.activity.save()
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(self.element.id), 'state': payload_state},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': payload_state,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
......@@ -181,6 +246,7 @@ class TestWorkerActivity(FixtureTestCase):
def test_put_activity_forbidden_states(self):
"""
Check state update is forbidden for some non consistant cases
The response is a HTTP_409_CONFLICT
"""
queued, started, error, processed = (
WorkerActivityState[state].value
......@@ -196,13 +262,17 @@ class TestWorkerActivity(FixtureTestCase):
for state, payload_state in forbidden_states_update:
self.activity.state = state
self.activity.save()
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{'element_id': str(self.element.id), 'state': payload_state},
{
'element_id': str(self.element.id),
'process_id': str(self.process.id),
'state': payload_state,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
self.assertDictEqual(response.json(), {
'__all__': [
'Either this worker activity does not exists or '
......@@ -212,6 +282,37 @@ class TestWorkerActivity(FixtureTestCase):
self.activity.refresh_from_db()
self.assertEqual(self.activity.state.value, state)
def test_put_activity_process_overriding(self):
"""
A process can update a worker activity event if it was initialized in another process
"""
self.client.force_login(self.internal_user)
process2 = DataImport.objects.create(
mode=DataImportMode.Workers,
creator=self.user
)
self.activity.process_id = process2.id
self.activity.save()
with self.assertNumQueries(4):
response = self.client.put(
reverse('api:update-worker-activity', kwargs={'pk': str(self.worker_version.id)}),
{
'element_id': str(self.element.id),
'process_id': str(process2.id),
'state': WorkerActivityState.Started.value,
},
content_type='application/json',
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
'element_id': str(self.element.id),
'process_id': str(process2.id),
'state': WorkerActivityState.Started.value,
})
self.activity.refresh_from_db()
self.assertEqual(self.activity.process_id, process2.id)
@patch('arkindex.dataimport.models.DataImport.versions')
@patch('arkindex.dataimport.models.ActivityManager.bulk_insert')
def test_async_activities_error(self, bulk_insert_mock, versions_mock):
......@@ -244,8 +345,8 @@ class TestWorkerActivity(FixtureTestCase):
self.assertEqual(bulk_insert_mock.call_count, 2)
self.assertListEqual(bulk_insert_mock.call_args_list, [
call(worker_version_id=v1_id, elements_qs=elts_qs),
call(worker_version_id=v2_id, elements_qs=elts_qs)
call(worker_version_id=v1_id, process_id=process.id, elements_qs=elts_qs),
call(worker_version_id=v2_id, process_id=process.id, elements_qs=elts_qs)
])
process.refresh_from_db()
......
......@@ -4,7 +4,7 @@ import uuid
from django.urls import reverse
from rest_framework import status
from arkindex.dataimport.models import WorkerActivity, WorkerActivityState, WorkerVersion
from arkindex.dataimport.models import DataImport, DataImportMode, WorkerActivity, WorkerActivityState, WorkerVersion
from arkindex.documents.models import Corpus
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role, User
......@@ -19,27 +19,48 @@ class TestWorkersActivity(FixtureAPITestCase):
cls.version_2 = WorkerVersion.objects.get(worker__slug='dla')
cls.private_corpus = Corpus.objects.create(name='private', public=False)
cls.elts_count = cls.corpus.elements.count()
cls.process = DataImport.objects.create(
mode=DataImportMode.Workers,
creator=cls.user,
corpus=cls.corpus,
)
# Generate worker activities
WorkerActivity.objects.bulk_create([
*(
WorkerActivity(
element_id=elt.id,
state=state,
worker_version_id=cls.version_1.id
worker_version_id=cls.version_1.id,
process_id=cls.process.id,
) for elt, state in zip(cls.corpus.elements.all(), itertools.cycle(WorkerActivityState))
), *(
WorkerActivity(
element_id=elt.id,
state=WorkerActivityState.Processed.value,
worker_version_id=cls.version_2.id
worker_version_id=cls.version_2.id,
process_id=None,
) for elt in cls.corpus.elements.all()
)
])
cls.error, cls.processed, cls.queued, cls.started = [
WorkerActivity.objects.filter(
element__corpus_id=cls.corpus.id,
worker_version_id=cls.version_1.id,
state=state
).count()
for state in [
WorkerActivityState.Error,
WorkerActivityState.Processed,
WorkerActivityState.Queued,
WorkerActivityState.Started
]
]
def test_version_stats_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(
reverse('api:workers-activity', kwargs={'corpus': str(self.corpus.id)})
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'Authentication credentials were not provided.'})
......@@ -48,7 +69,7 @@ class TestWorkersActivity(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:workers-activity', kwargs={'corpus': str(self.private_corpus.id)})
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.private_corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You do not have admin access to this corpus.'})
......@@ -59,7 +80,7 @@ class TestWorkersActivity(FixtureAPITestCase):
self.client.force_login(user)
with self.assertNumQueries(5):
response = self.client.get(
reverse('api:workers-activity', kwargs={'corpus': str(self.corpus.id)})
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You do not have admin access to this corpus.'})
......@@ -68,7 +89,7 @@ class TestWorkersActivity(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.get(
reverse('api:workers-activity', kwargs={'corpus': str(uuid.uuid4())})
reverse('api:corpus-workers-activity', kwargs={'corpus': str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
......@@ -81,39 +102,25 @@ class TestWorkersActivity(FixtureAPITestCase):
self.client.force_login(user)
with self.assertNumQueries(6):
response = self.client.get(
reverse('api:workers-activity', kwargs={'corpus': str(self.private_corpus.id)})
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.private_corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.json(), [])
def test_workers_activity_distributed_states(self):
error, processed, queued, started = [
WorkerActivity.objects.filter(
element__corpus_id=self.corpus.id,
worker_version_id=self.version_1.id,
state=state
).count()
for state in [
WorkerActivityState.Error,
WorkerActivityState.Processed,
WorkerActivityState.Queued,
WorkerActivityState.Started
]
]
assert error > 0 and processed > 0 and queued > 0 and started > 0
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:workers-activity', kwargs={'corpus': str(self.corpus.id)})
reverse('api:corpus-workers-activity', kwargs={'corpus': str(self.corpus.id)})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertCountEqual(response.json(), [
{
'worker_version_id': str(self.version_1.id),
'queued': queued,
'started': started,
'processed': processed,
'error': error,
'queued': self.queued,
'started': self.started,
'processed': self.processed,
'error': self.error,
}, {
'worker_version_id': str(self.version_2.id),
'queued': 0,
......@@ -122,3 +129,47 @@ class TestWorkersActivity(FixtureAPITestCase):
'error': 0,
}
])
def test_process_activity_stats_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(self.process.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'Authentication credentials were not provided.'})
def test_process_activity_stats_requires_admin(self):
self.process.corpus = self.private_corpus
self.process.save()
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(self.process.id)})
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {'detail': 'You do not have an admin access to this process.'})
def test_process_activity_stats_unexisting(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(uuid.uuid4())})
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_process_activity_stats(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.get(
reverse('api:process-workers-activity', kwargs={'pk': str(self.process.id)})
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.json(), [
{
'worker_version_id': str(self.version_1.id),
'queued': self.queued,
'started': self.started,
'processed': self.processed,
'error': self.error,
}
])
......@@ -842,45 +842,49 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
self.assertEqual(self.version_1.docker_command, 'mysupercommand')
self.version_1.configuration = {"test": "test1"}
def _assert_corpus_worker_version_list(self, response):
self.assertDictEqual(response.json(), {
'count': 1,
'next': None,
'number': 1,
'previous': None,
'results': [{
'id': str(self.version_1.id),
'configuration': {'test': 42},
'revision': {
'id': str(self.version_1.revision_id),
'hash': '1337',
'author': 'Test user',
'message': 'My w0rk3r',
'created': '2020-02-02T01:23:45.678000Z',
'commit_url': 'http://my_repo.fake/workers/worker/commit/1337',
'refs': []
},
'docker_image': str(self.version_1.docker_image_id),
'docker_image_iid': None,
'docker_image_name': self.version_1.docker_image_name,
'state': 'available',
'element_count': 9,
'worker': {
'id': str(self.worker_1.id),
'name': self.worker_1.name,
'type': self.worker_1.type,
'slug': self.worker_1.slug,
}
}]
})
def _serialize_worker_version(self, version, element_count=False):
data = {
'id': str(version.id),
'configuration': {'test': 42},
'revision': {
'id': str(version.revision_id),
'hash': '1337',
'author': 'Test user',
'message': 'My w0rk3r',
'created': '2020-02-02T01:23:45.678000Z',
'commit_url': 'http://my_repo.fake/workers/worker/commit/1337',
'refs': []
},
'docker_image': str(version.docker_image_id),
'docker_image_iid': None,
'docker_image_name': version.docker_image_name,
'state': 'available',
'worker': {
'id': str(version.worker.id),
'name': version.worker.name,
'type': version.worker.type,
'slug': version.worker.slug,
}
}
if element_count:
data['element_count'] = version.elements.filter(corpus=self.corpus).count()
return data
def test_corpus_worker_version_no_login(self):
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(7):
with self.assertNumQueries(8):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self._assert_corpus_worker_version_list(response)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1)
]
})
def test_corpus_worker_version_not_verified(self):
self.user.verified_email = False
......@@ -888,16 +892,52 @@ class TestWorkersWorkerVersions(FixtureAPITestCase):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(11):
with self.assertNumQueries(12):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self._assert_corpus_worker_version_list(response)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1)
]
})
def test_corpus_worker_version_list(self):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(11):
with self.assertNumQueries(12):
response = self.client.get(reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self._assert_corpus_worker_version_list(response)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1)
]
})
def test_corpus_worker_version_list_with_element_count(self):
self.client.force_login(self.user)
self.corpus.elements.filter(type__slug='word').update(worker_version=self.version_1)
with self.assertNumQueries(12):
response = self.client.get(
reverse('api:corpus-versions', kwargs={'pk': self.corpus.id}),
{'with_element_count': 'true'}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
'count': None,
'previous': None,
'next': None,
'results': [
self._serialize_worker_version(self.version_1, element_count=True)
]
})
......@@ -31,7 +31,7 @@ from rest_framework.generics import (
from rest_framework.mixins import DestroyModelMixin
from rest_framework.response import Response
from arkindex.dataimport.models import WorkerActivity, WorkerActivityState, WorkerVersion
from arkindex.dataimport.models import WorkerVersion
from arkindex.documents.models import (
AllowedMetaData,
Classification,
......@@ -47,6 +47,7 @@ from arkindex.documents.serializers.elements import (
CorpusSerializer,
ElementBulkSerializer,
ElementCreateSerializer,
ElementDestinationSerializer,
ElementListSerializer,
ElementNeighborsSerializer,
ElementParentSerializer,
......@@ -54,7 +55,6 @@ from arkindex.documents.serializers.elements import (
ElementSlimSerializer,
ElementTypeSerializer,
MetaDataUpdateSerializer,
WorkerStatisticsSerializer,
)
from arkindex.documents.serializers.light import CorpusAllowedMetaDataSerializer, ElementTypeLightSerializer
from arkindex.documents.serializers.ml import ElementTranscriptionSerializer
......@@ -64,7 +64,7 @@ from arkindex.project.openapi import AutoSchema
from arkindex.project.pagination import LargePageNumberPagination, PageNumberPagination
from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
from arkindex.project.tools import BulkMap
from arkindex.project.triggers import corpus_delete, element_trash, worker_results_delete
from arkindex.project.triggers import corpus_delete, element_trash, move_element, worker_results_delete
from arkindex.users.models import Role
from arkindex.users.utils import filter_rights
......@@ -1414,35 +1414,23 @@ class WorkerResultsDestroy(CorpusACLMixin, DestroyAPIView):
@extend_schema_view(
get=extend_schema(
operation_id='RetriveWorkersActivity',
tags=['elements']
)
post=extend_schema(operation_id='MoveElement', tags=['elements']),
)
class WorkersActivity(CorpusACLMixin, ListAPIView):
class ElementMove(CreateAPIView):
"""
Retrieve corpus wise statistics about the activity of a single worker version
Move an element to a new destination folder
"""
serializer_class = ElementDestinationSerializer
permission_classes = (IsVerified, )
serializer_class = WorkerStatisticsSerializer
pagination_class = None
queryset = WorkerActivity.objects.none()
def list(self, request, *args, **kwargs):
corpus = self.get_corpus(self.kwargs['corpus'], role=Role.Admin)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
# Retrieve the distribution of activities on this corpus grouped by worker version
stats = WorkerActivity.objects \
.filter(element_id__in=corpus.elements.values('id')) \
.values('worker_version_id') \
.annotate(
**{
state.value: Count('id', filter=Q(state=state.value))
for state in WorkerActivityState
}
)
source = serializer.validated_data['source']
destination = serializer.validated_data['destination']
serializer.perform_create_checks(source, destination)
return Response(
status=status.HTTP_200_OK,
data=WorkerStatisticsSerializer(stats, many=True).data
)
move_element(source=source, destination=destination, user_id=self.request.user.id)
return Response(serializer.data, status=status.HTTP_200_OK)
......@@ -49,8 +49,18 @@ class MetaDataUpdateSerializer(MetaDataLightSerializer):
"""
Allow editing MetaData
"""
entity = serializers.PrimaryKeyRelatedField(queryset=Entity.objects.none(), required=False, allow_null=True)
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), required=False, allow_null=True)
entity = serializers.PrimaryKeyRelatedField(
queryset=Entity.objects.none(),
required=False,
allow_null=True,
style={'base_template': 'input.html'},
)
worker_version = serializers.PrimaryKeyRelatedField(
queryset=WorkerVersion.objects.all(),
required=False,
allow_null=True,
style={'base_template': 'input.html'},
)
class Meta:
model = MetaData
......@@ -79,7 +89,12 @@ class CorpusSerializer(serializers.ModelSerializer):
rights = serializers.SerializerMethodField(read_only=True)
types = ElementTypeLightSerializer(many=True, read_only=True)
authorized_users = serializers.SerializerMethodField(read_only=True)
thumbnail = serializers.PrimaryKeyRelatedField(queryset=Element.objects.none(), allow_null=True, default=None)
thumbnail = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.none(),
allow_null=True,
default=None,
style={'base_template': 'input.html'},
)
thumbnail_url = serializers.SerializerMethodField(read_only=True)
class Meta:
......@@ -320,7 +335,8 @@ class ElementSerializer(ElementSlimSerializer):
write_only=True,
help_text='Link this element to an image by UUID via a polygon. '
'When the image is updated, if there was an image before and the polygon is not updated, '
'the previous polygon is reused. Otherwise, a polygon filling the new image is used.'
'the previous polygon is reused. Otherwise, a polygon filling the new image is used.',
style={'base_template': 'input.html'},
)
polygon = LinearRingField(
required=False,
......@@ -680,3 +696,40 @@ class WorkerStatisticsSerializer(serializers.Serializer):
started = serializers.IntegerField(read_only=True)
processed = serializers.IntegerField(read_only=True)
error = serializers.IntegerField(read_only=True)
class ElementDestinationSerializer(serializers.Serializer):
source = serializers.PrimaryKeyRelatedField(queryset=Element.objects.none())
destination = serializers.PrimaryKeyRelatedField(queryset=Element.objects.none())
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.context.get('request'):
# Do not raise Error in order to create OpenAPI schema
return
corpora = Corpus.objects.writable(self.context['request'].user)
self.fields['source'].queryset = Element.objects.filter(corpus__in=corpora).select_related('corpus')
self.fields['destination'].queryset = Element.objects.filter(corpus__in=corpora).select_related('corpus')
def validate(self, data):
data = super().validate(data)
source = data.get('source')
destination = data.get('destination')
if destination.id == source.id:
raise ValidationError({'destination': ['A source element cannot be moved into itself']})
if destination.corpus != source.corpus:
raise ValidationError({'destination': ['A source element cannot be moved to a destination from another corpus']})
return data
def perform_create_checks(self, source, destination):
# Assert destination is not a source's direct ancestor already
if ElementPath.objects.filter(element_id=source.id, path__last=destination.id).exists():
raise ValidationError({'destination': [
"'{}' is already a direct parent of '{}'".format(destination.id, source.id)
]})
# Assert destination is not a source's descendant
if ElementPath.objects.filter(element_id=destination.id, path__contains=[source.id]).exists():
raise ValidationError({'destination': [
"'{}' is a child of element '{}'".format(destination.id, source.id)
]})
......@@ -124,11 +124,18 @@ class EntityCreateSerializer(BaseEntitySerializer):
"""
Serialize an entity with a possible parents and children
"""
corpus = serializers.PrimaryKeyRelatedField(queryset=Corpus.objects.none())
corpus = serializers.PrimaryKeyRelatedField(
queryset=Corpus.objects.none(),
style={'base_template': 'input.html'},
)
metas = serializers.HStoreField(child=serializers.CharField(), required=False)
children = EntityLinkSerializer(many=True, read_only=True)
parents = EntityLinkSerializer(many=True, read_only=True)
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None)
worker_version = serializers.PrimaryKeyRelatedField(
queryset=WorkerVersion.objects.all(),
default=None,
style={'base_template': 'input.html'},
)
class Meta:
model = Entity
......@@ -162,9 +169,18 @@ class EntityLinkCreateSerializer(EntityLinkSerializer):
"""
Serialize an entity with a possible parents and children
"""
parent = serializers.PrimaryKeyRelatedField(queryset=Entity.objects.none())
child = serializers.PrimaryKeyRelatedField(queryset=Entity.objects.none())
role = serializers.PrimaryKeyRelatedField(queryset=Entity.objects.none())
parent = serializers.PrimaryKeyRelatedField(
queryset=Entity.objects.none(),
style={'base_template': 'input.html'},
)
child = serializers.PrimaryKeyRelatedField(
queryset=Entity.objects.none(),
style={'base_template': 'input.html'},
)
role = serializers.PrimaryKeyRelatedField(
queryset=EntityRole.objects.none(),
style={'base_template': 'input.html'},
)
class Meta:
model = EntityLink
......@@ -200,6 +216,7 @@ class TranscriptionEntitySerializer(serializers.ModelSerializer):
queryset=WorkerVersion.objects.all(),
source='worker_version',
default=None,
style={'base_template': 'input.html'},
)
class Meta:
......
......@@ -81,9 +81,20 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
"""
Serializer to create a single classification, defaulting to manual
"""
element = serializers.PrimaryKeyRelatedField(queryset=Element.objects.using('default').none())
ml_class = serializers.PrimaryKeyRelatedField(queryset=MLClass.objects.using('default').none())
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), allow_null=True, default=None)
element = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.using('default').none(),
style={'base_template': 'input.html'},
)
ml_class = serializers.PrimaryKeyRelatedField(
queryset=MLClass.objects.using('default').none(),
style={'base_template': 'input.html'},
)
worker_version = serializers.PrimaryKeyRelatedField(
queryset=WorkerVersion.objects.all(),
allow_null=True,
default=None,
style={'base_template': 'input.html'},
)
confidence = serializers.FloatField(
min_value=0,
max_value=1,
......@@ -243,7 +254,12 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer):
"""
Allows the insertion of a manual transcription attached to an element
"""
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), required=False, allow_null=True)
worker_version = serializers.PrimaryKeyRelatedField(
queryset=WorkerVersion.objects.all(),
required=False,
allow_null=True,
style={'base_template': 'input.html'},
)
score = serializers.FloatField(
min_value=0,
max_value=1,
......@@ -333,7 +349,8 @@ class ElementTranscriptionsBulkSerializer(serializers.Serializer):
)
worker_version = serializers.PrimaryKeyRelatedField(
queryset=WorkerVersion.objects.all(),
help_text='A WorkerVersion ID that transcriptions will refer to'
help_text='A WorkerVersion ID that transcriptions will refer to',
style={'base_template': 'input.html'},
)
transcriptions = SimpleTranscriptionSerializer(
many=True,
......@@ -394,7 +411,10 @@ class TranscriptionBulkItemSerializer(serializers.Serializer):
class TranscriptionBulkSerializer(serializers.Serializer):
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all())
worker_version = serializers.PrimaryKeyRelatedField(
queryset=WorkerVersion.objects.all(),
style={'base_template': 'input.html'},
)
transcriptions = TranscriptionBulkItemSerializer(many=True)
def validate(self, data):
......@@ -449,8 +469,13 @@ class ClassificationsSerializer(serializers.Serializer):
parent = serializers.PrimaryKeyRelatedField(
# The real queryset is set in __init__
queryset=Element.objects.none(),
style={'base_template': 'input.html'},
)
worker_version = serializers.PrimaryKeyRelatedField(
queryset=WorkerVersion.objects.all(),
default=None,
style={'base_template': 'input.html'},
)
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None)
classifications = ClassificationBulkSerializer(many=True, allow_empty=False)
def __init__(self, *args, **kwargs):
......
......@@ -169,6 +169,14 @@ def worker_results_delete(corpus_id: str, version_id: str, parent_id: str) -> No
transcriptions._raw_delete(using='default')
@job('high')
def move_element(source: Element, destination: Element) -> None:
paths = ElementPath.objects.filter(element_id=source.id)
for path in paths:
Element.objects.get(id=path.path[-1]).remove_child(source)
source.add_parent(destination)
@job('default', timeout=3600)
def initialize_activity(process: DataImport):
"""
......@@ -178,7 +186,11 @@ def initialize_activity(process: DataImport):
try:
with transaction.atomic():
for version_id in process.versions.values_list('id', flat=True):
WorkerActivity.objects.bulk_insert(worker_version_id=version_id, elements_qs=process.list_elements())
WorkerActivity.objects.bulk_insert(
worker_version_id=version_id,
process_id=process.id,
elements_qs=process.list_elements()
)
except Exception as e:
process.activity_state = ActivityState.Error
process.save()
......
from unittest.mock import patch
from uuid import UUID
from django.db import connections
from django.db.backends.base.base import _thread
from arkindex.documents.models import ElementPath
from arkindex.documents.tasks import move_element
from arkindex.project.tests import FixtureTestCase
PATHS_IDS = [
UUID('00000000-0000-0000-0000-000000000000'),
UUID('11111111-1111-1111-1111-111111111111'),
UUID('22222222-2222-2222-2222-222222222222'),
UUID('33333333-3333-3333-3333-333333333333'),
UUID('44444444-4444-4444-4444-444444444444'),
UUID('55555555-5555-5555-5555-555555555555'),
UUID('66666666-6666-6666-6666-666666666666'),
UUID('77777777-7777-7777-7777-777777777777'),
UUID('88888888-8888-8888-8888-888888888888'),
]
class TestMoveElement(FixtureTestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.page_type = cls.corpus.types.get(slug='page')
cls.destination = cls.corpus.elements.get(name='Volume 2')
cls.parent = cls.corpus.elements.get(name='Volume 1')
cls.source_with_children = cls.corpus.elements.get(name='Volume 1, page 1r')
cls.source_without_child = cls.corpus.elements.get(name='Volume 1, page 2r')
ElementPath.objects.filter(path__contains=[cls.source_without_child.id]).delete()
@patch.object(ElementPath._meta.get_field('id'), 'get_default')
def test_run_on_source_without_child(self, default_field_mock):
default_field_mock.return_value = PATHS_IDS[0]
# No child on this page
self.assertEqual(ElementPath.objects.filter(path__contains=[self.source_without_child.id]).count(), 0)
source_paths = ElementPath.objects.filter(element_id=self.source_without_child.id)
self.assertEqual(len(source_paths), 1)
self.assertEqual(list(source_paths.values('path')), [{'path': [self.parent.id]}])
with self.assertExactQueries('element_move_without_child.sql', params={
'source_id': str(self.source_without_child.id),
'parent_id': str(self.parent.id),
'destination_id': str(self.destination.id),
'page_type_id': str(self.page_type.id),
'path_id': str(PATHS_IDS[0]),
'savepoints': [f"s{_thread.get_ident()}_x{connections['default'].savepoint_state + 1}", f"s{_thread.get_ident()}_x{connections['default'].savepoint_state + 2}"]
}):
move_element(self.source_without_child, self.destination)
self.assertEqual(len(source_paths), 1)
self.assertEqual(list(source_paths.values('path')), [{'path': [self.destination.id]}])
@patch.object(ElementPath._meta.get_field('id'), 'get_default')
def test_run_on_source_with_children(self, default_field_mock):
default_field_mock.side_effect = PATHS_IDS
# 4 children on this page
children_paths = ElementPath.objects.filter(path__contains=[self.source_with_children.id])
self.assertEqual(children_paths.count(), 4)
self.assertEqual(list(children_paths.values('path')), [
{'path': [self.parent.id, self.source_with_children.id]},
{'path': [self.parent.id, self.source_with_children.id]},
{'path': [self.parent.id, self.source_with_children.id]},
{'path': [self.parent.id, self.source_with_children.id]}
])
source_paths = ElementPath.objects.filter(element_id=self.source_with_children.id)
self.assertEqual(len(source_paths), 1)
self.assertEqual(list(source_paths.values('path')), [{'path': [self.parent.id]}])
with self.assertExactQueries('element_move_with_children.sql', params={
'source_id': str(self.source_with_children.id),
'parent_id': str(self.parent.id),
'destination_id': str(self.destination.id),
'children_ids': [str(id) for id in children_paths.values_list('element_id', flat=True)],
'page_type_id': str(self.page_type.id),
'paths_ids': [str(id) for id in PATHS_IDS],
'savepoints': [f"s{_thread.get_ident()}_x{connections['default'].savepoint_state + 1}", f"s{_thread.get_ident()}_x{connections['default'].savepoint_state + 2}"]
}):
move_element(self.source_with_children, self.destination)
self.assertEqual(len(source_paths), 1)
self.assertEqual(list(source_paths.values('path')), [{'path': [self.destination.id]}])
# Assert children were also moved
self.assertEqual(list(children_paths.values('path')), [
{'path': [self.destination.id, self.source_with_children.id]},
{'path': [self.destination.id, self.source_with_children.id]},
{'path': [self.destination.id, self.source_with_children.id]},
{'path': [self.destination.id, self.source_with_children.id]}
])
from unittest.mock import call, patch
from django.urls import reverse
from rest_framework import status
from arkindex.documents.models import Corpus
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role
class TestMoveElement(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.source = cls.corpus.elements.get(name='Volume 1, page 1r')
cls.destination = cls.corpus.elements.get(name='Volume 2')
def test_move_element_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(reverse('api:move-element'), {'source': str(self.source.id), 'destination': str(self.destination.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_move_element_requires_verified(self):
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.post(reverse('api:move-element'), {'source': str(self.source.id), 'destination': str(self.destination.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_move_element_wrong_acl(self):
private_corpus = Corpus.objects.create(name='private', public=False)
private_element = private_corpus.elements.create(
type=private_corpus.types.create(slug='folder'),
)
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(reverse('api:move-element'), {'source': str(private_element.id), 'destination': str(private_element.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{
'source': [f'Invalid pk "{private_element.id}" - object does not exist.'],
'destination': [f'Invalid pk "{private_element.id}" - object does not exist.']
}
)
def test_move_element_wrong_source(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(reverse('api:move-element'), {'source': '12341234-1234-1234-1234-123412341234', 'destination': str(self.destination.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{'source': ['Invalid pk "12341234-1234-1234-1234-123412341234" - object does not exist.']}
)
def test_move_element_wrong_destination(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(reverse('api:move-element'), {'source': str(self.source.id), 'destination': '12341234-1234-1234-1234-123412341234'}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{'destination': ['Invalid pk "12341234-1234-1234-1234-123412341234" - object does not exist.']}
)
def test_move_element_same_source_destination(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(reverse('api:move-element'), {'source': str(self.source.id), 'destination': str(self.source.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{'destination': ['A source element cannot be moved into itself']}
)
def test_move_element_different_corpus(self):
corpus2 = Corpus.objects.create(name='new')
corpus2.memberships.create(user=self.user, level=Role.Contributor.value)
destination = corpus2.elements.create(type=corpus2.types.create(slug='folder'))
self.client.force_login(self.user)
with self.assertNumQueries(5):
response = self.client.post(reverse('api:move-element'), {'source': str(self.source.id), 'destination': str(destination.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{'destination': ['A source element cannot be moved to a destination from another corpus']}
)
def test_move_element_destination_is_direct_parent(self):
destination = self.corpus.elements.get(name='Volume 1')
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(reverse('api:move-element'), {'source': str(self.source.id), 'destination': str(destination.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{'destination': [f"'{destination.id}' is already a direct parent of '{self.source.id}'"]}
)
def test_move_element_destination_is_child(self):
source = self.corpus.elements.get(name='Volume 1')
destination_id = self.source.id
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(reverse('api:move-element'), {'source': str(source.id), 'destination': str(destination_id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json(),
{'destination': [f"'{destination_id}' is a child of element '{source.id}'"]}
)
@patch('arkindex.project.triggers.tasks.move_element.delay')
def test_move_element(self, delay_mock):
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(reverse('api:move-element'), {'source': str(self.source.id), 'destination': str(self.destination.id)}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(delay_mock.call_count, 1)
self.assertEqual(delay_mock.call_args, call(
source=self.source,
destination=self.destination,
user_id=self.user.id,
description=f"Moving element {self.source.name} to element {self.destination.name}"
))
......@@ -208,6 +208,7 @@ class ImageUploadSerializer(ImageSerializer):
queryset=ImageServer.objects.exclude(s3_bucket=None),
required=False,
write_only=True,
style={'base_template': 'input.html'},
)
path = IIIFPathField(
required=False,
......
......@@ -3,6 +3,7 @@ from django.views.generic.base import RedirectView
from arkindex.dataimport.api import (
AvailableRepositoriesList,
CorpusWorkersActivity,
CorpusWorkerVersionList,
CorpusWorkflow,
DataFileCreate,
......@@ -16,6 +17,7 @@ from arkindex.dataimport.api import (
GitRepositoryImportHook,
ImportTranskribus,
ListProcessElements,
ProcessWorkersActivity,
RepositoryList,
RepositoryRetrieve,
RevisionRetrieve,
......@@ -38,6 +40,7 @@ from arkindex.documents.api.elements import (
ElementBulkCreate,
ElementChildren,
ElementMetadata,
ElementMove,
ElementNeighbors,
ElementParent,
ElementParents,
......@@ -49,7 +52,6 @@ from arkindex.documents.api.elements import (
ManageSelection,
MetadataEdit,
WorkerResultsDestroy,
WorkersActivity,
)
from arkindex.documents.api.entities import (
CorpusRoles,
......@@ -125,6 +127,7 @@ api = [
name='element-transcriptions-bulk'
),
path('element/<uuid:child>/parent/<uuid:parent>/', ElementParent.as_view(), name='element-parent'),
path('element/move/', ElementMove.as_view(), name='move-element'),
# Corpora
path('corpus/', CorpusList.as_view(), name='corpus'),
......@@ -138,7 +141,7 @@ api = [
path('corpus/<uuid:pk>/selection/', CorpusSelectionDestroy.as_view(), name='corpus-delete-selection'),
path('corpus/<uuid:pk>/search/', CorpusSearch.as_view(), name='corpus-search'),
path('corpus/<uuid:corpus>/workerversion/<uuid:version>/results/', WorkerResultsDestroy.as_view(), name='worker-delete-results'),
path('corpus/<uuid:corpus>/workers-activity/', WorkersActivity.as_view(), name='workers-activity'),
path('corpus/<uuid:corpus>/workers-activity/', CorpusWorkersActivity.as_view(), name='corpus-workers-activity'),
# Moderation
......@@ -218,6 +221,7 @@ api = [
path('imports/<uuid:pk>/workers/', WorkerRunList.as_view(), name='worker-run-list'),
path('imports/workers/<uuid:pk>/', WorkerRunDetails.as_view(), name='worker-run-details'),
path('process/<uuid:pk>/elements/', ListProcessElements.as_view(), name='process-elements-list'),
path('process/<uuid:pk>/workers-activity/', ProcessWorkersActivity.as_view(), name='process-workers-activity'),
# Image management
path('image/', ImageCreate.as_view(), name='image-create'),
......
......@@ -3,6 +3,7 @@ from django.db.models import Q
from django.shortcuts import get_object_or_404
from django.views.decorators.cache import cache_page
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.exceptions import APIException, PermissionDenied, ValidationError
from rest_framework.serializers import CharField, Serializer
......@@ -242,6 +243,11 @@ class DeprecatedExceptionSerializer(Serializer):
detail = CharField()
class ConflictAPIException(APIException):
status_code = status.HTTP_409_CONFLICT
default_code = 'conflict'
class DeprecatedAPIException(APIException):
status_code = 410
default_detail = 'This endpoint has been deprecated.'
......