diff --git a/arkindex/documents/migrations/0018_element_worker_version.py b/arkindex/documents/migrations/0018_worker_version_attributes.py similarity index 51% rename from arkindex/documents/migrations/0018_element_worker_version.py rename to arkindex/documents/migrations/0018_worker_version_attributes.py index 940d1266de804108d0a8c8cb9152ef7e5d9aedea..059be105bc396e229014835bf4825be44aca84eb 100644 --- a/arkindex/documents/migrations/0018_element_worker_version.py +++ b/arkindex/documents/migrations/0018_worker_version_attributes.py @@ -17,4 +17,14 @@ class Migration(migrations.Migration): name='worker_version', field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='elements', to='dataimport.WorkerVersion'), ), + migrations.AddField( + model_name='transcription', + name='worker_version', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='transcriptions', to='dataimport.WorkerVersion'), + ), + migrations.AlterField( + model_name='transcription', + name='source', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='transcriptions', to='documents.DataSource'), + ), ] diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 5af509a1c5cdf98f74a941f3904c666c4fc25a84..0799cc173e49fe9afd6f5981e496226aba4f8243 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -442,6 +442,15 @@ class Transcription(models.Model): DataSource, on_delete=models.CASCADE, related_name='transcriptions', + null=True, + blank=True, + ) + worker_version = models.ForeignKey( + 'dataimport.WorkerVersion', + on_delete=models.CASCADE, + related_name='transcriptions', + null=True, + blank=True, ) text = models.TextField() score = models.FloatField(null=True, blank=True) diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 778249896bb2e812f49c35b27a60f3b6a0334d00..7d8d21899b9ee5e96307077f20d313aa92f04fd8 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -5,6 +5,7 @@ from rest_framework.exceptions import ValidationError from arkindex_common.ml_tool import MLToolType from arkindex_common.enums import TranscriptionType from arkindex.project.serializer_fields import EnumField, DataSourceSlugField, PolygonField +from arkindex.dataimport.models import WorkerVersion from arkindex.documents.models import ( Corpus, Element, ElementType, Transcription, DataSource, MLClass, Classification, ClassificationState ) @@ -219,12 +220,13 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer): Allows the insertion of a manual transcription attached to an element """ type = EnumField(TranscriptionType) - source = DataSourceSlugField(tool_type=MLToolType.Recognizer) + source = DataSourceSlugField(tool_type=MLToolType.Recognizer, required=False) + worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), required=False, allow_null=True) score = serializers.FloatField(min_value=0, max_value=1, required=False) class Meta: model = Transcription - fields = ('text', 'type', 'source', 'score') + fields = ('text', 'type', 'source', 'worker_version', 'score') def validate(self, data): data = super().validate(data) @@ -233,18 +235,31 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer): if not element.zone: raise ValidationError({'element': ['The element has no zone']}) - slug = data.get('source').slug - if slug == 'manual': - # Assert the type is allowed for manual transcription - allowed_transcription = element.type.allowed_transcription - if not allowed_transcription: - raise ValidationError({'element': ['The element type does not allow creating a manual transcription']}) - - if data['type'] is not allowed_transcription: - raise ValidationError({'type': [ - f"Only transcriptions of type '{allowed_transcription.value}' are allowed for this element" - ]}) - return data + source = data.get('source') + worker_version = data.get('worker_version') + if not source and not worker_version: + raise ValidationError({ + 'source': ['This field XOR worker_version field must be set to create a transcription'], + 'worker_version': ['This field XOR source field must be set to create a transcription'] + }) + elif source and worker_version: + raise ValidationError({ + 'source': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'], + 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'] + }) + elif source: + slug = source.slug + if slug == 'manual': + # Assert the type is allowed for manual transcription + allowed_transcription = element.type.allowed_transcription + if not allowed_transcription: + raise ValidationError({'element': ['The element type does not allow creating a manual transcription']}) + + if data['type'] is not allowed_transcription: + raise ValidationError({'type': [ + f"Only transcriptions of type '{allowed_transcription.value}' are allowed for this element" + ]}) + return data # Additional validation for transcriptions with an internal source if not data.get('score'): @@ -252,10 +267,16 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer): user = self.context['request'].user if (not user or not user.is_internal): - raise ValidationError({'source': [ + if source: + raise ValidationError({'source': [ + 'An internal user is required to create a transcription with ' + f'the internal source "{slug}"' + ]}) + raise ValidationError({'worker_version': [ 'An internal user is required to create a transcription with ' - f'the internal source "{slug}"' + f'the worker_version "{worker_version.id}"' ]}) + return data diff --git a/arkindex/documents/tests/test_create_transcriptions.py b/arkindex/documents/tests/test_create_transcriptions.py index 1fcdc51b62438a949dab73b4b327da7d9845b211..81444eb8b6567f616fd46900e7261ea0008e5a17 100644 --- a/arkindex/documents/tests/test_create_transcriptions.py +++ b/arkindex/documents/tests/test_create_transcriptions.py @@ -6,6 +6,7 @@ from rest_framework import status from arkindex.project.tests import FixtureAPITestCase from arkindex_common.enums import TranscriptionType from arkindex_common.ml_tool import MLToolType +from arkindex.dataimport.models import Worker, WorkerVersion from arkindex.documents.models import Corpus, Transcription, DataSource from arkindex.users.models import User from uuid import uuid4 @@ -29,6 +30,9 @@ class TestTranscriptionCreate(FixtureAPITestCase): cls.private_read_user.verified_email = True cls.private_read_user.save() cls.private_corpus.corpus_right.create(user=cls.private_read_user) + cls.creds = cls.user.credentials.get() + cls.repo = cls.creds.repos.get() + cls.rev = cls.repo.revisions.get() def setUp(self): self.manual_source = DataSource.objects.create(type=MLToolType.Recognizer, slug='manual', internal=False) @@ -228,6 +232,128 @@ class TestTranscriptionCreate(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertFalse(get_layer_mock().send.called) + def test_create_transcription_no_source_no_worker_version(self): + self.client.force_login(self.user) + response = self.client.post( + reverse('api:transcription-create', kwargs={'pk': self.line.id}), + format='json', + data={ + 'type': 'word', + 'text': 'NEKUDOTAYIM', + } + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'source': ['This field XOR worker_version field must be set to create a transcription'], + 'worker_version': ['This field XOR source field must be set to create a transcription'] + }) + + def test_create_transcription_source_and_worker_version_returns_error(self): + self.client.force_login(self.user) + worker = Worker.objects.create( + repository=self.repo, + name='Worker 1', + slug='worker_1', + type=MLToolType.Classifier + ) + version = WorkerVersion.objects.create( + worker=worker, + revision=self.rev, + configuration={"test": "test1"} + ) + response = self.client.post( + reverse('api:transcription-create', kwargs={'pk': self.line.id}), + format='json', + data={ + 'type': 'word', + 'text': 'NEKUDOTAYIM', + 'source': 'manual', + 'worker_version': str(version.id), + } + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'source': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'], + 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'] + }) + + def test_create_transcription_worker_version_non_internal(self): + self.client.force_login(self.user) + worker = Worker.objects.create( + repository=self.repo, + name='Worker 1', + slug='worker_1', + type=MLToolType.Classifier + ) + version = WorkerVersion.objects.create( + worker=worker, + revision=self.rev, + configuration={"test": "test1"} + ) + response = self.client.post( + reverse('api:transcription-create', kwargs={'pk': self.line.id}), + format='json', + data={ + 'type': 'word', + 'text': 'NEKUDOTAYIM', + 'worker_version': str(version.id), + 'score': .42 + } + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'worker_version': [f'An internal user is required to create a transcription with the worker_version "{version.id}"'] + }) + + @patch('arkindex.project.triggers.get_channel_layer') + def test_create_transcription_worker_version(self, get_layer_mock): + get_layer_mock.return_value.send = AsyncMock() + + self.client.force_login(self.internal_user) + worker = Worker.objects.create( + repository=self.repo, + name='Worker 1', + slug='worker_1', + type=MLToolType.Classifier + ) + version = WorkerVersion.objects.create( + worker=worker, + revision=self.rev, + configuration={"test": "test1"} + ) + response = self.client.post( + reverse('api:transcription-create', kwargs={'pk': self.line.id}), + format='json', + data={ + 'type': 'word', + 'text': 'NEKUDOTAYIM', + 'worker_version': str(version.id), + 'score': .42 + } + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + tr = Transcription.objects.get(text='NEKUDOTAYIM') + self.assertEqual(tr.worker_version, version) + self.assertDictEqual(response.json(), { + 'id': str(tr.id), + 'score': .42, + 'source': None, + 'text': 'NEKUDOTAYIM', + 'type': 'word', + 'zone': None + }) + + get_layer_mock().send.assert_called_once_with('reindex', { + 'type': 'reindex.start', + 'element': str(self.line.id), + 'corpus': None, + 'entity': None, + 'transcriptions': True, + 'elements': True, + 'entities': False, + 'drop': False, + }) + def test_manual_transcription_forbidden_type(self): """ Creating a manual transcription with a non allowed type is forbidden