Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (32)
Showing
with 1032 additions and 1025 deletions
......@@ -158,7 +158,7 @@ backend-build:
script:
- ci/build.sh Dockerfile
backend-build-binary:
backend-build-binary-docker:
stage: build
image: docker:19.03.1
services:
......@@ -180,6 +180,21 @@ backend-build-binary:
script:
- ci/build.sh Dockerfile.binary "-binary"
# Make sure arkindex is always compatible with Nuitka
backend-build-binary:
stage: build
image: python:3.10
before_script:
- pip install nuitka
script:
- python -m nuitka --nofollow-imports --include-package=arkindex --nofollow-import-to=*.tests arkindex/manage.py
except:
- schedules
backend-static-deploy:
stage: deploy
image: python:3-slim
......
......@@ -43,26 +43,6 @@ You will need to edit the ImageMagick policy file to get PDF and Image imports t
The line that sets the PDF policy is `<policy domain="coder" rights="none" pattern="PDF" />`. Replace `none` with `read|write` for it to work. See [this StackOverflow question](https://stackoverflow.com/questions/52998331) for more info.
### GitLab OAuth setup
Arkindex uses OAuth to let a user connect their GitLab account(s) and register Git repositories. In local development, you will need to register Arkindex as a GitLab OAuth application for it to work.
Go to GitLab's [Applications settings](https://gitlab.teklia.com/profile/applications) and create a new application with the `api` scope and add the following callback URIs:
```
http://127.0.0.1:8000/api/v1/oauth/providers/gitlab/callback/
http://ark.localhost:8000/api/v1/oauth/providers/gitlab/callback/
https://ark.localhost/api/v1/oauth/providers/gitlab/callback/
```
Once the application is created, GitLab will provide you with an application ID and a secret. Use the `arkindex/config.yml` file to set them:
```yaml
gitlab:
app_id: 24cacf5004bf68ae9daad19a5bba391d85ad1cb0b31366e89aec86fad0ab16cb
app_secret: 9d96d9d5b1addd7e7e6119a23b1e5b5f68545312bfecb21d1cdc6af22b8628b8
```
### Local image server
Arkindex splits up image URLs in their image server and the image path. For example, a IIIF server at `http://iiif.irht.cnrs.fr/iiif/` and an image at `/Paris/JJ042/1.jpg` would be represented as an ImageServer instance holding one Image. Since Arkindex has a local IIIF server for image uploads and thumbnails, a special instance of ImageServer is required to point to this local server. In local development, this server should be available at `https://ark.localhost/iiif`. You will therefore need to create an ImageServer via the Django admin or the Django shell with this URL. To set the local server ID, you can add a custom setting in `arkindex/config.yml`:
......@@ -161,9 +141,6 @@ SHELL_PLUS_POST_IMPORTS = [
)),
('arkindex.project.aws', (
'S3FileStatus',
)),
('arkindex.users.models', (
'OAuthStatus',
))
]
```
......
1.5.2
1.5.3
......@@ -9,7 +9,6 @@ from django.conf import settings
from django.core.exceptions import ValidationError as DjangoValidationError
from django.db import connection, transaction
from django.db.models import (
CharField,
Count,
Exists,
F,
......@@ -1625,8 +1624,8 @@ class TranscriptionsPagination(PageNumberPagination):
class ElementTranscriptions(ListAPIView):
"""
List all transcriptions for an element, optionally filtered by type or worker version id.
Recursive parameter allow listing transcriptions on sub-elements,
otherwise element fields in the response will be set to null.
Recursive parameter allow listing transcriptions on sub-elements.
Element fields in the response are only set when using the recursive parameter.
"""
serializer_class = ElementTranscriptionSerializer
pagination_class = TranscriptionsPagination
......@@ -1655,78 +1654,83 @@ class ElementTranscriptions(ListAPIView):
return context
def get_queryset(self):
# ORDER BY casting IDs as char to avoid the PostgreSQL optimizer's inefficient scan
# TODO: See if select_related is faster than a prefetch on this endpoint
queryset = Transcription.objects \
.prefetch_related('worker_version') \
.annotate(char_id=Cast('id', output_field=CharField())) \
.order_by('char_id')
queryset = Transcription.objects.select_related('worker_run')
if self.is_recursive:
# The transcription's `element` field is only included when recursive=true,
# so we add the prefetch here
queryset = queryset.prefetch_related('element__image__server', 'element__type')
queryset = queryset.filter(element__in=(
Element
.objects
# Transcriptions from the current element
.filter(id=self.element.id)
# We are about to use a UNION; we need to explicitly say we only want the ID column in the SELECT,
# because Django will otherwise pick all Element attributes, which is not supported by `__in`.
.values('id')
# Add the current element and the child elements together using a UNION
# Using Q(element_id=…) | Q(element__in=…) makes PostgreSQL use very slow nested loops and multi-processing.
.union(
# Transcriptions from all children of the current element.
# We are not using get_descending, because it includes an ORDER BY clause and a DISTINCT clause,
# which are both unnecessary here.
ElementPath.objects.filter(path__contains=[self.element.id]).values('element_id'),
# Use UNION ALL so that PostgreSQL does not unnecessarily sort and deduplicate UUIDs
all=True,
)
))
# List and filter children results. Current element transcriptions
# are conditionally added after filtering the queryset.
queryset = (
queryset
.filter(element__paths__path__overlap=[self.element.id])
# Also filter by corpus ID for better performance
.filter(element__corpus_id=self.element.corpus_id)
# Transcription's `element` field is only included when recursive=true
.select_related('element__type')
)
else:
queryset = queryset.filter(element_id=self.element.id)
return queryset
def filter_queryset(self, queryset):
errors = {}
filters = Q()
errors = defaultdict(list)
# Filter by worker run
if 'worker_run' in self.request.query_params:
worker_run_id = self.request.query_params['worker_run']
if worker_run_id.lower() in ('false', '0'):
# Restrict to transcriptions without worker runs
queryset = queryset.filter(worker_run_id=None)
filters &= Q(worker_run_id=None)
else:
try:
queryset = queryset.filter(worker_run_id=worker_run_id)
except DjangoValidationError as e:
errors['worker_run'] = e.messages
filters &= Q(worker_run_id=uuid.UUID(worker_run_id))
except (TypeError, ValueError):
errors['worker_run'].append(f'{worker_run_id}” is not a valid UUID.')
# Filter by worker version
if 'worker_version' in self.request.query_params:
worker_version_id = self.request.query_params['worker_version']
if worker_version_id.lower() in ('false', '0'):
# Restrict to transcriptions without worker versions
queryset = queryset.filter(worker_version_id=None)
# Restrict to transcriptions without worker runs
filters &= Q(worker_version_id=None)
else:
try:
queryset = queryset.filter(worker_version_id=worker_version_id)
except DjangoValidationError as e:
errors['worker_version'] = e.messages
filters &= Q(worker_version_id=uuid.UUID(worker_version_id))
except (TypeError, ValueError):
errors['worker_version'].append(f'{worker_version_id}” is not a valid UUID.')
# Filter by element_type
element_type = self.request.query_params.get('element_type')
if element_type:
queryset = queryset.select_related('element__type').filter(element__type__slug=element_type)
elt_type_filter = self.request.query_params.get('element_type')
if elt_type_filter:
queryset = queryset.filter(element__type__slug=elt_type_filter)
if errors:
raise ValidationError(errors)
return queryset
queryset = queryset.filter(filters)
# Perform a UNION after applying filters including parent transcriptions.
# This has better performance than a OR clause,
# especially when filtering by element type.
# https://gitlab.teklia.com/arkindex/backend/-/merge_requests/2180
if (
self.is_recursive
and (
elt_type_filter is None
or elt_type_filter == self.element.type.slug
)
):
queryset = queryset.union(
(
self.element.transcriptions
.select_related('element__type', 'worker_run')
.filter(filters)
),
# No element can be duplicated here
all=True,
)
return queryset.order_by('id')
@extend_schema_view(
......
......@@ -446,10 +446,10 @@ class TranscriptionEntities(ListAPIView):
raise serializers.ValidationError(errors)
transcription = get_object_or_404(
Transcription.objects.filter(
id=self.kwargs['pk'],
element__corpus__in=Corpus.objects.readable(self.request.user),
).only("id")
Transcription.objects
.using("default")
.filter(id=self.kwargs['pk'], element__corpus__in=Corpus.objects.readable(self.request.user))
.only("id")
)
return (
......
......@@ -54,6 +54,8 @@ class TranscriptionCreate(ACLMixin, CreateAPIView):
def get_object(self):
if not hasattr(self, 'element'):
self.element = super().get_object()
if not self.has_access(self.element.corpus, Role.Contributor.value):
raise PermissionDenied(detail="A write access to the element's corpus is required.")
return self.element
def get_queryset(self):
......@@ -65,8 +67,11 @@ class TranscriptionCreate(ACLMixin, CreateAPIView):
# We retrieve the readable objects then check permissions
# instead of retrieving writable objects directly so as not to
# get 404_NOT_FOUND errors on elements the user has access to.
return Element.objects.using('default').filter(
corpus__in=Corpus.objects.readable(self.request.user)
return (
Element.objects
.using('default')
.filter(corpus__in=Corpus.objects.readable(self.request.user))
.select_related('corpus')
)
def get_serializer_context(self):
......@@ -75,23 +80,10 @@ class TranscriptionCreate(ACLMixin, CreateAPIView):
context['element'] = self.get_object()
return context
def check_object_permissions(self, request, context):
super().check_object_permissions(request, context)
role = Role.Contributor
detail = "A write access to the element's corpus is required."
if not self.has_access(context.corpus, role.value):
raise PermissionDenied(detail=detail)
def perform_create(self, serializer):
return Transcription.objects.create(
element=self.element,
**serializer.validated_data
)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
obj = self.perform_create(serializer)
obj = serializer.save()
headers = self.get_success_headers(serializer.data)
return Response(
# Use a single transcription serializer for the response
......
This diff is collapsed.
......@@ -10,14 +10,7 @@ from rest_framework.authtoken.models import Token
from arkindex.images.models import ImageServer
from arkindex.ponos.models import Farm
from arkindex.process.models import (
Repository,
Worker,
WorkerType,
WorkerVersion,
WorkerVersionGPUUsage,
WorkerVersionState,
)
from arkindex.process.models import FeatureUsage, Repository, Worker, WorkerType, WorkerVersion, WorkerVersionState
from arkindex.users.models import User
# Constants used in architecture project
......@@ -165,9 +158,6 @@ class Command(BaseCommand):
# Create a fake worker version on a fake worker on a fake repo with a fake revision for file imports
repo, created = Repository.objects.get_or_create(
url=IMPORT_WORKER_REPO,
defaults={
"hook_token": str(uuid4()),
}
)
if created:
self.success(f'Created Git repository for {IMPORT_WORKER_REPO}')
......@@ -210,7 +200,7 @@ class Command(BaseCommand):
'id': IMPORT_WORKER_VERSION_ID,
'configuration': {},
'state': WorkerVersionState.Created,
'gpu_usage': WorkerVersionGPUUsage.Disabled,
'gpu_usage': FeatureUsage.Disabled,
'docker_image': None,
'docker_image_iid': None,
}
......@@ -243,7 +233,7 @@ class Command(BaseCommand):
# Ensure it has the right attributes
version.configuration = {}
version.state = WorkerVersionState.Created
version.gpu_usage = WorkerVersionGPUUsage.Disabled
version.gpu_usage = FeatureUsage.Disabled
version.docker_image = None
version.docker_image_iid = None
version.save()
......
......@@ -10,13 +10,14 @@ from arkindex.documents.models import Corpus, Element, MetaData, MetaType
from arkindex.images.models import Image, ImageServer
from arkindex.ponos.models import Farm, State
from arkindex.process.models import (
FeatureUsage,
Process,
ProcessMode,
Repository,
Worker,
WorkerRun,
WorkerType,
WorkerVersion,
WorkerVersionGPUUsage,
WorkerVersionState,
)
from arkindex.project.tools import fake_now
......@@ -69,19 +70,9 @@ class Command(BaseCommand):
level=Role.Guest.value
)
# Create OAuth credentials for a user
creds = user.credentials.create(
provider_url='https://somewhere',
token='oauth-token',
refresh_token='refresh-token',
# Use an expiry very far away to avoid OAuth token refreshes in every test
expiry=datetime(2100, 12, 31, 23, 59, 59, 999999, timezone.utc),
)
# Create a GitLab worker repository
gitlab_repo = creds.repos.create(
gitlab_repo = Repository.objects.create(
url='http://gitlab/repo',
hook_token='hook-token',
)
# Create a revision on this repository
......@@ -92,9 +83,8 @@ class Command(BaseCommand):
)
# Create another worker repository
worker_repo = creds.repos.create(
worker_repo = Repository.objects.create(
url="http://my_repo.fake/workers/worker",
hook_token='worker-hook-token',
)
# Create a revision on this repository
......@@ -140,7 +130,7 @@ class Command(BaseCommand):
revision=revision,
configuration={'test': 42},
state=WorkerVersionState.Available,
model_usage=False,
model_usage=FeatureUsage.Disabled,
docker_image=docker_image
)
dla_worker = WorkerVersion.objects.create(
......@@ -152,7 +142,7 @@ class Command(BaseCommand):
revision=revision,
configuration={'test': 42},
state=WorkerVersionState.Available,
model_usage=False,
model_usage=FeatureUsage.Disabled,
docker_image=docker_image
)
......@@ -165,7 +155,7 @@ class Command(BaseCommand):
revision=revision,
configuration={},
state=WorkerVersionState.Available,
model_usage=False,
model_usage=FeatureUsage.Disabled,
docker_image=docker_image,
)
......@@ -178,9 +168,9 @@ class Command(BaseCommand):
revision=revision,
configuration={'test': 42},
state=WorkerVersionState.Available,
model_usage=False,
model_usage=FeatureUsage.Disabled,
docker_image=docker_image,
gpu_usage=WorkerVersionGPUUsage.Required
gpu_usage=FeatureUsage.Required
)
# Create a generic worker and its version that uses a ML Model
......@@ -193,8 +183,8 @@ class Command(BaseCommand):
revision=revision,
configuration={'test': 42},
state=WorkerVersionState.Available,
gpu_usage=WorkerVersionGPUUsage.Disabled,
model_usage=True,
gpu_usage=FeatureUsage.Disabled,
model_usage=FeatureUsage.Required,
docker_image=docker_image
)
......
......@@ -2,7 +2,6 @@
import json
import os
import sqlite3
import uuid
from datetime import datetime, timezone
from pathlib import Path
......@@ -410,9 +409,6 @@ class Command(BaseCommand):
def create_repository(self, row):
repo, created = Repository.objects.get_or_create(
url=row['repository_url'],
defaults={
'hook_token': str(uuid.uuid4()),
},
)
return repo, created
......
......@@ -935,6 +935,7 @@ class ElementBulkSerializer(serializers.Serializer):
type_ids = dict(
ElementType
.objects
.using('default')
.filter(
corpus_id=self.context['element'].corpus_id,
slug__in=type_slugs)
......
......@@ -362,6 +362,12 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer):
return data
def create(self, validated_data):
return Transcription.objects.create(
element=self.context['element'],
**validated_data,
)
class SimpleTranscriptionSerializer(serializers.Serializer):
"""
......
......@@ -33,7 +33,7 @@ class TestLoadExport(FixtureTestCase):
unexpected_fields_by_model = {
'documents.elementtype': ['display_name', 'indexable'],
'documents.mlclass': [],
'process.repository': ['hook_token', 'credentials', 'git_ref_revisions'],
'process.repository': ['git_ref_revisions'],
'process.worker': [],
'process.revision': ['message', 'author'],
'process.workerversion': ['created', 'updated', 'configuration', 'state', 'docker_image', 'docker_image_iid'],
......
......@@ -96,7 +96,6 @@ class TestDeleteCorpus(FixtureTestCase):
# Create a separate corpus that should not get anything deleted
cls.repo = Repository.objects.create(
url='http://lol.git',
hook_token='h00k',
)
cls.corpus2 = Corpus.objects.create(name='Other corpus')
......
......@@ -43,7 +43,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
def test_write_right(self):
self.client.force_login(self.private_read_user)
with self.assertNumQueries(7):
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.private_page.id}),
format='json',
......@@ -76,7 +76,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
def test_manual(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -104,7 +104,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
"""
self.client.force_login(self.user)
ts = self.line.transcriptions.create(text='GLOUBIBOULGA')
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -123,7 +123,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
Check that a transcription is created with the specified orientation
"""
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -147,7 +147,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
Specifying an invalid text-orientation causes an error
"""
self.client.force_login(self.user)
with self.assertNumQueries(7):
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -158,7 +158,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
@override_settings(ARKINDEX_FEATURES={'search': False})
def test_no_search(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -168,7 +168,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
def test_worker_version(self):
self.client.force_login(self.user)
with self.assertNumQueries(7):
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -189,7 +189,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
"""
self.client.force_login(self.user)
with self.assertNumQueries(9):
with self.assertNumQueries(10):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -227,7 +227,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
self.worker_run.process.run()
task = self.worker_run.process.tasks.first()
with self.assertNumQueries(8):
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -263,7 +263,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
self.worker_run.process.run()
task = self.worker_run.process.tasks.first()
with self.assertNumQueries(8):
with self.assertNumQueries(10):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -297,7 +297,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
"""
self.client.force_login(self.superuser)
with self.assertNumQueries(5):
with self.assertNumQueries(4):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -320,7 +320,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
worker_run = self.user.processes.get(mode=ProcessMode.Local).worker_runs.first()
self.client.force_login(self.superuser)
with self.assertNumQueries(5):
with self.assertNumQueries(4):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -349,7 +349,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
self.worker_run.process.run()
task = self.worker_run.process.tasks.first()
with self.assertNumQueries(7):
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -368,7 +368,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
def test_worker_run_not_found(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -386,7 +386,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
def test_worker_run_required_confidence(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -402,7 +402,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
def test_worker_version_xor_worker_run(self):
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......@@ -424,7 +424,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
null_zone_page = self.corpus.elements.create(type=self.page.type)
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': null_zone_page.id}),
format='json',
......@@ -450,7 +450,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
"""
self.client.force_login(self.user)
with self.assertNumQueries(9):
with self.assertNumQueries(10):
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
......
......@@ -44,7 +44,7 @@ class TestTranscriptions(FixtureAPITestCase):
)
self.client.force_login(self.user)
with self.assertNumQueries(8):
with self.assertNumQueries(7):
response = self.client.get(reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
......@@ -72,7 +72,7 @@ class TestTranscriptions(FixtureAPITestCase):
def test_list_transcriptions_recursive(self):
self.client.force_login(self.user)
with self.assertNumQueries(12):
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true'}
......@@ -94,7 +94,7 @@ class TestTranscriptions(FixtureAPITestCase):
def test_list_transcriptions_recursive_filter_element_type(self):
self.client.force_login(self.user)
with self.assertNumQueries(10):
with self.assertNumQueries(8):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true', 'element_type': 'page'}
......@@ -119,7 +119,7 @@ class TestTranscriptions(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(12):
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true', 'worker_version': str(self.worker_version_2.id)}
......@@ -156,7 +156,7 @@ class TestTranscriptions(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(11):
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true', 'worker_version': False}
......@@ -215,7 +215,7 @@ class TestTranscriptions(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(13):
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true', 'worker_version': str(self.worker_version_2.id)}
......@@ -257,7 +257,7 @@ class TestTranscriptions(FixtureAPITestCase):
self.client.force_login(self.user)
with self.assertNumQueries(13):
with self.assertNumQueries(7):
response = self.client.get(
reverse('api:element-transcriptions', kwargs={'pk': str(self.page.id)}),
data={'recursive': 'true', 'worker_run': str(self.worker_run.id)}
......
# Generated by Django 4.1.7 on 2023-11-28 11:02
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('process', '0023_alter_workerversion_model_usage'),
('ponos', '0005_remove_task_tags'),
]
operations = [
migrations.AddField(
model_name='task',
name='worker_run',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tasks', to='process.workerrun'),
),
]
......@@ -502,6 +502,13 @@ class Task(models.Model):
related_name="tasks",
on_delete=models.CASCADE,
)
worker_run = models.ForeignKey(
'process.WorkerRun',
on_delete=models.SET_NULL,
related_name='tasks',
null=True,
blank=True,
)
parents = models.ManyToManyField(
"self",
related_name="children",
......
import hashlib
import logging
import uuid
from urllib.parse import urljoin
from textwrap import dedent
from django.conf import settings
from django.core.mail import send_mail
from django.db import transaction
from django.db.models import Exists, OuterRef, Prefetch, prefetch_related_objects
from django.db.models import Exists, OuterRef
from django.shortcuts import reverse
from django.template.loader import render_to_string
from django.utils import timezone
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
......@@ -29,8 +27,8 @@ from arkindex.ponos.models import (
)
from arkindex.ponos.serializer_fields import Base64Field, CurrentProcessDefault, PublicKeyField
from arkindex.ponos.signals import task_failure
from arkindex.process.models import ProcessMode
from arkindex.project.serializer_fields import EnumField
from arkindex.project.triggers import notify_process_completion
logger = logging.getLogger(__name__)
......@@ -99,8 +97,17 @@ class TaskLightSerializer(serializers.ModelSerializer):
Serializes a :class:`~arkindex.ponos.models.Task` instance without logs or agent information.
Used to list tasks inside a process.
"""
state = EnumField(
State,
help_text=dedent("""
Allowed transitions for the state of a task by an agent are defined below:
state = EnumField(State)
Pending ⟶ Running ⟶ Completed
└⟶ Error ├⟶ Failed
└⟶ Error
Stopping ⟶ Stopped
""").strip(),
)
class Meta:
model = Task
......@@ -122,6 +129,17 @@ class TaskLightSerializer(serializers.ModelSerializer):
"shm_size",
)
def validate_state(self, state):
allowed_transitions = {
State.Unscheduled: [State.Pending],
State.Pending: [State.Running, State.Error],
State.Running: [State.Completed, State.Failed, State.Stopping, State.Error],
State.Stopping: [State.Stopped],
}
if self.instance and state not in allowed_transitions.get(self.instance.state, []):
raise ValidationError(f'Transition from state {self.instance.state} to state {state} is forbidden.')
return state
class TaskSerializer(TaskLightSerializer):
"""
......@@ -199,65 +217,7 @@ class TaskSerializer(TaskLightSerializer):
instance.process.finished = timezone.now()
instance.process.save()
# For any process except Repository mode, notify the creator by email
state_msg = {
State.Completed: "successfully",
State.Failed: "with failures",
State.Error: "with errors",
State.Stopped: "because it was stopped",
}
prefetch_related_objects([instance.process], Prefetch(
'tasks',
# Avoid a stale read, since we just updated a task's state
Task.objects.using("default")
# Pick only the tasks of the last run, computed in a subquery
.filter(run=instance.process.tasks.order_by('-run').values('run')[:1])
# Sort the tasks to get a sorted output in the email
.order_by('depth', 'slug')
))
state = instance.process.state
request = self.context.get("request")
if (
request
and instance.process.mode != ProcessMode.Repository
and state in state_msg.keys()
):
process_name = instance.process.name if instance.process.name else str(instance.process.id)
current_run = instance.process.get_last_run()
tasks_stats = {
# Only tasks of the last run are prefetched
task.slug: task.state for task in instance.process.tasks.all()
}
sent = send_mail(
subject=f'Your process {process_name} finished {state_msg[state]}',
message=render_to_string(
'process_completion.html',
context={
'process': instance.process,
'state': state,
'user': instance.process.creator,
'run': current_run,
'tasks_stats': tasks_stats,
'url': urljoin(
settings.PUBLIC_HOSTNAME,
reverse('frontend-process-details', kwargs={
'pk': instance.process.id,
'run': current_run,
})
),
},
request=self.context.get("request"),
),
from_email=None,
recipient_list=[instance.process.creator.email],
fail_silently=True,
)
if sent == 0:
logger.error(
f'Failed to send status email for process {instance.process.id}'
f' to {instance.process.creator.email}'
)
notify_process_completion(instance.process)
# We already checked earlier that the task was in a final state.
# If this state is both final and not completed, then we should trigger the task failure signal.
......@@ -273,26 +233,48 @@ class TaskTinySerializer(TaskSerializer):
Used by humans to update a task.
"""
state = EnumField(State)
state = EnumField(
State,
help_text=dedent("""
Allowed transitions for the state of a task by a user are defined below:
Completed ⟶ Pending
Failed ⟶ Pending
Error ⟶ Pending
Stopped ⟶ Pending
Running ⟶ Stopping
""").strip(),
)
class Meta:
model = Task
fields = ("id", "state")
read_only_fields = ("id",)
def validate_state(self, state):
"""
Only allow a user to manually stop or retry a task
"""
allowed_transitions = {
state: [State.Pending] for state in FINAL_STATES
}
allowed_transitions.update({State.Running: [State.Stopping]})
if self.instance and state not in allowed_transitions.get(self.instance.state, []):
raise ValidationError(f'Transition from state {self.instance.state} to state {state} is forbidden.')
return state
def update(self, instance: Task, validated_data) -> Task:
new_state = validated_data.get('state')
if new_state == State.Stopping and instance.state != State.Running:
raise ValidationError("You can only stop a 'Running' task")
if new_state == State.Pending:
if instance.state not in FINAL_STATES:
raise ValidationError(
"You can only restart a task with a state equal to 'Completed', 'Failed', 'Error' or 'Stopped'."
)
if instance.state == State.Unscheduled:
# Prevent a user from restarting a task that was never assigned to an agent.
raise ValidationError({
'state': [f'Transition from state {State.Unscheduled} to state {State.Pending} is forbidden.']
})
# Restart the task
instance.agent = None
instance.gpu = None
# Un-finish the process since a task will run again
instance.process.finished = None
instance.process.save()
......@@ -476,6 +458,7 @@ class TaskDefinitionSerializer(serializers.ModelSerializer):
image_artifact_url = serializers.SerializerMethodField()
s3_logs_put_url = serializers.SerializerMethodField()
extra_files = serializers.DictField(default={})
state = EnumField(State)
@extend_schema_field(serializers.URLField(allow_null=True))
def get_image_artifact_url(self, task):
......@@ -515,6 +498,7 @@ class TaskDefinitionSerializer(serializers.ModelSerializer):
"process_id",
"gpu_id",
"extra_files",
"state",
)
read_only_fields = fields
......
import logging
from urllib.parse import urljoin
from django.conf import settings
from django.core.mail import send_mail
from django.db.models import Count, F, Q
from django.db.models.functions import Round
from django.shortcuts import reverse
from django.template.loader import render_to_string
from django_rq import job
from arkindex.process.models import Process, WorkerActivityState
logger = logging.getLogger(__name__)
@job('default', timeout=settings.RQ_TIMEOUTS['notify_process_completion'])
def notify_process_completion(
process: Process,
subject: str = 'Your process finished',
) -> None:
current_run = process.get_last_run()
# Inspect the state of tasks in the last run
tasks_stats = {
task.slug: task.state
for task in process.tasks.all()
# Run check is done in Python as tasks can be prefetched for the last run only
if task.run == current_run
}
# Aggregate statistics about worker activities failures
worker_failures = (
process.activities
.values('worker_version__worker__name')
.annotate(
failures=Count('id', filter=Q(state=WorkerActivityState.Error)),
total=Count('id')
)
.annotate(
percent=Round(100 * F('failures') / F('total'))
)
.filter(failures__gt=0)
.values('worker_version__worker__name', 'failures', 'total', 'percent')
)
send_mail(
subject=subject,
message=render_to_string(
'process_completion.html',
context={
'process': process,
'run': current_run,
'tasks_stats': tasks_stats,
'worker_failures': worker_failures,
'url': urljoin(
settings.PUBLIC_HOSTNAME,
reverse('frontend-process-details', kwargs={
'pk': process.id,
'run': current_run,
})
),
},
),
from_email=None,
recipient_list=[process.creator.email],
fail_silently=False,
)