Skip to content
Snippets Groups Projects
Commit c53ed02b authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'workerversion-fk-on-transcription' into 'master'

Add a FK on Transcription towards a WorkerVersion

Closes #366

See merge request !872
parents 9da2f108 91daa2c6
No related branches found
No related tags found
1 merge request!872Add a FK on Transcription towards a WorkerVersion
......@@ -17,4 +17,14 @@ class Migration(migrations.Migration):
name='worker_version',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='elements', to='dataimport.WorkerVersion'),
),
migrations.AddField(
model_name='transcription',
name='worker_version',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='transcriptions', to='dataimport.WorkerVersion'),
),
migrations.AlterField(
model_name='transcription',
name='source',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='transcriptions', to='documents.DataSource'),
),
]
......@@ -442,6 +442,15 @@ class Transcription(models.Model):
DataSource,
on_delete=models.CASCADE,
related_name='transcriptions',
null=True,
blank=True,
)
worker_version = models.ForeignKey(
'dataimport.WorkerVersion',
on_delete=models.CASCADE,
related_name='transcriptions',
null=True,
blank=True,
)
text = models.TextField()
score = models.FloatField(null=True, blank=True)
......
......@@ -5,6 +5,7 @@ from rest_framework.exceptions import ValidationError
from arkindex_common.ml_tool import MLToolType
from arkindex_common.enums import TranscriptionType
from arkindex.project.serializer_fields import EnumField, DataSourceSlugField, PolygonField
from arkindex.dataimport.models import WorkerVersion
from arkindex.documents.models import (
Corpus, Element, ElementType, Transcription, DataSource, MLClass, Classification, ClassificationState
)
......@@ -219,12 +220,13 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer):
Allows the insertion of a manual transcription attached to an element
"""
type = EnumField(TranscriptionType)
source = DataSourceSlugField(tool_type=MLToolType.Recognizer)
source = DataSourceSlugField(tool_type=MLToolType.Recognizer, required=False)
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), required=False, allow_null=True)
score = serializers.FloatField(min_value=0, max_value=1, required=False)
class Meta:
model = Transcription
fields = ('text', 'type', 'source', 'score')
fields = ('text', 'type', 'source', 'worker_version', 'score')
def validate(self, data):
data = super().validate(data)
......@@ -233,18 +235,31 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer):
if not element.zone:
raise ValidationError({'element': ['The element has no zone']})
slug = data.get('source').slug
if slug == 'manual':
# Assert the type is allowed for manual transcription
allowed_transcription = element.type.allowed_transcription
if not allowed_transcription:
raise ValidationError({'element': ['The element type does not allow creating a manual transcription']})
if data['type'] is not allowed_transcription:
raise ValidationError({'type': [
f"Only transcriptions of type '{allowed_transcription.value}' are allowed for this element"
]})
return data
source = data.get('source')
worker_version = data.get('worker_version')
if not source and not worker_version:
raise ValidationError({
'source': ['This field XOR worker_version field must be set to create a transcription'],
'worker_version': ['This field XOR source field must be set to create a transcription']
})
elif source and worker_version:
raise ValidationError({
'source': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'],
'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription']
})
elif source:
slug = source.slug
if slug == 'manual':
# Assert the type is allowed for manual transcription
allowed_transcription = element.type.allowed_transcription
if not allowed_transcription:
raise ValidationError({'element': ['The element type does not allow creating a manual transcription']})
if data['type'] is not allowed_transcription:
raise ValidationError({'type': [
f"Only transcriptions of type '{allowed_transcription.value}' are allowed for this element"
]})
return data
# Additional validation for transcriptions with an internal source
if not data.get('score'):
......@@ -252,10 +267,16 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer):
user = self.context['request'].user
if (not user or not user.is_internal):
raise ValidationError({'source': [
if source:
raise ValidationError({'source': [
'An internal user is required to create a transcription with '
f'the internal source "{slug}"'
]})
raise ValidationError({'worker_version': [
'An internal user is required to create a transcription with '
f'the internal source "{slug}"'
f'the worker_version "{worker_version.id}"'
]})
return data
......
......@@ -6,6 +6,7 @@ from rest_framework import status
from arkindex.project.tests import FixtureAPITestCase
from arkindex_common.enums import TranscriptionType
from arkindex_common.ml_tool import MLToolType
from arkindex.dataimport.models import Worker, WorkerVersion
from arkindex.documents.models import Corpus, Transcription, DataSource
from arkindex.users.models import User
from uuid import uuid4
......@@ -29,6 +30,9 @@ class TestTranscriptionCreate(FixtureAPITestCase):
cls.private_read_user.verified_email = True
cls.private_read_user.save()
cls.private_corpus.corpus_right.create(user=cls.private_read_user)
cls.creds = cls.user.credentials.get()
cls.repo = cls.creds.repos.get()
cls.rev = cls.repo.revisions.get()
def setUp(self):
self.manual_source = DataSource.objects.create(type=MLToolType.Recognizer, slug='manual', internal=False)
......@@ -228,6 +232,128 @@ class TestTranscriptionCreate(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertFalse(get_layer_mock().send.called)
def test_create_transcription_no_source_no_worker_version(self):
self.client.force_login(self.user)
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
data={
'type': 'word',
'text': 'NEKUDOTAYIM',
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'source': ['This field XOR worker_version field must be set to create a transcription'],
'worker_version': ['This field XOR source field must be set to create a transcription']
})
def test_create_transcription_source_and_worker_version_returns_error(self):
self.client.force_login(self.user)
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
data={
'type': 'word',
'text': 'NEKUDOTAYIM',
'source': 'manual',
'worker_version': str(version.id),
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'source': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'],
'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription']
})
def test_create_transcription_worker_version_non_internal(self):
self.client.force_login(self.user)
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
data={
'type': 'word',
'text': 'NEKUDOTAYIM',
'worker_version': str(version.id),
'score': .42
}
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_version': [f'An internal user is required to create a transcription with the worker_version "{version.id}"']
})
@patch('arkindex.project.triggers.get_channel_layer')
def test_create_transcription_worker_version(self, get_layer_mock):
get_layer_mock.return_value.send = AsyncMock()
self.client.force_login(self.internal_user)
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
response = self.client.post(
reverse('api:transcription-create', kwargs={'pk': self.line.id}),
format='json',
data={
'type': 'word',
'text': 'NEKUDOTAYIM',
'worker_version': str(version.id),
'score': .42
}
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
tr = Transcription.objects.get(text='NEKUDOTAYIM')
self.assertEqual(tr.worker_version, version)
self.assertDictEqual(response.json(), {
'id': str(tr.id),
'score': .42,
'source': None,
'text': 'NEKUDOTAYIM',
'type': 'word',
'zone': None
})
get_layer_mock().send.assert_called_once_with('reindex', {
'type': 'reindex.start',
'element': str(self.line.id),
'corpus': None,
'entity': None,
'transcriptions': True,
'elements': True,
'entities': False,
'drop': False,
})
def test_manual_transcription_forbidden_type(self):
"""
Creating a manual transcription with a non allowed type is forbidden
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment