Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (33)
Showing
with 427 additions and 198 deletions
......@@ -20,7 +20,7 @@ include:
before_script:
# Custom line to install our own deps from Git using GitLab CI credentials
- "pip install -e git+https://gitlab-ci-token:${CI_JOB_TOKEN}@gitlab.com/teklia/arkindex/transkribus#egg=transkribus-client"
- pip install -r tests-requirements.txt codecov
- pip install -r tests-requirements.txt
- "echo 'database: {host: postgres, port: 5432}\npublic_hostname: http://ci.arkindex.localhost' > $CONFIG_PATH"
# Those jobs require the base image; they might fail if the image is not up to date.
......@@ -61,7 +61,6 @@ backend-tests:
script:
- python3 setup.py test
- codecov
backend-lint:
image: python:3.10
......
1.4.1-beta1
1.4.1
......@@ -157,9 +157,19 @@ class EntityRoleAdmin(admin.ModelAdmin):
class EntityTypeAdmin(admin.ModelAdmin):
list_display = ('id', 'name', 'color')
list_display = ('id', 'corpus', 'name', 'color')
list_filter = ('corpus', )
readonly_fields = ('id', 'corpus', )
def get_readonly_fields(self, request, obj=None):
# Make the corpus field read-only only for existing entity types.
# Otherwise, new EntityTypes would be created with corpus=None
if obj:
return ('id', 'corpus')
return ('id', )
def has_delete_permission(self, request, obj=None):
# Require everyone to use the frontend or DestroyEntityType
return False
admin.site.register(Corpus, CorpusAdmin)
......
......@@ -78,6 +78,7 @@ from arkindex.documents.serializers.elements import (
from arkindex.documents.serializers.light import CorpusAllowedMetaDataSerializer, ElementTypeLightSerializer
from arkindex.documents.serializers.ml import ElementTranscriptionSerializer
from arkindex.images.models import Image
from arkindex.ponos.utils import is_admin_or_ponos_task
from arkindex.process.models import WorkerRun, WorkerVersion
from arkindex.project.fields import Unnest
from arkindex.project.mixins import ACLMixin, CorpusACLMixin, SelectionMixin
......@@ -1316,8 +1317,23 @@ class ElementNeighbors(ACLMixin, ListAPIView):
@extend_schema(tags=['elements'], request=None)
@extend_schema_view(
post=extend_schema(operation_id='CreateElementParent', description='Link an element to a new parent'),
delete=extend_schema(operation_id='DestroyElementParent', description='Delete the relation between an element and one of its parents'),
post=extend_schema(
operation_id='CreateElementParent',
description='Link an element to a new parent',
parameters=[
OpenApiParameter(
'type_ordering',
type=bool,
default=True,
description='Add the child element at the last position in the parent relative to '
'only the elements of the same type, or to all elements.',
)
],
),
delete=extend_schema(
operation_id='DestroyElementParent',
description='Delete the relation between an element and one of its parents',
),
)
class ElementParent(CreateAPIView, DestroyAPIView):
"""
......@@ -1326,8 +1342,12 @@ class ElementParent(CreateAPIView, DestroyAPIView):
serializer_class = ElementParentSerializer
permission_classes = (IsVerified, )
@property
def type_ordering(self):
return self.request.query_params.get('type_ordering', 'true').lower() not in ('0', 'false')
def get_serializer_from_params(self, child=None, parent=None, **kwargs):
data = {'child': child, 'parent': parent}
data = {'child': child, 'parent': parent, 'type_ordering': self.type_ordering}
kwargs['context'] = self.get_serializer_context()
return ElementParentSerializer(data=data, **kwargs)
......@@ -1593,6 +1613,13 @@ class TranscriptionsPagination(PageNumberPagination):
'If set to `False`, only include transcriptions created by humans.',
required=False,
),
OpenApiParameter(
'worker_run',
type=UUID_OR_FALSE,
description='Only include transcriptions created by a specific worker run. '
'If set to `False`, only include transcriptions created by no worker run.',
required=False,
),
OpenApiParameter(
'element_type',
description='Filter transcriptions by element type',
......@@ -1671,6 +1698,20 @@ class ElementTranscriptions(ListAPIView):
return queryset
def filter_queryset(self, queryset):
errors = {}
# Filter by worker run
if 'worker_run' in self.request.query_params:
worker_run_id = self.request.query_params['worker_run']
if worker_run_id.lower() in ('false', '0'):
# Restrict to transcriptions without worker runs
queryset = queryset.filter(worker_run_id=None)
else:
try:
queryset = queryset.filter(worker_run_id=worker_run_id)
except DjangoValidationError as e:
errors['worker_run'] = e.messages
# Filter by worker version
if 'worker_version' in self.request.query_params:
worker_version_id = self.request.query_params['worker_version']
......@@ -1681,13 +1722,16 @@ class ElementTranscriptions(ListAPIView):
try:
queryset = queryset.filter(worker_version_id=worker_version_id)
except DjangoValidationError as e:
raise ValidationError({'worker_version': e.messages})
errors['worker_version'] = e.messages
# Filter by element_type
element_type = self.request.query_params.get('element_type')
if element_type:
queryset = queryset.select_related('element__type').filter(element__type__slug=element_type)
if errors:
raise ValidationError(errors)
return queryset
......@@ -1815,7 +1859,7 @@ class ElementMetadata(ListCreateAPIView):
def perform_create(self, serializer):
instance = serializer.save()
if self.request.user.is_admin or self.request.user.is_internal:
if is_admin_or_ponos_task(self.request):
AllowedMetaData.objects.get_or_create(name=instance.name, type=instance.type, corpus=instance.element.corpus)
......@@ -1847,7 +1891,7 @@ class ElementMetadataBulk(CreateAPIView):
def perform_create(self, serializer):
instances = serializer.save()
if self.request.user.is_admin or self.request.user.is_internal:
if is_admin_or_ponos_task(self.request):
AllowedMetaData.objects.bulk_create(
[
AllowedMetaData(name=instance.name, type=instance.type, corpus=instance.element.corpus)
......
......@@ -37,7 +37,7 @@ from arkindex.documents.serializers.entities import (
TranscriptionEntitySerializer,
)
from arkindex.documents.serializers.light import EntityTypeLightSerializer
from arkindex.process.models import WorkerVersion
from arkindex.process.models import WorkerRun, WorkerVersion
from arkindex.project.mixins import ACLMixin, CorpusACLMixin
from arkindex.project.openapi import UUID_OR_FALSE
from arkindex.project.permissions import IsVerified, IsVerifiedOrReadOnly
......@@ -373,13 +373,27 @@ class TranscriptionEntityCreate(CreateAPIView):
'If set to `False`, only include transcription entities created by humans.',
required=False,
),
OpenApiParameter(
'worker_run',
type=UUID_OR_FALSE,
description='Only include transcription entities created by a specific worker run. '
'If set to `False`, only include transcription entities created by no worker run.',
required=False,
),
OpenApiParameter(
'entity_worker_version',
type=UUID_OR_FALSE,
description='Only include transcription entities whose entity was created by a specific worker version. '
'If set to `False`, only include transcription entities whose entity was created by humans.',
required=False,
)
),
OpenApiParameter(
'entity_worker_run',
type=UUID_OR_FALSE,
description='Only include transcription entities whose entity was created by a specific worker run. '
'If set to `False`, only include transcription entities whose entity was created by no worker run.',
required=False,
),
]
)
)
......@@ -393,34 +407,41 @@ class TranscriptionEntities(ListAPIView):
# For OpenAPI type discovery: a transcription's ID is in the path
queryset = Transcription.objects.none()
def parse_worker_version(self, value):
def parse_model_id(self, model, value):
if value.lower() in ('false', '0'):
return None
try:
validated = UUID(value)
except (TypeError, ValueError):
raise serializers.ValidationError(['Invalid UUID.'])
if not WorkerVersion.objects.filter(id=validated).exists():
raise serializers.ValidationError(['This worker version does not exist.'])
if not model.objects.filter(id=validated).exists():
raise serializers.ValidationError([f'This {model._meta.verbose_name} does not exist.'])
return validated
def get_queryset(self):
filters = {}
errors = defaultdict()
if 'worker_version' in self.request.query_params:
try:
worker_version_id = self.parse_worker_version(self.request.query_params['worker_version'])
filters['worker_version_id'] = worker_version_id
except serializers.ValidationError as e:
errors['worker_version'] = e.detail
if 'entity_worker_version' in self.request.query_params:
try:
entity_worker_version_id = self.parse_worker_version(self.request.query_params['entity_worker_version'])
filters['entity__worker_version_id'] = entity_worker_version_id
except serializers.ValidationError as e:
errors['entity_worker_version'] = e.detail
# List of WorkerVersion and WorkerRun filters to handle:
# (query parameter, QuerySet field name, model class)
worker_filters = (
('worker_version', 'worker_version_id', WorkerVersion),
('entity_worker_version', 'entity__worker_version_id', WorkerVersion),
('worker_run', 'worker_run_id', WorkerRun),
('entity_worker_run', 'entity__worker_run_id', WorkerRun),
)
for query_param, field_name, model in worker_filters:
if query_param in self.request.query_params:
try:
value = self.parse_model_id(model, self.request.query_params[query_param])
filters[field_name] = value
except serializers.ValidationError as e:
errors[query_param] = e.detail
if errors:
raise serializers.ValidationError(errors)
......@@ -438,7 +459,7 @@ class TranscriptionEntities(ListAPIView):
**filters,
)
.order_by('offset')
.select_related('entity__type')
.select_related('entity__type', 'worker_run')
)
......
......@@ -1648,30 +1648,12 @@
"display_name": "Admin",
"transkribus_email": null,
"is_active": true,
"is_internal": false,
"is_admin": true,
"verified_email": true,
"created": "2020-02-02T01:23:45.678Z",
"updated": "2020-02-02T01:23:45.678Z"
}
},
{
"model": "users.user",
"pk": 2,
"fields": {
"password": "pbkdf2_sha256$390000$QDrLXttZfp4Tq3zzbzz4j5$JIgbn37CuP7iU1MAinYJsmmjeS8F2hjZaDOloo560G0=",
"last_login": null,
"email": "internal@internal.fr",
"display_name": "Internal user",
"transkribus_email": null,
"is_active": true,
"is_internal": true,
"is_admin": false,
"verified_email": true,
"created": "2020-02-02T01:23:45.678Z",
"updated": "2020-02-02T01:23:45.678Z"
}
},
{
"model": "users.user",
"pk": 3,
......@@ -1682,7 +1664,6 @@
"display_name": "Test user",
"transkribus_email": null,
"is_active": true,
"is_internal": false,
"is_admin": false,
"verified_email": true,
"created": "2020-02-02T01:23:45.678Z",
......@@ -1699,7 +1680,6 @@
"display_name": "Test user write",
"transkribus_email": null,
"is_active": true,
"is_internal": false,
"is_admin": false,
"verified_email": false,
"created": "2020-02-02T01:23:45.678Z",
......@@ -1716,7 +1696,6 @@
"display_name": "Test user read",
"transkribus_email": null,
"is_active": true,
"is_internal": false,
"is_admin": false,
"verified_email": false,
"created": "2020-02-02T01:23:45.678Z",
......
......@@ -32,7 +32,7 @@ IMPORT_WORKER_SLUG = 'file_import'
IMPORT_WORKER_REPO = 'https://gitlab.com/teklia/arkindex/tasks'
IMPORT_WORKER_REVISION_MESSAGE = 'File import worker bootstrap'
IMPORT_WORKER_REVISION_AUTHOR = 'Dev Bootstrap'
INTERNAL_API_TOKEN = "deadbeefTestToken"
ADMIN_API_TOKEN = "deadbeefTestToken"
class Command(BaseCommand):
......@@ -51,14 +51,13 @@ class Command(BaseCommand):
self.stdout.write(self.style.ERROR(f"{msg}"))
def check_user(self, user):
"""Ensure a user is admin + internal"""
if user.is_internal and user.is_admin:
self.success(f"Internal user {user} is valid")
"""Ensure a user is admin"""
if user.is_admin:
self.success(f"Admin user for legacy worker API tokens {user} is valid")
else:
user.is_internal = True
user.is_admin = True
user.save()
self.warn(f"Updated user {user} to internal+admin")
self.warn(f"Updated user {user} to admin")
def handle(self, **options):
# Never allow running this script in production
......@@ -92,17 +91,16 @@ class Command(BaseCommand):
)
self.success("Ponos farm created")
# an internal API user with a specific token
# An admin API user with a specific token
try:
token = Token.objects.get(key=INTERNAL_API_TOKEN)
token = Token.objects.get(key=ADMIN_API_TOKEN)
self.check_user(token.user)
except Token.DoesNotExist:
# Create a new internal user
user, _ = User.objects.get_or_create(
email='internal+bootstrap@teklia.com',
defaults={
'display_name': 'Bootstrap Internal user',
'is_internal': True,
'display_name': 'Bootstrap Admin user',
'is_admin': True,
}
)
......@@ -113,8 +111,8 @@ class Command(BaseCommand):
if hasattr(user, "auth_token"):
# Support One-To-One relation
user.auth_token.delete()
Token.objects.create(key=INTERNAL_API_TOKEN, user=user)
self.success(f"Created token {INTERNAL_API_TOKEN}")
Token.objects.create(key=ADMIN_API_TOKEN, user=user)
self.success(f"Created token {ADMIN_API_TOKEN}")
# an image server for local cantaloupe https://ark.localhost/iiif/2
try:
......
......@@ -56,15 +56,6 @@ class Command(BaseCommand):
superuser.verified_email = True
superuser.save()
internal_user = User.objects.create_user(
'internal@internal.fr',
'Pa$$w0rd',
display_name='Internal user',
internal=True,
)
internal_user.verified_email = True
internal_user.save()
user = User.objects.create_user('user@user.fr', 'Pa$$w0rd', display_name='Test user')
user.verified_email = True
user.save()
......
......@@ -3,18 +3,22 @@ from datetime import timedelta
from urllib.parse import quote
from uuid import UUID
import django_rq
from botocore.exceptions import ClientError
from django.conf import settings
from django.core.management.base import BaseCommand
from django.db.models import Exists, F, Max, OuterRef, Q, Value
from django.utils import timezone
from rq.utils import as_text
from arkindex.documents.models import CorpusExport, CorpusExportState, Element
from arkindex.images.models import Image, ImageServer
from arkindex.ponos.models import Artifact, Task, Workflow
from arkindex.process.models import DataFile, GitRef, GitRefType, WorkerVersion, WorkerVersionState
from arkindex.project.aws import s3
from arkindex.project.rq_overrides import Job
from arkindex.training.models import ModelVersion
from redis.exceptions import ConnectionError
# Ponos artifacts use the path: <workflow uuid>/<task id>/<path>
# Before June 2020, artifacts used <workflow uuid>/run_<run id>/<task id>.tar.zst
......@@ -48,6 +52,8 @@ class Command(BaseCommand):
self.cleanup_unlinked_model_versions()
self.cleanup_rq_user_registries()
def cleanup_artifacts(self):
"""
Remove all Ponos artifacts that are not tied to a Process
......@@ -330,3 +336,58 @@ class Command(BaseCommand):
self.stdout.write(self.style.ERROR(str(e)))
self.stdout.write(self.style.SUCCESS('Successfully cleaned up orphaned model versions archives.'))
def cleanup_rq_user_registries(self):
"""
To link RQ tasks to users and make them available in ListJobs, we have a "user registry" which stores a list
of RQ jobs IDs under rq:registry:user:<user ID>. This registry cannot be cleaned up automatically by RQ workers
as they are not easily overridable; the cleanup is instead done when ListJobs is called, which means cleaning up
a very large number of tasks can cause the API to timeout and the whole cleanup to be cancelled.
To at least reduce the impact of this issue, the cleanup command also performs this cleanup for all users.
"""
self.stdout.write('Cleaning up deleted jobs from RQ user registries…')
deleted_jobs, user_count = 0, 0
connection = django_rq.get_connection()
# Skip cleaning up when Redis appears to be down.
# No actual connection is made when calling get_connection, so we check on the first command we run
try:
registry_keys = connection.keys(pattern='rq:registry:user:*')
except ConnectionError as e:
self.stdout.write(self.style.ERROR(str(e)))
return
for registry_key in registry_keys:
job_ids = list(connection.zrange(registry_key, 0, -1))
to_delete = set()
if not job_ids:
# Empty registry
continue
# To check whether or not each job ID exists, we'll use the EXISTS command.
# It accepts multiple names as arguments and returns how many of those names do exist,
# but we need to know which job ID do exist or not, so we issue one EXISTS command per job.
# To make this much faster, we use a pipeline to run many EXISTS commands and return all the results at once.
with connection.pipeline() as pipeline:
for job_id in job_ids:
pipeline.exists(Job.key_for(as_text(job_id)))
for job_id, job_exists in zip(job_ids, pipeline.execute()):
if job_exists:
continue
to_delete.add(job_id)
# ZREM can be called with multiple keys at once, but with enough keys we'll get a broken pipe error,
# so to make sure we can really clean up a huge amount of job IDs, we delete in chunks.
if len(to_delete) >= settings.REDIS_ZREM_CHUNK_SIZE:
deleted_jobs += connection.zrem(registry_key, *to_delete)
to_delete = set()
if to_delete:
deleted_jobs += connection.zrem(registry_key, *to_delete)
user_count += 1
self.stdout.write(self.style.SUCCESS(f'Successfully cleaned up {deleted_jobs} deleted jobs from {user_count} RQ user registries.'))
......@@ -100,12 +100,16 @@ SQL_ELEMENT_QUERY = """
LEFT JOIN image_map ON (image_map.url = image.url)
"""
SQL_ELEMENT_PATH_QUERY = "SELECT * FROM element_path"
SQL_PARENT_QUERY = "SELECT * FROM element_path WHERE child_id = '{}'"
SQL_PARENT_QUERY = "SELECT parent_id FROM element_path WHERE child_id = '{}'"
SQL_TOP_LEVEL_PATH_QUERY = """
SELECT element.id
FROM element
LEFT JOIN element_path ON element_path.child_id = element.id
WHERE element_path.id IS NULL
WHERE NOT EXISTS (
SELECT 1
FROM element_path
WHERE child_id = element.id
LIMIT 1
)
"""
SQL_ENTITY_QUERY = "SELECT * FROM entity"
......@@ -142,24 +146,25 @@ class Command(BaseCommand):
return
yield chunk
def build_element_paths(self, child_id):
def build_element_paths(self, parent_id):
"""
The SQL database only stores links to direct parents.
It is necessary to be able to reconstruct the complete paths (with all grandparents).
This function retrieves the complete paths by making a recursive call.
Reconstructs the paths of a child element on a single parent element at a time.
The SQLite database only stores direct links, not arrays like in Arkindex,
so this function calls itself recursively to reconstruct the arrays.
"""
sql_query = SQL_PARENT_QUERY.format(child_id)
sql_query = SQL_PARENT_QUERY.format(parent_id)
found = False
paths = []
for db_chunk in self.sql_chunk(sql_query):
for row in db_chunk:
parent_paths = self.build_element_paths(row["parent_id"])
for parent_path in parent_paths:
parent_path.append(row["parent_id"])
paths.append(parent_path)
if not parent_paths:
paths.append([row["parent_id"]])
return paths
found = True
yield parent_path + [parent_id]
if not found:
# When this element has no parents, return a path with the element alone
yield [parent_id]
def convert_images(self, row):
assert row["url"].startswith(row["server_url"]), "The url of the image does not start with the url of its server"
......@@ -187,7 +192,7 @@ class Command(BaseCommand):
)]
def convert_element_paths(self, row):
paths = self.build_element_paths(row["child_id"])
paths = self.build_element_paths(row["parent_id"])
return [ElementPath(
element_id=row["child_id"],
path=path,
......
......@@ -246,7 +246,7 @@ class Element(IndexableModel):
Element.objects.filter(id=self.id)._raw_delete(using='default')
@transaction.atomic
def add_parent(self, parent, skip_children=False):
def add_parent(self, parent, skip_children=False, type_ordering=True):
'''
Add an element as ancestor
'''
......@@ -262,7 +262,7 @@ class Element(IndexableModel):
raise ValueError('Cannot add a descendant as a parent')
# Get the next order for this type in the parent
order = parent.get_next_order(self.type_id)
order = parent.get_next_order(self.type_id if type_ordering else None)
# Remove any top-level path
if [] in existing_paths:
......@@ -391,11 +391,15 @@ class Element(IndexableModel):
Uses the primary database to avoid stale reads and duplicate orderings,
and path__overlap to let Postgres use the GIN index before filtering by last item.
"""
assert isinstance(type, (ElementType, uuid.UUID))
return ElementPath.objects \
.using('default') \
.filter(path__overlap=[self.id], path__last=self.id, element__type=type) \
.aggregate(max=models.Max('ordering') + 1)['max'] or 0
assert type is None or isinstance(type, (ElementType, uuid.UUID))
paths = ElementPath.objects \
.using('default') \
.filter(path__overlap=[self.id], path__last=self.id)
if type:
paths = paths.filter(element__type=type)
return paths.aggregate(max=models.Max('ordering') + 1)['max'] or 0
@transaction.atomic
def remove_child(self, child):
......
......@@ -23,6 +23,7 @@ from arkindex.documents.serializers.light import (
from arkindex.documents.serializers.ml import ClassificationSerializer, WorkerRunSummarySerializer
from arkindex.images.models import Image
from arkindex.images.serializers import ZoneSerializer
from arkindex.ponos.utils import is_admin_or_ponos_task
from arkindex.process.models import WorkerRun, WorkerVersion
from arkindex.project.fields import Array
from arkindex.project.mixins import SelectionMixin
......@@ -428,8 +429,7 @@ class ElementSlimSerializer(ElementTinySerializer):
@extend_schema_field(serializers.CharField(allow_null=True))
def get_thumbnail_put_url(self, element):
user = self.context['request'].user
if user.is_authenticated and (user.is_admin or user.is_internal) and element.type.folder:
if is_admin_or_ponos_task(self.context['request']) and element.type.folder:
return element.thumbnail.s3_put_url
class Meta(ElementTinySerializer.Meta):
......@@ -492,6 +492,11 @@ class ElementParentSerializer(serializers.Serializer):
"""
child = serializers.PrimaryKeyRelatedField(queryset=Element.objects.none())
parent = serializers.PrimaryKeyRelatedField(queryset=Element.objects.none())
type_ordering = serializers.BooleanField(
default=True,
help_text='Add the child element at the last position in the parent relative to '
'only the elements of the same type, or to all elements.',
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -508,7 +513,7 @@ class ElementParentSerializer(serializers.Serializer):
child = data.get('child')
parent = data.get('parent')
if parent.corpus != child.corpus:
if parent.corpus_id != child.corpus_id:
errors['parent'].append("Parent is not from corpus '{}'".format(child.corpus.name))
if parent.id == child.id:
errors['parent'].append('A child cannot be its own parent')
......@@ -531,8 +536,9 @@ class ElementParentSerializer(serializers.Serializer):
def create(self, validated_data):
child = validated_data['child']
parent = validated_data['parent']
type_ordering = validated_data['type_ordering']
self.perform_create_checks(child, parent)
child.add_parent(parent)
child.add_parent(parent, type_ordering=type_ordering)
def delete(self, validated_data):
child = validated_data['child']
......@@ -789,10 +795,10 @@ class ElementCreateSerializer(ElementLightSerializer):
worker_run = data.get('worker_run', None)
if worker_run:
if self.context['request'].user.is_internal:
if is_admin_or_ponos_task(self.context['request']):
data['worker_version_id'] = worker_run.version_id
else:
errors['worker_run_id'].append('Only an internal user can create an element with a worker run.')
errors['worker_run_id'].append('Only an instance admin or a Ponos task can create an element with a worker run.')
else:
# Set the element creator for a manual element
data['creator'] = self.context['request'].user
......
......@@ -484,7 +484,7 @@ class TranscriptionEntityBulkItemSerializer(serializers.ModelSerializer):
extra_kwargs = {
'offset': {'write_only': True},
'length': {'write_only': True},
'confidence': {'write_only': True, 'required': True},
'confidence': {'write_only': True},
'entity_id': {'help_text': 'UUID of the newly created Entity.'},
}
......@@ -569,7 +569,7 @@ class TranscriptionEntitiesBulkSerializer(serializers.Serializer):
entity=entity,
offset=item["offset"],
length=item["length"],
confidence=item["confidence"],
confidence=item.get("confidence"),
worker_run=self.validated_data["worker_run"],
worker_version_id=self.validated_data["worker_run"].version_id,
)
......
......@@ -5,6 +5,7 @@ from rest_framework.exceptions import ValidationError
from arkindex.documents.dates import DateType
from arkindex.documents.models import AllowedMetaData, Corpus, Element, ElementType, EntityType, MetaData, MetaType
from arkindex.images.serializers import ZoneLightSerializer
from arkindex.ponos.utils import is_admin_or_ponos_task
from arkindex.project.serializer_fields import EnumField, MetaDataValueField
......@@ -133,10 +134,9 @@ class MetaDataLightSerializer(serializers.ModelSerializer):
})
def create(self, validated_data):
user = self.context['request'].user
element = self.context['element']
if not (user.is_admin or user.is_internal):
if not is_admin_or_ponos_task(self.context['request']):
self.check_allowed(corpus=element.corpus, **validated_data)
try:
......@@ -164,9 +164,8 @@ class MetaDataLightSerializer(serializers.ModelSerializer):
return element.metadatas.create(**validated_data)
def update(self, instance, validated_data):
user = self.context['request'].user
element = instance.element
if not (user.is_admin or user.is_internal):
if not is_admin_or_ponos_task(self.context['request']):
# Assert actual instance is part of AllowedMetaData
self.check_allowed(corpus=element.corpus, instance=instance)
# Assert the new values is part of AllowedMetaData
......@@ -203,8 +202,7 @@ class MetaDataLightSerializer(serializers.ModelSerializer):
return super().update(instance, validated_data)
def delete(self, instance):
user = self.context['request'].user
if not (user.is_admin or user.is_internal):
if not is_admin_or_ponos_task(self.context['request']):
self.check_allowed(corpus=instance.element.corpus, instance=instance)
instance.delete()
......
......@@ -19,6 +19,7 @@ from arkindex.documents.models import (
Transcription,
)
from arkindex.documents.serializers.light import ElementZoneSerializer
from arkindex.ponos.utils import is_admin_or_ponos_task
from arkindex.process.models import WorkerRun
from arkindex.project.serializer_fields import EnumField, ForbiddenField, LinearRingField
from arkindex.project.tools import polygon_outside_image
......@@ -199,7 +200,6 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
def validate(self, data):
# Note that (worker_version, class, element) unicity is already checked by DRF
errors = {}
user = self.context['request'].user
if data['element'].corpus_id != data['ml_class'].corpus_id:
errors['non_field_errors'] = ['Element and ML class are not in the same corpus']
......@@ -207,8 +207,8 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
worker_run = data.get('worker_run')
if worker_run is not None:
if not user or not user.is_internal:
errors['worker_run_id'] = ['An internal user is required to create a classification with a worker run.']
if not is_admin_or_ponos_task(self.context['request']):
errors['worker_run_id'] = ['An admin user or a Ponos task is required to create a classification with a worker run.']
if data.get('confidence') is None:
errors['confidence'] = ['This field is required to create a classification with a worker run.']
if data.get('high_confidence') is None:
......@@ -343,13 +343,12 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer):
if worker_run is None:
return data
# Additional validation for transcriptions with a worker version or worker run
# Additional validation for transcriptions with a worker run
errors = {}
user = self.context['request'].user
if worker_run is not None:
if not user or not user.is_internal:
errors['worker_run'] = ['An internal user is required to create a transcription with a worker run.']
if not is_admin_or_ponos_task(self.context['request']):
errors['worker_run'] = ['An admin user or a Ponos task is required to create a transcription with a worker run.']
if 'confidence' not in data:
errors['non_field_errors'] = ['The confidence field must be defined for a transcription with a worker run.']
......
......@@ -51,8 +51,9 @@ def corpus_delete(corpus_id: str) -> None:
# Process-DataFile M2M with implicit model
Process.files.through.objects.filter(process__corpus_id=corpus_id),
Process.files.through.objects.filter(datafile__corpus_id=corpus_id),
# Worker activities
WorkerActivity.objects.filter(Q(element__corpus_id=corpus_id) | Q(process__corpus_id=corpus_id)),
# Worker activities are deleted in two queries, as filtering using OR is slower
WorkerActivity.objects.filter(element__corpus_id=corpus_id),
WorkerActivity.objects.filter(process__corpus_id=corpus_id),
corpus.files.all(),
MetaData.objects.filter(element__corpus_id=corpus_id),
EntityLink.objects.filter(role__corpus_id=corpus_id),
......
......@@ -17,6 +17,7 @@ from arkindex.training.models import Model, ModelVersion
@override_settings(AWS_EXPORT_BUCKET='export', PONOS_S3_ARTIFACTS_BUCKET='ponos-artifacts', PONOS_S3_LOGS_BUCKET='ponos-logs', AWS_TRAINING_BUCKET='training')
@patch('django_rq.get_connection')
@patch('arkindex.project.aws.s3')
class TestCleanupCommand(FixtureTestCase):
......@@ -28,7 +29,7 @@ class TestCleanupCommand(FixtureTestCase):
call_command('cleanup', args + ['--no-color'], stdout=output, stderr=output)
return output.getvalue().strip()
def test_cleanup(self, s3_mock):
def test_cleanup(self, s3_mock, rq_mock):
young_export = self.corpus.exports.create(user=self.superuser)
# Use a fake creation time to make old exports
with patch('django.utils.timezone.now') as mock_now:
......@@ -65,6 +66,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -75,7 +78,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(s3_mock.Object.call_args, call('export', str(done_export.id)))
self.assertEqual(s3_mock.Object.return_value.delete.call_count, 1)
def test_nothing(self, s3_mock):
def test_nothing(self, s3_mock, rq_mock):
self.assertFalse(CorpusExport.objects.exists())
self.assertEqual(
self.cleanup(),
......@@ -105,11 +108,13 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
def test_s3_not_found(self, s3_mock):
def test_s3_not_found(self, s3_mock, rq_mock):
s3_mock.Object.return_value.delete.side_effect = ClientError({'Error': {'Code': '404'}}, 'delete_object')
with patch('django.utils.timezone.now') as mock_now:
......@@ -146,6 +151,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -156,7 +163,7 @@ class TestCleanupCommand(FixtureTestCase):
# s3.Object.delete should only be called once, not retried
self.assertEqual(s3_mock.Object.return_value.delete.call_count, 1)
def test_s3_error(self, s3_mock):
def test_s3_error(self, s3_mock, rq_mock):
error = ClientError({'Error': {'Code': '500'}}, 'delete_object')
# Fail twice, then delete successfully
s3_mock.Object.return_value.delete.side_effect = [error, error, None]
......@@ -194,6 +201,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -203,7 +212,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(s3_mock.Object.call_args, call('export', str(done_export.id)))
self.assertEqual(s3_mock.Object.return_value.delete.call_count, 3)
def test_cleanup_trashed_datafiles(self, s3_mock):
def test_cleanup_trashed_datafiles(self, s3_mock, rq_mock):
DataFile.objects.bulk_create([
DataFile(
name=f'test{i}.txt',
......@@ -245,6 +254,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -252,7 +263,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(DataFile.objects.filter(trashed=True).count(), 0)
self.assertEqual(s3_mock.Object().delete.call_count, 1)
def test_cleanup_trashed_datafiles_ignore_s3_errors(self, s3_mock):
def test_cleanup_trashed_datafiles_ignore_s3_errors(self, s3_mock, rq_mock):
"""
Test the cleanup command tries multiple times to delete from S3, but ignores errors if it fails
"""
......@@ -304,6 +315,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -312,7 +325,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(s3_mock.Object().delete.call_count, 3)
@patch('arkindex.documents.management.commands.cleanup.s3')
def test_cleanup_artifacts(self, cleanup_s3_mock, s3_mock):
def test_cleanup_artifacts(self, cleanup_s3_mock, s3_mock, rq_mock):
workflow = Workflow.objects.create(farm=Farm.objects.first())
task = workflow.tasks.create(run=0, depth=0, slug='task')
......@@ -374,6 +387,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -392,7 +407,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(broken_s3_artifact.delete.call_count, 1)
@patch('arkindex.ponos.models.s3')
def test_cleanup_expired_workflows(self, ponos_s3_mock, s3_mock):
def test_cleanup_expired_workflows(self, ponos_s3_mock, s3_mock, rq_mock):
farm = Farm.objects.first()
expired_workflow = farm.workflows.create()
......@@ -449,6 +464,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -497,7 +514,7 @@ class TestCleanupCommand(FixtureTestCase):
return revision, artifact
@patch('arkindex.ponos.models.s3')
def test_cleanup_expired_workflows_docker_images(self, ponos_s3_mock, s3_mock):
def test_cleanup_expired_workflows_docker_images(self, ponos_s3_mock, s3_mock, rq_mock):
"""
Artifacts used as Docker images for worker versions from expired workflows
should only be deleted if the versions are neither on Git tags or on main branches.
......@@ -548,6 +565,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -578,7 +597,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(ponos_s3_mock.Object().delete.call_count, 4)
@patch('arkindex.ponos.models.s3')
def test_cleanup_expired_workflows_null(self, ponos_s3_mock, s3_mock):
def test_cleanup_expired_workflows_null(self, ponos_s3_mock, s3_mock, rq_mock):
repo = Repository.objects.get(url='http://my_repo.fake/workers/worker')
# This revision on the `main` branch does not have any WorkerVersions.
......@@ -644,6 +663,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -670,7 +691,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(ponos_s3_mock.Object().delete.call_count, 4)
@patch('arkindex.documents.management.commands.cleanup.s3')
def test_cleanup_local_images(self, cleanup_s3_mock, s3_mock):
def test_cleanup_local_images(self, cleanup_s3_mock, s3_mock, rq_mock):
ImageServer.objects.local.images.create(path='path%2Fto%2Fimage.jpg')
img_object = MagicMock()
img_object.key = 'path/to/image.jpg'
......@@ -724,6 +745,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -741,7 +764,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(orphan_object.delete.call_count, 1)
@patch('arkindex.documents.management.commands.cleanup.s3')
def test_cleanup_orphan_images(self, cleanup_s3_mock, s3_mock):
def test_cleanup_orphan_images(self, cleanup_s3_mock, s3_mock, rq_mock):
element = Element.objects.get(name='Volume 2, page 1v')
image_with_element = Image.objects.get(id=element.image.id)
img_server = ImageServer.objects.get(url='http://server')
......@@ -781,6 +804,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -793,7 +818,7 @@ class TestCleanupCommand(FixtureTestCase):
image_no_element_new.refresh_from_db()
@patch('arkindex.documents.management.commands.cleanup.s3')
def test_cleanup_logs(self, cleanup_s3_mock, s3_mock):
def test_cleanup_logs(self, cleanup_s3_mock, s3_mock, rq_mock):
workflow = Workflow.objects.create(farm=Farm.objects.first())
task = workflow.tasks.create(run=0, depth=0, slug='task')
......@@ -852,6 +877,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -871,7 +898,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(broken_s3_log.delete.call_count, 1)
@patch('arkindex.documents.management.commands.cleanup.s3')
def test_cleanup_model_versions(self, cleanup_s3_mock, s3_mock):
def test_cleanup_model_versions(self, cleanup_s3_mock, s3_mock, rq_mock):
model = Model.objects.create(name='Nice Model')
model_version = ModelVersion.objects.create(model_id=model.id, hash='bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', archive_hash='aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', size=8)
......@@ -930,6 +957,8 @@ class TestCleanupCommand(FixtureTestCase):
Removing model version archive aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa.zst…
An error occurred (500) when calling the delete_object operation: Unknown
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -949,7 +978,7 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(broken_s3_version.delete.call_count, 1)
@patch('arkindex.documents.management.commands.cleanup.s3')
def test_cleanup_orphan_exports(self, cleanup_s3_mock, s3_mock):
def test_cleanup_orphan_exports(self, cleanup_s3_mock, s3_mock, rq_mock):
export = self.corpus.exports.create(user=self.superuser)
good_export = MagicMock()
......@@ -1007,6 +1036,8 @@ class TestCleanupCommand(FixtureTestCase):
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 0 deleted jobs from 0 RQ user registries.
"""
).strip()
)
......@@ -1024,3 +1055,92 @@ class TestCleanupCommand(FixtureTestCase):
self.assertEqual(orphan_export.delete.call_count, 1)
self.assertEqual(unsupported_export.delete.call_count, 0)
self.assertEqual(broken_export.delete.call_count, 1)
@override_settings(REDIS_ZREM_CHUNK_SIZE=2)
def test_cleanup_rq_user_registries(self, s3_mock, rq_mock):
# List of user registries
rq_mock().keys.return_value = [
'rq:registry:user:1',
'rq:registry:user:2',
'rq:registry:user:3',
]
# All keys in each user registry, in the order specified above
rq_mock().zrange.side_effect = [
[],
['job1', 'job2', 'job3', 'job4', 'job5'],
['job2', 'job3'],
]
# Whether or not each job exists in each non-empty user registry, in the order specified above
rq_mock().pipeline().__enter__().execute.side_effect = [
# Three jobs should be deleted: two ZREM calls will be made since we use chunks of 2
[True, False, True, False, False],
# Nothing to delete, no ZREM should be called
[True, True],
]
# How many jobs get deleted with each ZREM call
rq_mock().zrem.side_effect = [2, 1]
# Reset the call history, keeping the return values and side effects in place
rq_mock.reset_mock()
self.assertEqual(
self.cleanup(),
dedent(
"""
Removing orphaned Ponos artifacts…
Successfully cleaned up orphaned Ponos artifacts.
Removing 0 artifacts of expired workflows from S3…
Removing logs for 0 tasks of expired workflows from S3…
Updating 0 available worker versions to the Error state…
Removing 0 artifacts of expired workflows…
Removing 0 tasks of expired workflows…
Removing 0 expired workflows…
Successfully cleaned up expired workflows.
Removing 0 old corpus exports from S3…
Removing 0 old corpus exports…
Successfully cleaned up old corpus exports.
Removing orphaned corpus exports…
Successfully cleaned up orphaned corpus exports.
Deleting 0 DataFiles marked as trashed from S3 and the database…
Successfully cleaned up DataFiles marked as trashed.
Removing orphan images…
Successfully cleaned up orphan images.
Removing orphaned local images…
Successfully cleaned up orphaned local images.
Removing orphaned Ponos logs…
Successfully cleaned up orphaned Ponos logs.
Removing orphaned model versions archives…
Successfully cleaned up orphaned model versions archives.
Cleaning up deleted jobs from RQ user registries…
Successfully cleaned up 3 deleted jobs from 2 RQ user registries.
"""
).strip()
)
self.assertEqual(rq_mock.call_count, 1)
self.assertEqual(rq_mock().keys.call_count, 1)
self.assertEqual(rq_mock().keys.call_args, call(pattern='rq:registry:user:*'))
self.assertListEqual(rq_mock().zrange.call_args_list, [
call('rq:registry:user:1', 0, -1),
call('rq:registry:user:2', 0, -1),
call('rq:registry:user:3', 0, -1),
])
self.assertEqual(rq_mock().pipeline.call_count, 2)
self.assertListEqual(rq_mock().pipeline().__enter__().exists.call_args_list, [
call(b'rq:job:job1'),
call(b'rq:job:job2'),
call(b'rq:job:job3'),
call(b'rq:job:job4'),
call(b'rq:job:job5'),
call(b'rq:job:job2'),
call(b'rq:job:job3'),
])
self.assertEqual(rq_mock().pipeline().__enter__().execute.call_count, 2)
self.assertEqual(rq_mock().zrem.call_count, 2)
chunk_call, remaining_call = rq_mock().zrem.call_args_list
# The chunk call is made using a set(), because we do not want to send duplicate keys,
# so the arguments might be either job2,job4 or job4,job2
self.assertIn(chunk_call, (
call('rq:registry:user:2', 'job2', 'job4'),
call('rq:registry:user:2', 'job4', 'job2'),
))
self.assertEqual(remaining_call, call('rq:registry:user:2', 'job5'))
......@@ -208,8 +208,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase):
.filter(paths__path__last=self.page.id, type__slug='text_line')
self.assertEqual(created_elts.count(), 1)
self.client.force_login(self.internal_user)
with self.assertNumQueries(14):
self.client.force_login(self.user)
with self.assertNumQueries(16):
response = self.client.post(
reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}),
format='json',
......@@ -250,8 +250,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase):
} for poly, text, confidence in transcriptions]
}
self.client.force_login(self.internal_user)
with self.assertNumQueries(14):
self.client.force_login(self.user)
with self.assertNumQueries(16):
response = self.client.post(
reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}),
format='json',
......@@ -299,8 +299,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase):
],
}
self.client.force_login(self.internal_user)
with self.assertNumQueries(12):
self.client.force_login(self.user)
with self.assertNumQueries(14):
response = self.client.post(
reverse('api:element-transcriptions-bulk', kwargs={'pk': self.huge_page.id}),
format='json',
......@@ -466,8 +466,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase):
'return_elements': True
}
self.client.force_login(self.internal_user)
with self.assertNumQueries(14):
self.client.force_login(self.user)
with self.assertNumQueries(16):
response = self.client.post(
reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}),
format='json',
......@@ -498,8 +498,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase):
}])
def test_bulk_transcriptions_requires_zone(self):
self.client.force_login(self.internal_user)
with self.assertNumQueries(3):
self.client.force_login(self.user)
with self.assertNumQueries(5):
response = self.client.post(
reverse('api:element-transcriptions-bulk', kwargs={'pk': self.vol.id}),
format='json',
......
......@@ -132,7 +132,6 @@ class TestBulkTranscriptionEntities(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'entities': [{
'confidence': ['This field is required.'],
'length': ['This field is required.'],
'name': ['This field is required.'],
'offset': ['This field is required.'],
......@@ -197,14 +196,13 @@ class TestBulkTranscriptionEntities(FixtureAPITestCase):
'type_id': str(self.person_ent_type.id),
'offset': 0,
'length': 1,
'confidence': 0.7,
'confidence': None,
},
{
'name': 'Knight',
'type_id': str(self.person_ent_type.id),
'offset': 10,
'length': 5,
'confidence': 0.05,
},
],
'worker_run_id': str(self.worker_run.id),
......@@ -228,8 +226,8 @@ class TestBulkTranscriptionEntities(FixtureAPITestCase):
.values_list('entity__name', 'entity__type', 'offset', 'length', 'confidence', 'worker_run_id')
.order_by('entity__name', 'offset'),
[
('Knight', self.person_ent_type.id, 0, 1, 0.7, self.worker_run.id),
('Knight', self.person_ent_type.id, 10, 5, 0.05, self.worker_run.id),
('Knight', self.person_ent_type.id, 0, 1, None, self.worker_run.id),
('Knight', self.person_ent_type.id, 10, 5, None, self.worker_run.id),
('Paris', self.location_ent_type.id, 0, 1, 0.7, self.worker_run.id),
],
)
......
......@@ -7,7 +7,7 @@ from rest_framework import status
from arkindex.documents.models import Classification, ClassificationState, Corpus, Element, MLClass
from arkindex.process.models import WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role, User
from arkindex.users.models import Role
class TestClassifications(FixtureAPITestCase):
......@@ -24,7 +24,6 @@ class TestClassifications(FixtureAPITestCase):
cls.worker_version_2 = WorkerVersion.objects.get(worker__slug='reco')
cls.worker_run_1 = cls.worker_version_1.worker_runs.get()
cls.worker_run_2 = cls.worker_version_2.worker_runs.get()
cls.internal_user = User.objects.get_by_natural_key('internal@internal.fr')
def test_create_manual(self):
"""
......@@ -101,7 +100,7 @@ class TestClassifications(FixtureAPITestCase):
If a classification from the same worker run already exists,
creation must respond a 400_BAD_REQUEST with an explicit message
"""
self.client.force_login(self.internal_user)
self.client.force_login(self.superuser)
request = (
reverse('api:classification-create'),
{
......@@ -205,8 +204,8 @@ class TestClassifications(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_create_worker_version(self):
self.client.force_login(self.internal_user)
with self.assertNumQueries(4):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
......@@ -219,48 +218,8 @@ class TestClassifications(FixtureAPITestCase):
'worker_version': ['This field is forbidden.'],
})
def test_create_worker_run_requires_internal(self):
"""
Test creating a classification with a worker version requires the user to be internal
"""
self.client.force_login(self.user)
with self.assertNumQueries(8):
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
'worker_run_id': str(self.worker_run_1.id)
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_run_id': ['An internal user is required to create a classification with a worker run.'],
'confidence': ['This field is required to create a classification with a worker run.'],
'high_confidence': ['This field is required to create a classification with a worker run.']
})
def test_create_worker_run_non_admin(self):
"""
Test creating a classification with a worker run requires the user to be internal,
and does not make an exception for admin users
"""
self.assertFalse(self.superuser.is_internal)
self.client.force_login(self.superuser)
with self.assertNumQueries(6):
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
'worker_run_id': str(self.worker_run_1.id)
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_run_id': ['An internal user is required to create a classification with a worker run.'],
'confidence': ['This field is required to create a classification with a worker run.'],
'high_confidence': ['This field is required to create a classification with a worker run.']
})
def test_create_worker_run(self):
self.client.force_login(self.internal_user)
self.client.force_login(self.superuser)
with self.assertNumQueries(7):
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
......@@ -290,9 +249,45 @@ class TestClassifications(FixtureAPITestCase):
'high_confidence': False,
})
def test_create_worker_run_task(self):
self.worker_run_1.process.start()
task = self.worker_run_1.process.workflow.tasks.first()
with self.assertNumQueries(8):
response = self.client.post(
reverse('api:classification-create'),
{
'element': str(self.element.id),
'ml_class': str(self.text.id),
'worker_run_id': str(self.worker_run_1.id),
'confidence': 0.42,
'high_confidence': False,
},
HTTP_AUTHORIZATION=f'Ponos {task.token}',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
classification = self.element.classifications.get()
self.assertEqual(classification.worker_run, self.worker_run_1)
self.assertEqual(classification.worker_version, self.worker_version_1)
self.assertEqual(classification.ml_class, self.text)
self.assertEqual(classification.state, ClassificationState.Pending)
self.assertEqual(classification.confidence, 0.42)
self.assertFalse(classification.high_confidence)
self.assertDictEqual(response.json(), {
'id': str(classification.id),
'element': str(self.element.id),
'ml_class': str(self.text.id),
'worker_run_id': str(self.worker_run_1.id),
'worker_version': str(self.worker_version_1.id),
'state': ClassificationState.Pending.value,
'confidence': 0.42,
'high_confidence': False,
})
def test_create_worker_version_xor_worker_run(self):
self.client.force_login(self.internal_user)
with self.assertNumQueries(5):
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
......@@ -306,8 +301,8 @@ class TestClassifications(FixtureAPITestCase):
})
def test_create_worker_run_not_found(self):
self.client.force_login(self.internal_user)
with self.assertNumQueries(5):
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
......@@ -323,7 +318,7 @@ class TestClassifications(FixtureAPITestCase):
"""
Ensure CreateClassification accepts a confidence of 0
"""
self.client.force_login(self.internal_user)
self.client.force_login(self.superuser)
with self.assertNumQueries(7):
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
......@@ -347,7 +342,7 @@ class TestClassifications(FixtureAPITestCase):
CreateClassification should allow creating the same ML class from a worker run,
a worker version and no worker version on the same element.
"""
self.client.force_login(self.internal_user)
self.client.force_login(self.superuser)
# Create a classification with a worker version and no worker run
self.element.classifications.create(
......