Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • arkindex/backend
1 result
Show changes
Commits on Source (20)
Showing
with 711 additions and 27 deletions
......@@ -16,3 +16,4 @@ htmlcov
ponos
*.key
arkindex/config.yml
test-report.xml
......@@ -11,7 +11,7 @@ include:
# For jobs that run backend scripts directly
.backend-setup:
image: registry.gitlab.com/teklia/arkindex/backend/base:django-4.1.6
image: registry.gitlab.com/teklia/arkindex/backend/base:python-3.10
cache:
paths:
......@@ -57,14 +57,14 @@ backend-tests:
when: always
reports:
junit:
- nosetests.xml
- test-report.xml
script:
- python3 setup.py test
- codecov
backend-lint:
image: python:3.8
image: python:3.10
stage: test
except:
......
FROM registry.gitlab.com/teklia/arkindex/backend/base:django-4.1.6 as build
FROM registry.gitlab.com/teklia/arkindex/backend/base:python-3.10 as build
RUN mkdir build
ADD . build
RUN cd build && python3 setup.py sdist
FROM registry.gitlab.com/teklia/arkindex/backend/base:django-4.1.6
FROM registry.gitlab.com/teklia/arkindex/backend/base:python-3.10
ARG TRANSKRIBUS_BRANCH=master
ARG TRANSKRIBUS_ID=11180199
ARG GITLAB_TOKEN="gaFM7LRa9zy9QMowcUhx"
......
FROM python:3.8-slim AS compilation
FROM python:3.10-slim AS compilation
RUN apt-get update && apt-get install --no-install-recommends -y build-essential wget
......@@ -50,7 +50,7 @@ RUN python -m nuitka \
arkindex/manage.py
# Start over from a clean setup
FROM registry.gitlab.com/teklia/arkindex/backend/base:django-4.1.6 as build
FROM registry.gitlab.com/teklia/arkindex/backend/base:python-3.10 as build
# Import files from compilation
RUN mkdir /usr/share/arkindex
......
......@@ -155,11 +155,8 @@ SHELL_PLUS_POST_IMPORTS = [
('arkindex.documents.models', (
'ElementType',
'Right',
'PageType',
'PageDirection',
'PageComplement',
)),
('arkindex.dataimport.models', (
('arkindex.process.models', (
'DataImportMode',
)),
('arkindex.project.aws', (
......@@ -171,8 +168,6 @@ SHELL_PLUS_POST_IMPORTS = [
]
```
You may want to also uninstall `django-nose`, as it is an optional test runner that is used for code coverage in the CI. Uninstalling will remove about a hundred useless lines in the `./manage.py test` output so you will no longer have to scroll to the test errors list.
## Asynchronous tasks
We use [rq](https://python-rq.org/), integrated via [django-rq](https://pypi.org/project/django-rq/), to run tasks without blocking an API request or causing timeouts. To call them in Python code, you should use the trigger methods in `arkindex.project.triggers`; those will do some safety checks to make catching some errors easier in dev. The actual tasks are in `arkindex.documents.tasks`. The following tasks exist:
......
1.4.0-beta2
1.4.1-beta1
......@@ -141,7 +141,7 @@ class EntityLinkInLine(admin.TabularInline):
class EntityAdmin(admin.ModelAdmin):
list_display = ('id', 'name', 'type')
list_filter = [('type', EnumFieldListFilter), 'corpus']
list_filter = ['corpus', 'type']
readonly_fields = ('id', )
raw_id_fields = ('worker_version', 'worker_run', )
search_fields = ('name', )
......
......@@ -32,6 +32,7 @@ from arkindex.documents.serializers.entities import (
EntitySerializer,
EntityTypeCreateSerializer,
EntityTypeSerializer,
TranscriptionEntitiesBulkSerializer,
TranscriptionEntityCreateSerializer,
TranscriptionEntitySerializer,
)
......@@ -530,3 +531,35 @@ class ElementLinks(CorpusACLMixin, ListAPIView):
| Q(parent__in=entities_meta, child__in=entities_tr)
| Q(parent__in=entities_meta, child__in=entities_meta)
).select_related('role', 'child__type', 'parent__type').order_by('parent__name')
@extend_schema_view(
post=extend_schema(
operation_id='CreateTranscriptionEntities',
tags=['entities'],
)
)
class TranscriptionEntitiesBulk(CorpusACLMixin, CreateAPIView):
"""
Create multiple entities attached to a transcription.
This requires the transcription to not already have any TranscriptionEntities from the specified WorkerRun.
Requires a **contributor** access to the transcription's corpus.
"""
serializer_class = TranscriptionEntitiesBulkSerializer
permission_classes = (IsVerified, )
# We need to use the default database to avoid stale read on a created transcription
queryset = Transcription.objects.select_related("element__corpus").using('default')
def get_object(self):
transcription = super().get_object()
if not self.has_write_access(transcription.element.corpus):
raise PermissionDenied(detail='You do not have a contributor access to the corpus of this transcription.')
return transcription
def get_serializer_context(self):
return {
**super().get_serializer_context(),
'transcription': self.get_object(),
}
import math
from django.conf import settings
from django.utils.functional import cached_property
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.exceptions import NotFound, ValidationError
from rest_framework.generics import RetrieveAPIView
from rest_framework.generics import CreateAPIView, RetrieveAPIView
from rest_framework.response import Response
from rest_framework.utils.urls import replace_query_param
from SolrClient import SolrClient
from SolrClient.exceptions import SolrError
from arkindex.documents.models import EntityType, MetaType
from arkindex.documents.serializers.search import CorpusSearchQuerySerializer, CorpusSearchResultSerializer
from arkindex.documents.serializers.search import (
CorpusSearchQuerySerializer,
CorpusSearchResultSerializer,
ReindexCorpusSerializer,
)
from arkindex.project.mixins import CorpusACLMixin
from arkindex.project.permissions import IsVerified
from arkindex.users.models import Role
solr = SolrClient(settings.SOLR_API_URL)
......@@ -123,3 +130,35 @@ class CorpusSearch(CorpusACLMixin, RetrieveAPIView):
'results': results.docs if not only_facets else None,
'facets': results.get_facets()
}).data, status=status.HTTP_200_OK)
@extend_schema_view(post=extend_schema(operation_id="BuildSearchIndex", tags=['search']))
class SearchIndexBuild(CorpusACLMixin, CreateAPIView):
"""
Starts an indexation task for a specific corpus, ran by RQ
"""
permission_classes = (IsVerified, )
serializer_class = ReindexCorpusSerializer
@cached_property
def corpus(self):
corpus = self.get_corpus(self.kwargs['pk'], role=Role.Admin)
if not corpus.indexable:
raise ValidationError({'__all__': ['This project is not indexable.']})
if not corpus.types.filter(indexable=True).exists():
raise ValidationError({'__all__': ['There are no indexable element types for this project.']})
return corpus
def get_serializer_context(self):
return {
**super().get_serializer_context(),
'corpus_id': self.corpus.id,
'user_id': self.request.user.id,
}
def create(self, request, *args, **kwargs):
if not settings.ARKINDEX_FEATURES['search']:
raise ValidationError({
'__all__': ['Building search index is unavailable due to the search feature being disabled.']
})
return super().create(request, *args, **kwargs)
......@@ -78,7 +78,7 @@ def save_sqlite(rows, table, cursor):
return float(value)
# Show very explicit error messages if we stumble upon an unexpected type
# https://docs.python.org/3.8/library/sqlite3.html#sqlite-and-python-types
# https://docs.python.org/3.10/library/sqlite3.html#sqlite-and-python-types
assert value is None or isinstance(value, (int, float, str, bytes)), f'Type {type(value)} is not supported by sqlite3'
return value
......
......@@ -72,6 +72,17 @@
"author": "Test user"
}
},
{
"model": "process.worker",
"pk": "0fb9c903-6d8b-4ed9-b532-1ccc7ad8b78b",
"fields": {
"name": "File import",
"slug": "file_import",
"type": "1a1a76ad-27c8-4f4d-a6cc-b37890fc68d9",
"repository": "867b5357-e9dd-4a61-8f40-7a84cbf92a20",
"public": false
}
},
{
"model": "process.worker",
"pk": "3fb21e96-99ca-446a-812e-5ab4a43ea356",
......@@ -116,6 +127,16 @@
"public": false
}
},
{
"model": "process.workertype",
"pk": "1a1a76ad-27c8-4f4d-a6cc-b37890fc68d9",
"fields": {
"created": "2020-02-02T01:23:45.678Z",
"updated": "2020-02-02T01:23:45.678Z",
"slug": "import",
"display_name": "Import"
}
},
{
"model": "process.workertype",
"pk": "4c8075e4-6281-426f-9600-f93f86035a17",
......@@ -194,6 +215,20 @@
"docker_image_iid": null
}
},
{
"model": "process.workerversion",
"pk": "985770d3-4926-41c6-abfb-58627c6aaa40",
"fields": {
"worker": "0fb9c903-6d8b-4ed9-b532-1ccc7ad8b78b",
"revision": "5d2ff50a-cd96-4a00-9069-450af3fa57c3",
"configuration": {},
"state": "available",
"gpu_usage": "disabled",
"model_usage": false,
"docker_image": "35a1ba90-c22c-4b99-b9d5-244e9a6eb3ed",
"docker_image_iid": null
}
},
{
"model": "process.workerversion",
"pk": "da32c17c-2ac6-44cd-9d84-7c7b48e37178",
......@@ -3726,11 +3761,11 @@
"slug": "docker_build",
"priority": 10,
"state": "completed",
"tags": "[]",
"tags": [],
"image": "",
"shm_size": null,
"command": null,
"env": null,
"env": "{}",
"has_docker_socket": false,
"image_artifact": null,
"agent": null,
......@@ -3741,7 +3776,7 @@
"created": "2020-02-02T01:23:45.678Z",
"updated": "2020-02-02T01:23:45.678Z",
"expiry": "2050-03-03T01:23:45.678Z",
"extra_files": {},
"extra_files": "{}",
"parents": []
}
},
......
......@@ -120,6 +120,7 @@ class Command(BaseCommand):
dla_worker_type = WorkerType.objects.create(slug="dla")
recognizer_worker_type = WorkerType.objects.create(slug="recognizer")
gpu_worker_type = WorkerType.objects.create(slug="worker")
import_worker_type = WorkerType.objects.create(slug="import", display_name="Import")
# Create a fake docker build with a docker image task
farm = Farm.objects.create(name="Wheat farm")
......@@ -153,6 +154,19 @@ class Command(BaseCommand):
docker_image=docker_image
)
WorkerVersion.objects.create(
worker=worker_repo.workers.create(
name='File import',
slug='file_import',
type=import_worker_type,
),
revision=revision,
configuration={},
state=WorkerVersionState.Available,
model_usage=False,
docker_image=docker_image,
)
WorkerVersion.objects.create(
worker=worker_repo.workers.create(
name='Worker requiring a GPU',
......
......@@ -497,6 +497,9 @@ class EntityType(models.Model):
('name', 'corpus'),
)
def __str__(self):
return self.name
class Entity(models.Model):
"""
......
from collections import defaultdict
from textwrap import dedent
from django.db import transaction
from drf_spectacular.utils import extend_schema_serializer
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
......@@ -426,12 +427,18 @@ class TranscriptionEntityCreateSerializer(serializers.ModelSerializer):
if worker_run is not None:
data['worker_version_id'] = worker_run.version_id
existing_transcription_entities = TranscriptionEntity.objects.filter(transcription=data['transcription'], entity=data['entity'], offset=data['offset'], length=data['length'])
if worker_run:
if existing_transcription_entities.filter(worker_run=worker_run).exists():
existing_transcription_entities = TranscriptionEntity.objects.filter(
transcription=data['transcription'],
entity=data['entity'],
offset=data['offset'],
length=data['length'],
worker_run=worker_run,
)
if existing_transcription_entities.exists():
if worker_run:
errors['__all__'] = ['This entity is already linked to this transcription by this worker run at this position.']
elif existing_transcription_entities.exists():
errors['__all__'] = ['This entity is already linked to this transcription at this position.']
else:
errors['__all__'] = ['This entity is already linked to this transcription at this position.']
if errors:
raise serializers.ValidationError(errors)
......@@ -446,3 +453,134 @@ class TranscriptionEntitySerializer(TranscriptionEntityCreateSerializer):
"""
worker_version_id = serializers.UUIDField(read_only=True)
entity = BaseEntitySerializer()
class TranscriptionEntityBulkItemSerializer(serializers.ModelSerializer):
name = serializers.CharField(write_only=True)
type_id = serializers.UUIDField(
help_text='UUID of the EntityType to use for this entity.',
write_only=True,
)
transcription_entity_id = serializers.IntegerField(
help_text='ID of the newly created TranscriptionEntity.',
read_only=True,
)
class Meta:
model = TranscriptionEntity
fields = (
'name',
'type_id',
'offset',
'length',
'confidence',
'transcription_entity_id',
'entity_id',
)
read_only_fields = (
'transcription_entity_id',
'entity_id',
)
extra_kwargs = {
'offset': {'write_only': True},
'length': {'write_only': True},
'confidence': {'write_only': True, 'required': True},
'entity_id': {'help_text': 'UUID of the newly created Entity.'},
}
def validate(self, data):
# Ensure no TranscriptionEntity overflows transcription's length
if data['offset'] + data['length'] > len(self.context['transcription'].text):
raise ValidationError({'__all__': ['Entity position overflows transcription text size']})
return data
class TranscriptionEntitiesBulkSerializer(serializers.Serializer):
worker_run_id = serializers.PrimaryKeyRelatedField(
queryset=WorkerRun.objects.all(),
write_only=True,
style={'base_template': 'input.html'},
source='worker_run',
)
entities = TranscriptionEntityBulkItemSerializer(
many=True,
allow_empty=False,
help_text='Attributes for the Entities and TranscriptionEntities to create. Each item must be unique.',
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transcription = self.context.get('transcription')
# Pass the transcription through to the item serializer for validation
self.fields['entities'].context['transcription'] = self.transcription
def validate(self, data):
errors = {}
if self.transcription.transcription_entities.filter(worker_run=data['worker_run']).exists():
errors['worker_run_id'] = ['Some TranscriptionEntities were already created for this worker run on this transcription.']
# Check that all the EntityType IDs exist in one query
found_types = set(self.transcription.element.corpus.entity_types.filter(
id__in={item['type_id'] for item in data['entities']}
).values_list('id', flat=True))
entity_errors = {}
for i, item in enumerate(data['entities']):
if item['type_id'] not in found_types:
entity_errors[i] = {'type_id': ["An EntityType with this ID does not exist in the transcription's corpus."]}
# Verify all TranscriptionEntities will be unique: the DB unique constraint that applies here is
# (transcription, entity, offset, length, worker_run), so here we need to check on (entity, offset, length).
# Since we are creating entities based on the (name, type_id) combo, we use (name, type_id, offset, length).
unique_count = len(set(
(item["offset"], item["length"], item["name"], item["type_id"])
for item in data['entities']
))
# If we have a different entity count with the set than in the original list, there are duplicates.
if len(data['entities']) != unique_count:
entity_errors['__all__'] = ['Some TranscriptionEntities have the same name, type, offset and length.']
if entity_errors:
errors['entities'] = entity_errors
if errors:
raise ValidationError(errors)
return data
@transaction.atomic
def save(self):
entities = Entity.objects.bulk_create([
Entity(
corpus=self.transcription.element.corpus,
name=item['name'],
type_id=item['type_id'],
validated=True,
worker_run=self.validated_data["worker_run"],
worker_version_id=self.validated_data["worker_run"].version_id,
)
for item in self.validated_data['entities']
])
transcription_entities = TranscriptionEntity.objects.bulk_create([
TranscriptionEntity(
transcription=self.transcription,
entity=entity,
offset=item["offset"],
length=item["length"],
confidence=item["confidence"],
worker_run=self.validated_data["worker_run"],
worker_version_id=self.validated_data["worker_run"].version_id,
)
for entity, item in zip(entities, self.validated_data["entities"])
])
# Update self.validated_data so that DRF can output a response with the created UUIDs
self.validated_data['entities'] = [
{
"entity_id": transcription_entity.entity.id,
"transcription_entity_id": transcription_entity.id,
}
for transcription_entity in transcription_entities
]
from textwrap import dedent
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex.documents.models import MetaType
from arkindex.project.rq_overrides import get_existing_job
from arkindex.project.serializer_fields import EnumField
from arkindex.project.triggers import reindex_corpus
class SolrDocumentSerializer(serializers.Serializer):
......@@ -164,3 +167,24 @@ class CorpusSearchQuerySerializer(serializers.Serializer):
required=False,
help_text='Filter by name of the worker that created the entity.',
)
class ReindexCorpusSerializer(serializers.Serializer):
drop = serializers.BooleanField(default=True, help_text='Drop existing collections for this corpus.')
def save(self, **kwargs):
corpus_id = self.context.get('corpus_id')
user_id = self.context.get('user_id')
assert corpus_id and user_id, 'corpus_id and user_id must be passed in the serializer context'
# Ensure the reindex job has not already been started
job_id = f'reindex-{corpus_id}'
if (job := get_existing_job(job_id)) is not None:
# A previous job can only be removed if finished
if job.ended_at is None:
raise ValidationError({
'__all__': [f'A job is already running to build search index on corpus {corpus_id}.']
})
job.delete()
reindex_corpus(**self.validated_data, corpus_id=corpus_id, user_id=user_id, job_id=job_id)
......@@ -3,10 +3,13 @@ from typing import Optional
from uuid import UUID
from django.conf import settings
from django.core.mail import send_mail
from django.db.models import Q
from django.template.loader import render_to_string
from django_rq import job
from rq import Retry, get_current_job
from arkindex.documents.indexer import Indexer
from arkindex.documents.managers import ElementQuerySet
from arkindex.documents.models import (
Classification,
......@@ -20,6 +23,7 @@ from arkindex.documents.models import (
TranscriptionEntity,
)
from arkindex.process.models import Process, ProcessElement, WorkerActivity, WorkerRun
from arkindex.users.models import User
logger = logging.getLogger(__name__)
......@@ -227,3 +231,33 @@ def add_parent_selection(corpus_id: UUID, parent: Element) -> None:
for i, item in enumerate(queryset):
rq_job.set_progress(i / total)
item.add_parent(parent)
@job('default', timeout=settings.RQ_TIMEOUTS['reindex_corpus'])
def reindex_corpus(corpus_id: UUID, drop: bool = True) -> None:
rq_job = get_current_job()
assert rq_job is not None, 'This task can only be run in a RQ job context.'
assert rq_job.user_id is not None, 'This task requires a user ID to be defined on the RQ job.'
indexer = Indexer(corpus_id)
if drop:
indexer.drop_index()
indexer.setup()
indexer.index()
# Report to the user that the index build finished
user = User.objects.get(id=rq_job.user_id)
corpus = Corpus.objects.get(id=corpus_id)
send_mail(
subject=f'Project {corpus.name} was successfully indexed',
message=render_to_string(
'reindex_corpus.html',
context={
'user': user,
'corpus': corpus,
},
),
from_email=None,
recipient_list=[user.email],
fail_silently=True,
)
......@@ -100,7 +100,7 @@ class TestDeleteCorpus(FixtureTestCase):
cls.process = cls.rev.processes.create(
creator=cls.user,
corpus=cls.corpus2,
mode=ProcessMode.Repository,
mode=ProcessMode.Files,
)
cls.df = cls.process.files.create(
name='a.txt',
......
from unittest.mock import call, patch
from arkindex.documents.tasks import reindex_corpus
from arkindex.project.tests import FixtureTestCase
EXPECTED_EMAIL_BODY = """
Hello Test user,
The search index build you started on project Unit Tests has been finished successfully.
Indexed elements can be explored from the search interface on Arkindex.
--
Arkindex
"""
class TestReindexCorpus(FixtureTestCase):
def test_no_rq_job(self):
with self.assertRaises(AssertionError) as ctx:
reindex_corpus(corpus_id=self.corpus.id)
self.assertEqual(str(ctx.exception), 'This task can only be run in a RQ job context.')
@patch('arkindex.documents.tasks.get_current_job')
def test_no_user_id(self, job_mock):
job_mock.return_value.user_id = None
with self.assertRaises(AssertionError) as ctx:
reindex_corpus(corpus_id=self.corpus.id)
self.assertEqual(str(ctx.exception), 'This task requires a user ID to be defined on the RQ job.')
@patch('arkindex.documents.tasks.send_mail')
@patch('arkindex.documents.tasks.get_current_job')
@patch('arkindex.documents.tasks.Indexer')
def test_run_drop_false(self, indexer_mock, job_mock, send_mail_mock):
job_mock.return_value.user_id = self.user.id
reindex_corpus(corpus_id=self.corpus.id, drop=False)
self.assertEqual(indexer_mock.call_count, 1)
self.assertEqual(indexer_mock.call_args, call(self.corpus.id))
self.assertEqual(indexer_mock.return_value.drop_index.call_count, 0)
self.assertEqual(indexer_mock.return_value.setup.call_count, 1)
self.assertEqual(indexer_mock.return_value.index.call_count, 1)
self.assertListEqual(send_mail_mock.call_args_list, [
call(
subject=f'Project {self.corpus.name} was successfully indexed',
message=EXPECTED_EMAIL_BODY,
from_email=None,
recipient_list=[self.user.email],
fail_silently=True,
)
])
@patch('arkindex.documents.tasks.send_mail')
@patch('arkindex.documents.tasks.get_current_job')
@patch('arkindex.documents.tasks.Indexer')
def test_run(self, indexer_mock, job_mock, send_mail_mock):
job_mock.return_value.user_id = self.user.id
reindex_corpus(corpus_id=self.corpus.id)
self.assertEqual(indexer_mock.call_count, 1)
self.assertEqual(indexer_mock.call_args, call(self.corpus.id))
self.assertEqual(indexer_mock.return_value.drop_index.call_count, 1)
self.assertEqual(indexer_mock.return_value.setup.call_count, 1)
self.assertEqual(indexer_mock.return_value.index.call_count, 1)
self.assertListEqual(send_mail_mock.call_args_list, [
call(
subject=f'Project {self.corpus.name} was successfully indexed',
message=EXPECTED_EMAIL_BODY,
from_email=None,
recipient_list=[self.user.email],
fail_silently=True,
)
])
import uuid
from django.urls import reverse
from rest_framework import status
from arkindex.process.models import WorkerVersion
from arkindex.project.tests import FixtureAPITestCase
from arkindex.users.models import Role
class TestBulkTranscriptionEntities(FixtureAPITestCase):
@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
cls.worker_run = cls.worker_version.worker_runs.get()
cls.transcription = cls.corpus.elements.get(name='Volume 1, page 1r').transcriptions.create(
text='Once upon a time in a castle a knight was living lonely',
confidence=0.42,
)
cls.location_ent_type = cls.corpus.entity_types.get(name="location")
cls.person_ent_type = cls.corpus.entity_types.get(name="person")
cls.entity = cls.corpus.entities.create(name="Paris", type=cls.location_ent_type)
def test_requires_login(self):
with self.assertNumQueries(0):
response = self.client.post(reverse('api:transcription-entities-bulk', kwargs={'pk': self.transcription.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_requires_verified(self):
self.user.verified_email = False
self.user.save()
self.client.force_login(self.user)
with self.assertNumQueries(2):
response = self.client.post(reverse('api:transcription-entities-bulk', kwargs={'pk': self.transcription.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.json(), {'detail': 'You do not have permission to perform this action.'})
def test_requires_contributor(self):
self.user.rights.update(level=Role.Guest.value)
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(reverse('api:transcription-entities-bulk', kwargs={'pk': self.transcription.id}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.json(), {
'detail': 'You do not have a contributor access to the corpus of this transcription.'
})
def test_not_found(self):
self.client.force_login(self.user)
with self.assertNumQueries(3):
response = self.client.post(reverse('api:transcription-entities-bulk', kwargs={'pk': uuid.uuid4()}))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_required_fields(self):
self.client.force_login(self.user)
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:transcription-entities-bulk', kwargs={'pk': str(self.transcription.id)}),
data={},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {
'entities': ['This field is required.'],
'worker_run_id': ['This field is required.'],
})
def test_wrong_values(self):
fake_uuid = uuid.uuid4()
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-entities-bulk', kwargs={'pk': str(self.transcription.id)}),
data={
'entities': [],
'worker_run_id': str(fake_uuid),
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json(), {
'entities': {'non_field_errors': ['This list may not be empty.']},
'worker_run_id': [f'Invalid pk "{fake_uuid}" - object does not exist.'],
})
def test_existing_transcription_entity(self):
self.transcription.transcription_entities.create(
entity=self.entity,
offset=1,
length=1,
worker_run=self.worker_run,
worker_version=self.worker_run.version,
)
self.client.force_login(self.user)
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:transcription-entities-bulk', kwargs={'pk': str(self.transcription.id)}),
data={
'entities': [
{
'name': 'Knight',
'type_id': str(self.person_ent_type.id),
'offset': 10,
'length': 5,
'confidence': 0.05,
},
],
'worker_run_id': str(self.worker_run.id),
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_run_id': ['Some TranscriptionEntities were already created for this worker run on this transcription.'],
})
def test_entity_fields_validation(self):
self.client.force_login(self.user)
with self.assertNumQueries(7):
response = self.client.post(
reverse('api:transcription-entities-bulk', kwargs={'pk': str(self.transcription.id)}),
data={
'entities': [
{}
],
'worker_run_id': str(self.worker_run.id),
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'entities': [{
'confidence': ['This field is required.'],
'length': ['This field is required.'],
'name': ['This field is required.'],
'offset': ['This field is required.'],
'type_id': ['This field is required.'],
}]
})
def test_duplicate_entities(self):
self.client.force_login(self.user)
payload = {
'name': 'Knight',
'type_id': str(self.person_ent_type.id),
'offset': 10,
'length': 5,
'confidence': 0.05,
}
with self.assertNumQueries(9):
response = self.client.post(
reverse('api:transcription-entities-bulk', kwargs={'pk': str(self.transcription.id)}),
data={
'entities': [
payload,
payload,
],
'worker_run_id': str(self.worker_run.id),
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'entities': {
'__all__': ['Some TranscriptionEntities have the same name, type, offset and length.'],
},
})
def test_unknown_type(self):
pass
def test_new_entities(self):
self.client.force_login(self.user)
self.assertQuerysetEqual(
self.corpus.entities.values_list('name', 'type', 'worker_run_id').order_by('name', 'type', 'worker_run_id'),
[('Paris', self.location_ent_type.id, None)],
)
self.assertEqual(self.transcription.transcription_entities.count(), 0)
with self.assertNumQueries(13):
response = self.client.post(
reverse('api:transcription-entities-bulk', kwargs={'pk': str(self.transcription.id)}),
data={
'entities': [
{
# This entity already exists, and should be duplicated
'name': 'Paris',
'type_id': str(self.location_ent_type.id),
'offset': 0,
'length': 1,
'confidence': 0.7,
},
{
'name': 'Knight',
'type_id': str(self.person_ent_type.id),
'offset': 0,
'length': 1,
'confidence': 0.7,
},
{
'name': 'Knight',
'type_id': str(self.person_ent_type.id),
'offset': 10,
'length': 5,
'confidence': 0.05,
},
],
'worker_run_id': str(self.worker_run.id),
},
format='json',
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertQuerysetEqual(
self.corpus.entities.values_list('name', 'type', 'worker_run_id').order_by('name', 'type', 'worker_run_id'),
[
('Knight', self.person_ent_type.id, self.worker_run.id),
('Knight', self.person_ent_type.id, self.worker_run.id),
('Paris', self.location_ent_type.id, self.worker_run.id),
('Paris', self.location_ent_type.id, None),
],
)
self.assertQuerysetEqual(
self.transcription.transcription_entities
.values_list('entity__name', 'entity__type', 'offset', 'length', 'confidence', 'worker_run_id')
.order_by('entity__name', 'offset'),
[
('Knight', self.person_ent_type.id, 0, 1, 0.7, self.worker_run.id),
('Knight', self.person_ent_type.id, 10, 5, 0.05, self.worker_run.id),
('Paris', self.location_ent_type.id, 0, 1, 0.7, self.worker_run.id),
],
)
knight1, knight2, paris = self.transcription.transcription_entities.order_by('entity__name', 'offset')
# Ensure that the returned IDs match the order in which the payloads were sent
self.assertEqual(response.json(), {
"entities": [
{
'transcription_entity_id': paris.id,
'entity_id': str(paris.entity_id),
},
{
'transcription_entity_id': knight1.id,
'entity_id': str(knight1.entity_id),
},
{
'transcription_entity_id': knight2.id,
'entity_id': str(knight2.entity_id),
},
],
})
......@@ -975,6 +975,42 @@ class TestEntitiesAPI(FixtureAPITestCase):
'__all__': ['This entity is already linked to this transcription by this worker run at this position.']
})
def test_create_transcription_entity_manual_existing_worker_run(self):
"""
A manual TranscriptionEntity can be created even when one exists with the same attributes from a WorkerRun
"""
self.client.force_login(self.internal_user)
TranscriptionEntity.objects.create(
transcription=self.transcription,
entity=self.entity,
offset=4,
length=8,
worker_run=self.worker_run,
worker_version=self.worker_version_1
)
with self.assertNumQueries(6):
response = self.client.post(
reverse('api:transcription-entity-create', kwargs={'pk': str(self.transcription.id)}),
data={
'entity': str(self.entity.id),
'offset': 4,
'length': 8,
},
format='json'
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
transcription_entity = self.transcription.transcription_entities.filter(worker_run=None, entity=self.entity).get()
self.assertDictEqual(
response.json(),
{
'entity': str(transcription_entity.entity.id),
'offset': transcription_entity.offset,
'length': transcription_entity.length,
'worker_run': None,
'confidence': None
}
)
def test_create_transcription_entity_key_missing(self):
self.client.force_login(self.user)
data = self.tr_entities_sample
......