diff --git a/arkindex/documents/api/entities.py b/arkindex/documents/api/entities.py index 3d2d505a15f3c17efb2cbd2dbb89f8d423505e65..cc69925745f4afda656150789830e793b38a9acb 100644 --- a/arkindex/documents/api/entities.py +++ b/arkindex/documents/api/entities.py @@ -137,14 +137,12 @@ class EntityCreate(CreateAPIView): type = serializer.validated_data['type'] corpus = serializer.validated_data['corpus'] metas = serializer.validated_data['metas'] if 'metas' in serializer.data else None - source = serializer.validated_data['ner'] worker_version = serializer.validated_data['worker_version'] return Entity.objects.create( name=name, type=type, corpus=corpus, metas=metas, - source=source, worker_version=worker_version ) diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index f25151096b5711bbc372602cd2772e3a1011a851..885372f661949bacee4b5235f7c858acd7c3f7c8 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -111,7 +111,7 @@ class TranscriptionEdit(RetrieveUpdateDestroyAPIView): errors = defaultdict(list) if Right.Write not in rights: errors['__all__'].append('A write access to transcription element corpus is required.') - if transcription.source.slug != 'manual': + if transcription.worker_version or transcription.source and transcription.source.slug != 'manual': errors['__all__'].append('Only manual transcriptions can be edited.') if (errors): raise PermissionDenied(errors) @@ -137,7 +137,7 @@ class ElementTranscriptionsBulk(CreateAPIView): 'value': { 'element_type': 'string', 'transcription_type': 'string', - 'source': 'string', + 'worker_version': 'string', 'transcriptions': { 'polygon': [[]], 'text': 'string', @@ -194,7 +194,6 @@ class ElementTranscriptionsBulk(CreateAPIView): def perform_create(self, serializer): elt_type = serializer.validated_data['element_type'] tr_type = serializer.validated_data['transcription_type'] - source = serializer.validated_data['source'] worker_version = serializer.validated_data['worker_version'] annotations = serializer.validated_data['transcriptions'] @@ -280,7 +279,6 @@ class ElementTranscriptionsBulk(CreateAPIView): element=annotation['element'], type=tr_type, zone=None, - source=source, worker_version=worker_version, text=annotation['text'], score=annotation['score'] @@ -379,7 +377,7 @@ class ClassificationCreate(CreateAPIView): } def perform_create(self, serializer): - if serializer.validated_data['source'] and serializer.validated_data['source'].slug == 'manual': + if serializer.validated_data['worker_version'] is None: # A manual classification is immediately valid serializer.save( moderator=self.request.user, diff --git a/arkindex/documents/migrations/0020_remove_source_xor_version_constraint.py b/arkindex/documents/migrations/0020_remove_source_xor_version_constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..8907ff7578fcfdd2237963ec17610a08c7470a2c --- /dev/null +++ b/arkindex/documents/migrations/0020_remove_source_xor_version_constraint.py @@ -0,0 +1,33 @@ +# Generated by Django 3.1 on 2020-10-14 13:56 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('documents', '0019_corpus_repository'), + ] + + operations = [ + migrations.RemoveConstraint( + model_name='classification', + name='classification_unique_source', + ), + migrations.RemoveConstraint( + model_name='classification', + name='classification_unique_worker_version', + ), + migrations.RemoveConstraint( + model_name='classification', + name='classification_source_xor_workerversion', + ), + migrations.AddConstraint( + model_name='classification', + constraint=models.UniqueConstraint(condition=models.Q(('source_id__isnull', True), ('worker_version_id__isnull', True)), fields=('element', 'ml_class'), name='classification_unique_manual'), + ), + migrations.AddConstraint( + model_name='classification', + constraint=models.UniqueConstraint(condition=models.Q(('source_id__isnull', True), ('worker_version_id__isnull', False)), fields=('element', 'ml_class', 'worker_version'), name='classification_unique_worker_version'), + ), + ] diff --git a/arkindex/documents/models.py b/arkindex/documents/models.py index 1e59a848e1985a64c341732d42eb60ba18587ba0..7f77e4d49446607b318f429da5d881962d9fafb2 100644 --- a/arkindex/documents/models.py +++ b/arkindex/documents/models.py @@ -578,19 +578,16 @@ class Classification(models.Model): class Meta: constraints = [ + # Add class unicity for manual and non manual classifications on an element models.UniqueConstraint( - fields=['element', 'ml_class', 'source'], - name='classification_unique_source', - condition=Q(worker_version_id__isnull=True), + fields=['element', 'ml_class'], + name='classification_unique_manual', + condition=Q(worker_version_id__isnull=True, source_id__isnull=True), ), models.UniqueConstraint( fields=['element', 'ml_class', 'worker_version'], name='classification_unique_worker_version', - condition=Q(source_id__isnull=True), - ), - models.CheckConstraint( - check=Q(source_id__isnull=False, worker_version_id__isnull=True) | Q(source_id__isnull=True, worker_version_id__isnull=False), - name='classification_source_xor_workerversion', + condition=Q(worker_version_id__isnull=False, source_id__isnull=True), ) ] diff --git a/arkindex/documents/serializers/elements.py b/arkindex/documents/serializers/elements.py index e4c08bae036d124322bc4556fc727ff68d16123d..e63d8e3a0b3a1266676c39579d491ca5664002db 100644 --- a/arkindex/documents/serializers/elements.py +++ b/arkindex/documents/serializers/elements.py @@ -4,7 +4,6 @@ from django.contrib.gis.geos import LinearRing from django.utils.functional import cached_property from rest_framework import serializers from rest_framework.exceptions import ValidationError -from arkindex_common.ml_tool import MLToolType from arkindex_common.enums import MetaType from arkindex.dataimport.models import WorkerVersion from arkindex.images.serializers import ZoneSerializer @@ -18,7 +17,7 @@ from arkindex.documents.serializers.light import ( ) from arkindex.documents.serializers.entities import BaseEntitySerializer, TranscriptionEntityDetailsSerializer from arkindex.documents.serializers.ml import ClassificationSerializer, DataSourceSerializer -from arkindex.project.serializer_fields import LinearRingField, DataSourceSlugField +from arkindex.project.serializer_fields import LinearRingField class MetaDataSerializer(MetaDataLightSerializer): @@ -338,12 +337,11 @@ class ElementCreateSerializer(ElementLightSerializer): help_text='Set the polygon linking this element to the image. ' '`image` must be set when this field is set. Defaults to a rectangle taking up the whole image.', ) - source = DataSourceSlugField(tool_type=MLToolType.DLAnalyser, required=False) worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), required=False, allow_null=True) class Meta(ElementLightSerializer.Meta): model = Element - fields = ElementLightSerializer.Meta.fields + ('image', 'corpus', 'parent', 'polygon', 'source', 'worker_version') + fields = ElementLightSerializer.Meta.fields + ('image', 'corpus', 'parent', 'polygon', 'worker_version') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -379,10 +377,6 @@ class ElementCreateSerializer(ElementLightSerializer): # will lead to many errors everywhere as this would create impossible polygons errors['image'].append('This image does not have valid dimensions.') - if data.get('source') and data.get('worker_version'): - errors['source'].append('You can only refer to a DataSource XOR a WorkerVersion on an element') - errors['worker_version'].append('You can only refer to a DataSource XOR a WorkerVersion on an element') - if errors: raise ValidationError(errors) return data @@ -423,7 +417,6 @@ class ElementCreateSerializer(ElementLightSerializer): corpus=validated_data['corpus'], type=validated_data['type'], name=validated_data['name'], - source=validated_data.get('source'), worker_version=validated_data.get('worker_version'), zone=zone, ) diff --git a/arkindex/documents/serializers/entities.py b/arkindex/documents/serializers/entities.py index 0cf5169f133cb520765b960d576f81c68c4261ec..21a0c6034b35921a8fd887f3598bd6eaadcd01da 100644 --- a/arkindex/documents/serializers/entities.py +++ b/arkindex/documents/serializers/entities.py @@ -1,13 +1,11 @@ from rest_framework import serializers -from rest_framework.exceptions import ValidationError -from arkindex_common.ml_tool import MLToolType from arkindex.dataimport.models import WorkerVersion from arkindex.documents.models import \ Corpus, Entity, EntityLink, EntityRole, TranscriptionEntity from arkindex_common.enums import EntityType from arkindex.documents.serializers.light import CorpusLightSerializer, InterpretedDateSerializer from arkindex.documents.serializers.ml import DataSourceSerializer -from arkindex.project.serializer_fields import EnumField, DataSourceSlugField +from arkindex.project.serializer_fields import EnumField from arkindex.project.triggers import reindex_start @@ -130,7 +128,6 @@ class EntityCreateSerializer(BaseEntitySerializer): metas = serializers.HStoreField(child=serializers.CharField(), required=False) children = EntityLinkSerializer(many=True, read_only=True) parents = EntityLinkSerializer(many=True, read_only=True) - ner = DataSourceSlugField(tool_type=MLToolType.NER, default=None) worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None) class Meta: @@ -144,7 +141,6 @@ class EntityCreateSerializer(BaseEntitySerializer): 'corpus', 'parents', 'children', - 'ner', 'worker_version' ) read_only_fields = ( @@ -161,22 +157,6 @@ class EntityCreateSerializer(BaseEntitySerializer): corpora = Corpus.objects.writable(self.context['request'].user) self.fields['corpus'].queryset = corpora - def validate(self, data): - ner = data.get('ner') - worker_version = data.get('worker_version') - if not ner and not worker_version: - raise ValidationError({ - 'ner': ['This field XOR worker_version field must be set to create an entity'], - 'worker_version': ['This field XOR ner field must be set to create an entity'] - }) - elif ner and worker_version: - raise ValidationError({ - 'ner': ['You can only refer to a DataSource XOR a WorkerVersion on an entity'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on an entity'] - }) - - return data - class EntityLinkCreateSerializer(EntityLinkSerializer): """ diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index dae717afba2c9ec867afc28aff76ffca7364c11c..65a34c6e7d267d75baddd25d5aa2ab8371ecc91d 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -9,7 +9,7 @@ from arkindex.dataimport.models import WorkerVersion from arkindex.documents.models import ( Corpus, Element, ElementType, Transcription, DataSource, MLClass, Classification, ClassificationState ) -from arkindex.project.serializer_fields import EnumField, DataSourceSlugField, LinearRingField +from arkindex.project.serializer_fields import EnumField, LinearRingField from arkindex.images.serializers import ZoneSerializer from arkindex.documents.serializers.light import ElementZoneSerializer import uuid @@ -113,21 +113,20 @@ class ClassificationCreateSerializer(serializers.ModelSerializer): """ Serializer to create a single classification, defaulting to manual """ - source = DataSourceSlugField(tool_type=MLToolType.Classifier, default=None) worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None) confidence = serializers.FloatField( min_value=0, max_value=1, required=False, - help_text='Confidence score for this classification. Required for non-manual sources.' - 'Will be ignored and set to 1.0 for classifications on a `manual` source.', + help_text='Confidence score for this classification. Required for classifications with a worker version.' + 'Will be ignored and set to 1.0 for a `manual` classification.', ) - # Use a NullBooleanField here to make it default to None and check later on non-manual sources + # Use a NullBooleanField here to make it default to None and check later on non-manual classifications high_confidence = serializers.NullBooleanField( required=False, help_text='Whether or not a machine learning tool marks this as the correct classification. ' - 'Required for non-manual sources. ' - 'Will be ignored and set to True for classifications on a `manual` source.', + 'Required for classifications with a worker version. ' + 'Will be ignored and set to True for `manual` classifications.', ) state = EnumField(ClassificationState, read_only=True) @@ -137,7 +136,6 @@ class ClassificationCreateSerializer(serializers.ModelSerializer): 'id', 'element', 'ml_class', - 'source', 'worker_version', 'confidence', 'high_confidence', @@ -146,12 +144,12 @@ class ClassificationCreateSerializer(serializers.ModelSerializer): read_only_fields = ('id', 'state') validators = [ UniqueTogetherValidator( - queryset=Classification.objects.filter(worker_version_id__isnull=True), - fields=['element', 'source', 'ml_class'] + queryset=Classification.objects.filter(worker_version__isnull=False, source_id__isnull=True), + fields=['element', 'worker_version', 'ml_class'] ), UniqueTogetherValidator( - queryset=Classification.objects.filter(source_id__isnull=True), - fields=['element', 'worker_version', 'ml_class'] + queryset=Classification.objects.filter(worker_version__isnull=True, source_id__isnull=True), + fields=['element', 'ml_class'] ) ] @@ -166,43 +164,26 @@ class ClassificationCreateSerializer(serializers.ModelSerializer): corpus__in=Corpus.objects.writable(self.context['request'].user)) def validate(self, data): - # Do not check data for manual classifications - # Note that (source, class, element) unicity is already checked by DRF + # Note that (worker_version, class, element) unicity is already checked by DRF errors = {} user = self.context['request'].user - source = data.get('source') worker_version = data.get('worker_version') - if not source and not worker_version: - raise ValidationError({ - 'source': ['This field XOR worker_version field must be set to create a classification'], - 'worker_version': ['This field XOR source field must be set to create a classification'] - }) - elif source and worker_version: - raise ValidationError({ - 'source': ['You can only refer to a DataSource XOR a WorkerVersion on a classification'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a classification'] - }) - elif source: - slug = data.get('source').slug - if slug == 'manual': - return data - - if not user or not user.is_internal: - errors['source'] = [ - 'An internal user is required to create a classification with ' - f'the non-manual source "{slug}"' - ] - elif worker_version and (not user or not user.is_internal): - errors['worker_version'] = [ - 'An internal user is required to create a classification with ' - f'the worker_version "{worker_version.id}"' - ] - - # Additional validation for transcriptions with an internal source + + # Do not check data for manual classifications + if worker_version is None: + return data + + # Additional validation for transcriptions with a worker version + version_required_error = 'is required to create a classification with a worker version.' + + if not user or not user.is_internal: + errors['worker_version'] = [f'An internal user {version_required_error}'] + if data.get('confidence') is None: - errors['confidence'] = ['This field is required for non manual sources.'] + errors['confidence'] = [f'This field {version_required_error}'] + if data.get('high_confidence') is None: - errors['high_confidence'] = ['This field is required for non manual sources.'] + errors['high_confidence'] = [f'This field {version_required_error}'] if errors: raise ValidationError(errors) @@ -289,13 +270,12 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer): Allows the insertion of a manual transcription attached to an element """ type = EnumField(TranscriptionType) - source = DataSourceSlugField(tool_type=MLToolType.Recognizer, required=False) worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), required=False, allow_null=True) score = serializers.FloatField(min_value=0, max_value=1, required=False) class Meta: model = Transcription - fields = ('text', 'type', 'source', 'worker_version', 'score') + fields = ('text', 'type', 'worker_version', 'score') def validate(self, data): data = super().validate(data) @@ -305,46 +285,27 @@ class TranscriptionCreateSerializer(serializers.ModelSerializer): if not element.zone_id: raise ValidationError({'element': ['The element has no zone']}) - source = data.get('source') worker_version = data.get('worker_version') - if not source and not worker_version: - raise ValidationError({ - 'source': ['This field XOR worker_version field must be set to create a transcription'], - 'worker_version': ['This field XOR source field must be set to create a transcription'] - }) - elif source and worker_version: - raise ValidationError({ - 'source': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'] - }) - elif source: - slug = source.slug - if slug == 'manual': - # Assert the type is allowed for manual transcription - allowed_transcription = element.type.allowed_transcription - if not allowed_transcription: - raise ValidationError({'element': ['The element type does not allow creating a manual transcription']}) - - if data['type'] is not allowed_transcription: - raise ValidationError({'type': [ - f"Only transcriptions of type '{allowed_transcription.value}' are allowed for this element" - ]}) - return data - - # Additional validation for transcriptions with an internal source + if worker_version is None: + # Assert the type is allowed for manual transcription + allowed_transcription = element.type.allowed_transcription + if not allowed_transcription: + raise ValidationError({'element': ['The element type does not allow creating a manual transcription']}) + + if data['type'] is not allowed_transcription: + raise ValidationError({'type': [ + f"Only transcriptions of type '{allowed_transcription.value}' are allowed for this element" + ]}) + return data + + # Additional validation for transcriptions with a worker version if not data.get('score'): - raise ValidationError({'score': ['This field is required for non manual sources.']}) + raise ValidationError({'score': ['This field is required for transcription with a worker version.']}) user = self.context['request'].user if (not user or not user.is_internal): - if source: - raise ValidationError({'source': [ - 'An internal user is required to create a transcription with ' - f'the internal source "{slug}"' - ]}) raise ValidationError({'worker_version': [ - 'An internal user is required to create a transcription with ' - f'the worker_version "{worker_version.id}"' + 'An internal user is required to create a transcription refering to a worker_version' ]}) return data @@ -373,16 +334,8 @@ class ElementTranscriptionsBulkSerializer(serializers.Serializer): TranscriptionType, help_text='A TranscriptionType for created transcriptions' ) - source = DataSourceSlugField( - tool_type=MLToolType.Recognizer, - required=False, - default=None, - help_text='A recognizer DataSource slug, unique per request too. It cannot be set to manual' - ) worker_version = serializers.PrimaryKeyRelatedField( queryset=WorkerVersion.objects.all(), - required=False, - default=None, help_text='A WorkerVersion ID that transcriptions will refer to' ) transcriptions = SimpleTranscriptionSerializer( @@ -406,24 +359,6 @@ class ElementTranscriptionsBulkSerializer(serializers.Serializer): # Use the parent types for validation as elements are in the same corpus self.fields['element_type'].queryset = ElementType.objects.filter(corpus=self.context['element'].corpus) - def validate(self, data): - data = super().validate(data) - source = data.get('source') - worker_version = data.get('worker_version') - if not source and not worker_version: - raise ValidationError({ - 'source': ['This field XOR worker_version field must be set to create transcriptions'], - 'worker_version': ['This field XOR source field must be set to create transcriptions'] - }) - elif source and worker_version: - raise ValidationError({ - 'source': ['You can only refer to a DataSource XOR a WorkerVersion on transcriptions'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on transcriptions'] - }) - elif source and source.slug == 'manual': - raise ValidationError({'source': ["Transcriptions source slug cannot be set to manual."]}) - return data - class AnnotatedElementSerializer(serializers.Serializer): """ @@ -453,7 +388,6 @@ class ClassificationsSerializer(serializers.Serializer): # The real queryset is set in __init__ queryset=Element.objects.none(), ) - classifier = DataSourceSlugField(tool_type=MLToolType.Classifier, default=None) worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None) classifications = ClassificationBulkSerializer(many=True, allow_empty=False) @@ -471,19 +405,13 @@ class ClassificationsSerializer(serializers.Serializer): def validate(self, data): data = super().validate(data) - if not (data['classifier'] is None) ^ (data['worker_version'] is None): - raise ValidationError({ - 'classifier': ['This field XOR worker_version field must be set to create classifications'], - 'worker_version': ['This field XOR classifier field must be set to create classifications'] - }) - ml_class_names = [ classification['ml_class'] for classification in data['classifications'] ] if len(ml_class_names) != len(set(ml_class_names)): raise ValidationError({ - 'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.'] + 'classifications': ['Duplicated ML classes are not allowed from the same worker version.'] }) return data @@ -513,12 +441,9 @@ class ClassificationsSerializer(serializers.Serializer): ml_classes.update({ml_class.name: ml_class.id for ml_class in new_classes}) - source = validated_data.get('classifier') worker_version = validated_data.get('worker_version') - origin = {'source': source} if source else {'worker_version': worker_version} - # Delete classifications with the same origin - parent.classifications.filter(**origin).delete() + parent.classifications.filter(worker_version=worker_version).delete() Classification.objects.bulk_create([ Classification( @@ -526,7 +451,7 @@ class ClassificationsSerializer(serializers.Serializer): ml_class_id=ml_classes[cl['ml_class']], confidence=cl['confidence'], high_confidence=cl['high_confidence'], - **origin + worker_version=worker_version ) for cl in validated_data['classifications'] ]) diff --git a/arkindex/documents/tests/test_bulk_classification.py b/arkindex/documents/tests/test_bulk_classification.py index a0179856748369d3398e7e1051ae8515fb36f54b..9f6faa953d97197a02ee84343d05f5309c5eba2e 100644 --- a/arkindex/documents/tests/test_bulk_classification.py +++ b/arkindex/documents/tests/test_bulk_classification.py @@ -62,71 +62,6 @@ class TestBulkClassification(FixtureAPITestCase): } ) - def test_bulk_classification_requires_source_xor_worker_version(self): - """ - A classifier data source XOR a worker_version is required to push classifications on an element - """ - wrong_payloads = ( - {'classifier': self.src.slug, 'worker_version': self.worker_version.id}, - {'classifier': None, 'worker_version': None}, - {'classifier': ''}, - {} - ) - self.client.force_login(self.user) - for payload in wrong_payloads: - response = self.client.post( - reverse('api:classification-bulk'), - format='json', - data={ - 'parent': str(self.page.id), - 'classifications': [{'class_name': 'cat', 'confidence': 0.42}], - **payload - } - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - if payload.get('classifier') and payload.get('worker_version'): - self.assertDictEqual(response.json(), { - 'classifier': ['This field XOR worker_version field must be set to create classifications'], - 'worker_version': ['This field XOR classifier field must be set to create classifications'] - }) - - def test_bulk_classification_source(self): - """ - Bulk classifications are created using an existing classifier source - """ - self.client.force_login(self.user) - with self.assertNumQueries(8): - response = self.client.post( - reverse('api:classification-bulk'), - format='json', - data={ - "parent": str(self.page.id), - "classifier": self.src.slug, - "classifications": [{ - "class_name": 'dog', - "confidence": 0.99, - "high_confidence": True - }, { - "class_name": 'cat', - "confidence": 0.42, - }] - } - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertCountEqual( - list(self.page.classifications.values_list( - 'ml_class__name', - 'confidence', - 'high_confidence', - 'source', - 'worker_version' - )), - [ - ('dog', 0.99, True, self.src.id, None), - ('cat', 0.42, False, self.src.id, None) - ], - ) - def test_bulk_classification_worker_version(self): """ Classifications are created and linked to a worker version @@ -255,5 +190,5 @@ class TestBulkClassification(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - 'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.'] + 'classifications': ['Duplicated ML classes are not allowed from the same worker version.'] }) diff --git a/arkindex/documents/tests/test_bulk_element_transcriptions.py b/arkindex/documents/tests/test_bulk_element_transcriptions.py index 0fb7b3ed14cdf5a60365aba2ec6ae66c84c53674..ee81ca3fc470a29667573ede1164a810e283c9e4 100644 --- a/arkindex/documents/tests/test_bulk_element_transcriptions.py +++ b/arkindex/documents/tests/test_bulk_element_transcriptions.py @@ -9,6 +9,7 @@ from arkindex_common.enums import TranscriptionType from arkindex.dataimport.models import WorkerVersion from arkindex.documents.models import Element, Corpus, DataSource from arkindex_common.ml_tool import MLToolType +import uuid class TestBulkElementTranscriptions(FixtureAPITestCase): @@ -27,88 +28,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): cls.private_page = cls.private_corpus.elements.create(type=cls.page.type) cls.worker_version = WorkerVersion.objects.get(worker__slug='reco') - def setUp(self): - self.manual_source = DataSource.objects.create(type=MLToolType.Recognizer, slug='manual', internal=False) - self.manual_transcription = self.line.transcriptions.create( - text='A manual transcription', - source=self.manual_source, - type=TranscriptionType.Line - ) - self.private_transcription = self.private_page.transcriptions.create( - text='PEPE', - type=TranscriptionType.Line, - source=self.manual_source - ) - @patch('arkindex.project.triggers.get_channel_layer') - def test_bulk_transcriptions_with_source(self, get_layer_mock): - """ - Bulk creates a list of element with an attached transcription generated by a source - """ - get_layer_mock.return_value.send = AsyncMock() - self.src.internal = True - self.src.save() - - transcriptions = [ - ([[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'Hello world !', 0.1337), - ([[24, 42], [64, 42], [64, 142], [24, 142], [24, 42]], 'I <3 JavaScript', 0.42), - ] - data = { - 'element_type': 'text_line', - 'transcription_type': 'line', - 'source': self.src.slug, - 'transcriptions': [{ - 'polygon': poly, - 'text': text, - 'score': score - } for poly, text, score in transcriptions] - } - - self.client.force_login(self.internal_user) - with self.assertNumQueries(17): - response = self.client.post( - reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}), - format='json', - data=data - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - created_elts = Element.objects.get_descending(self.page.id).exclude(id=self.line.id) - self.assertEqual(created_elts.count(), 2) - self.assertSetEqual( - set(created_elts.values_list('zone__image_id', flat=True)), - {self.page.zone.image_id} - ) - self.assertListEqual( - [ - (elt.paths.get().ordering, elt.name, elt.zone.polygon.coords) - for elt in created_elts.order_by('name') - ], - [ - (1, '2', ((13, 37), (13, 137), (133, 137), (133, 37), (13, 37))), - (2, '3', ((24, 42), (24, 142), (64, 142), (64, 42), (24, 42))) - ] - ) - self.assertCountEqual( - created_elts.values_list('transcriptions__type', 'transcriptions__text', 'transcriptions__zone', 'transcriptions__source', 'transcriptions__worker_version'), - [ - (TranscriptionType.Line, ('Hello world !'), None, self.src.id, None), - (TranscriptionType.Line, ('I <3 JavaScript'), None, self.src.id, None) - ] - ) - get_layer_mock().send.assert_called_once_with('reindex', { - 'type': 'reindex.start', - 'corpus': None, - 'element': str(self.page.id), - 'entity': None, - 'transcriptions': True, - 'elements': True, - 'entities': False, - 'drop': False, - }) - - @patch('arkindex.project.triggers.get_channel_layer') - def test_bulk_transcriptions_with_worker_version(self, get_layer_mock): + def test_bulk_transcriptions(self, get_layer_mock): """ Bulk creates a list of element with an attached transcription generated by a worker_version """ @@ -177,6 +98,7 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): self.src.internal = True self.src.save() + # Create 100 transcriptions transcriptions = [( [[i, i], [i, i + 20], [i + 20, i + 20], [i + 20, i], [i, i]], str(i / 2), @@ -185,7 +107,7 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): data = { 'element_type': 'text_line', 'transcription_type': 'line', - 'source': self.src.slug, + 'worker_version': str(self.worker_version.id), 'transcriptions': [{ 'polygon': poly, 'text': text, @@ -206,26 +128,37 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(created_elts.count(), 101) - # Each element has a transcription - self.assertEqual(created_elts.annotate(ts_count=Count('transcriptions')).filter(ts_count=1).count(), 101) + self.assertEqual(created_elts.count(), 101) + # Each annotated element has a transcription + self.assertEqual(created_elts.annotate(ts_count=Count('transcriptions')).filter(ts_count=1).count(), 100) @patch('arkindex.project.triggers.get_channel_layer') def test_bulk_transcriptions_similar_zone(self, get_layer_mock): """ Pushing a transcription matching an element type and zone reuses this element + Does not erase present transcriptions, even from the same worker """ get_layer_mock.return_value.send = AsyncMock() self.src.internal = True self.src.save() + # Create a manual transcription on the element + self.line.transcriptions.create( + text='A manual transcription', + worker_version=self.worker_version, + type=TranscriptionType.Line + ) + self.assertEqual(self.line.transcriptions.count(), 1) + transcriptions = [ ([[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'Hello world !', 0.1337), + # Use line zone to create the second transcription (list(self.line.zone.polygon), 'I <3 JavaScript', 0.42), ] data = { 'element_type': 'text_line', 'transcription_type': 'line', - 'source': self.src.slug, + 'worker_version': str(self.worker_version.id), 'transcriptions': [{ 'polygon': poly, 'text': text, @@ -242,9 +175,10 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - created_elts = Element.objects.get_descending(self.page.id) + page_elts = Element.objects.get_descending(self.page.id) # The existing text line has been reused - self.assertEqual(created_elts.count(), 2) + self.assertEqual(page_elts.count(), 2) + # There are now two transcriptions on the line self.assertCountEqual(self.line.transcriptions.values_list('type', 'text', 'zone'), [ (TranscriptionType.Line, 'A manual transcription', None), (TranscriptionType.Line, 'I <3 JavaScript', None) @@ -277,59 +211,9 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - def test_bulk_transcriptions_no_source_no_worker_version(self): - """ - Transcriptions must be created with a source or a worker_version - """ - self.client.force_login(self.user) - response = self.client.post( - reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}), - format='json', - data={ - 'element_type': 'text_line', - 'transcription_type': 'line', - 'transcriptions': [{ - 'polygon': [[13, 37], [133, 37], [133, 137], [13, 137]], - 'text': 'Can I write womething here ?', - 'score': .666 - }] - } - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'source': ['This field XOR worker_version field must be set to create transcriptions'], - 'worker_version': ['This field XOR source field must be set to create transcriptions'] - }) - - def test_bulk_transcriptions_with_source_and_worker_version_returns_error(self): - """ - Transcriptions must be created with a source or a worker_version not both - """ - self.client.force_login(self.user) - response = self.client.post( - reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}), - format='json', - data={ - 'element_type': 'text_line', - 'transcription_type': 'line', - 'source': self.src.slug, - 'worker_version': str(self.worker_version.id), - 'transcriptions': [{ - 'polygon': [[13, 37], [133, 37], [133, 137], [13, 137]], - 'text': 'Can I write womething here ?', - 'score': .666 - }] - } - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'source': ['You can only refer to a DataSource XOR a WorkerVersion on transcriptions'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on transcriptions'] - }) - def test_bulk_transcriptions_non_manual(self): """ - Transcriptions source cannot be set to manual + Worker version is a required field """ self.client.force_login(self.user) response = self.client.post( @@ -338,7 +222,6 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): data={ 'element_type': 'text_line', 'transcription_type': 'line', - 'source': 'manual', 'transcriptions': [{ 'polygon': [[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'text': 'Can I write womething here ?', @@ -348,21 +231,22 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - 'source': ['Transcriptions source slug cannot be set to manual.'] + 'worker_version': ['This field is required.'] }) def test_bulk_transcriptions_wrong_fields_validation(self): """ Briefly assert fields values are validated """ + wrong_version_id = str(uuid.uuid4()) self.client.force_login(self.user) response = self.client.post( reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}), format='json', data={ 'element_type': 'paragraph', + 'worker_version': wrong_version_id, 'transcription_type': 'boulga', - 'source': 'gloubiboulga', 'transcriptions': [{ 'polygon': [[13, 37]], 'text': 'There is a snake in my computer' @@ -372,7 +256,7 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { 'element_type': ['Object with slug=paragraph does not exist.'], - 'source': ["Source with slug 'gloubiboulga' not found"], + 'worker_version': [f'Invalid pk "{wrong_version_id}" - object does not exist.'], 'transcription_type': ['Value is not of type TranscriptionType'], 'transcriptions': [{ 'polygon': ['Ensure this field has at least 3 elements.'], @@ -399,7 +283,7 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): data = { 'element_type': 'text_line', 'transcription_type': 'line', - 'source': self.src.slug, + 'worker_version': str(self.worker_version.id), 'transcriptions': [{ 'polygon': poly, 'text': text, diff --git a/arkindex/documents/tests/test_create_elements.py b/arkindex/documents/tests/test_create_elements.py index 49eb0b0eb817bb9cdefaf483a27375622ee8aa13..232ce91875cb9852b4d2f10cd6274b2278aa97f3 100644 --- a/arkindex/documents/tests/test_create_elements.py +++ b/arkindex/documents/tests/test_create_elements.py @@ -1,9 +1,7 @@ from django.urls import reverse from rest_framework import status -from arkindex_common.ml_tool import MLToolType from arkindex.dataimport.models import WorkerVersion -from arkindex.documents.models import \ - Element, DataSource, Corpus +from arkindex.documents.models import Element, Corpus from arkindex.images.models import ImageServer from arkindex.project.aws import S3FileStatus from arkindex.project.tests import FixtureAPITestCase @@ -134,28 +132,6 @@ class TestCreateElements(FixtureAPITestCase): self.assertEqual(act.name, 'Castle story') self.assertEqual(act.type, self.act_type) - def test_create_element_source(self): - # Create an element with a source - self.client.force_login(self.user) - source = DataSource.objects.create( - type=MLToolType.DLAnalyser, - slug='fairy_tale_detector', - internal=False, - ) - request = self.make_create_request( - name='Castle story', - elt_type='act', - source='fairy_tale_detector', - ) - with self.assertNumQueries(8): - response = self.client.post(**request) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - act = Element.objects.get(id=response.json()['id']) - self.assertEqual(act.name, 'Castle story') - self.assertEqual(act.type, self.act_type) - self.assertEqual(act.source, source) - self.assertEqual(act.worker_version, None) - def test_create_element_worker_version(self): # Create an element with a worker version self.client.force_login(self.user) @@ -173,30 +149,6 @@ class TestCreateElements(FixtureAPITestCase): self.assertEqual(act.source, None) self.assertEqual(act.worker_version, self.worker_version) - def test_create_element_source_and_worker_version_returns_error(self): - # Create an element with a source and a worker version (not allowed) - self.client.force_login(self.user) - DataSource.objects.create( - type=MLToolType.DLAnalyser, - slug='fairy_tale_detector', - internal=False, - ) - request = self.make_create_request( - name='Castle story', - elt_type='act', - source='fairy_tale_detector', - worker_version=str(self.worker_version.id), - ) - with self.assertNumQueries(6): - response = self.client.post(**request) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'source': ['You can only refer to a DataSource XOR a WorkerVersion on an ' - 'element'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on ' - 'an element'] - }) - def test_create_element_polygon(self): # Create an element with a polygon to an existing volume polygon = [[10, 10], [10, 40], [40, 40], [40, 10], [10, 10]] diff --git a/arkindex/documents/tests/test_create_transcriptions.py b/arkindex/documents/tests/test_create_transcriptions.py index 0be51e51529c67e4d87be1a1c2743a7e9a02233e..999fc147aa70bd419a38c9fa09af249de36d07eb 100644 --- a/arkindex/documents/tests/test_create_transcriptions.py +++ b/arkindex/documents/tests/test_create_transcriptions.py @@ -32,9 +32,6 @@ class TestTranscriptionCreate(FixtureAPITestCase): cls.private_corpus.corpus_right.create(user=cls.private_read_user) cls.worker_version = WorkerVersion.objects.get(worker__slug='reco') - def setUp(self): - self.manual_source = DataSource.objects.create(type=MLToolType.Recognizer, slug='manual', internal=False) - def test_create_transcription_require_login(self): response = self.client.post(reverse('api:transcription-create', kwargs={'pk': self.line.id}), format='json') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -47,11 +44,7 @@ class TestTranscriptionCreate(FixtureAPITestCase): response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.private_page.id}), format='json', - data={ - 'type': 'word', - 'text': 'NEKUDOTAYIM', - 'source': 'manual' - } + data={'type': TranscriptionType.Word.value, 'text': 'NEKUDOTAYIM'} ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @@ -61,79 +54,40 @@ class TestTranscriptionCreate(FixtureAPITestCase): response = self.client.post( reverse('api:transcription-create', kwargs={'pk': uuid4()}), format='json', - data={ - 'type': 'word', - 'text': 'NEKUDOTAYIM', - 'source': 'manual' - } + data={'type': TranscriptionType.Word.value, 'text': 'NEKUDOTAYIM'} ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @patch('arkindex.project.triggers.get_channel_layer') - def test_create_transcription_elt_null_zone(self, get_layer_mock): - null_zone_page = self.corpus.elements.create(type=self.page.type) - - self.client.force_login(self.user) - response = self.client.post( - reverse('api:transcription-create', kwargs={'pk': null_zone_page.id}), - format='json', - data={ - 'type': 'word', - 'text': 'NEKUDOTAYIM', - 'source': 'manual' - } - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'element': ['The element has no zone'] - }) - - @patch('arkindex.project.triggers.get_channel_layer') - def test_create_transcription(self, get_layer_mock): + def test_create_manual_transcription(self, get_layer_mock): """ Checks the view creates a manual transcription and runs ES indexing """ get_layer_mock.return_value.send = AsyncMock() - # The view must have the ability to create a manual source if it does not exist - self.manual_source.delete() - self.assertFalse(DataSource.objects.filter(type=MLToolType.Recognizer, slug='manual').exists()) - self.client.force_login(self.user) - with self.assertNumQueries(9): + with self.assertNumQueries(5): response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.line.id}), format='json', - data={ - 'type': 'line', - 'text': 'A perfect day in a perfect place', - 'source': 'manual' - } + data={'type': TranscriptionType.Line.value, 'text': 'A perfect day in a perfect place'} ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) tr = Transcription.objects.get(text='A perfect day in a perfect place') self.assertDictEqual(response.json(), { 'id': str(tr.id), 'score': None, - 'source': { - 'id': str(tr.source.id), - 'internal': False, - 'name': '', - 'revision': '', - 'slug': 'manual', - 'type': 'recognizer' - }, 'text': 'A perfect day in a perfect place', 'type': 'line', + 'source': None, 'worker_version_id': None, 'zone': None }) - new_ts = Transcription.objects.get(text='A perfect day in a perfect place', type=TranscriptionType.Line) + new_ts = Transcription.objects.get(text='A perfect day in a perfect place', type=TranscriptionType.Line.value) self.assertIsNone(new_ts.zone) self.assertIsNone(new_ts.score) - self.assertTrue(DataSource.objects.filter(type=MLToolType.Recognizer, slug='manual').exists()) - self.assertEqual(new_ts.source.slug, 'manual') + self.assertEqual(new_ts.worker_version, None) self.assertTrue(self.line.transcriptions.filter(pk=new_ts.id).exists()) get_layer_mock().send.assert_called_once_with('reindex', { @@ -162,41 +116,36 @@ class TestTranscriptionCreate(FixtureAPITestCase): 'type': 'line', 'polygon': [(0, 0), (42, 0), (42, 42), (0, 42), (0, 0)], 'text': 'SQUARE', - 'source': 'manual' } ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - new_ts = Transcription.objects.get(text='SQUARE', type=TranscriptionType.Line) + new_ts = Transcription.objects.get(text='SQUARE', type=TranscriptionType.Line.value) self.assertEqual(new_ts.zone, None) @patch('arkindex.project.triggers.get_channel_layer') def test_create_duplicated_transcription(self, get_layer_mock): """ - Check the view creates a new transcriptions with a similar manual source + Check the view creates a new manual transcriptions with a similar text and element """ get_layer_mock.return_value.send = AsyncMock() self.client.force_login(self.user) - ts = self.page.transcriptions.create( - text='GLOUBIBOULGA', - type=TranscriptionType.Word, - source=self.manual_source - ) - with self.assertNumQueries(6): + ts = self.line.transcriptions.create(text='GLOUBIBOULGA', type=TranscriptionType.Line.value) + with self.assertNumQueries(5): response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.line.id}), format='json', data={ 'type': 'line', 'text': ts.text, - 'source': 'manual' } ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(Transcription.objects.filter(text=ts.text, source=self.manual_source).count(), 2) - ts.refresh_from_db() - self.assertNotEqual(ts.score, 0.99) + self.assertEqual( + Transcription.objects.filter(element=self.line, text=ts.text, worker_version__isnull=True).count(), + 2 + ) @patch('arkindex.project.triggers.get_channel_layer') def test_create_transcription_wrong_type(self, get_layer_mock): @@ -204,11 +153,7 @@ class TestTranscriptionCreate(FixtureAPITestCase): response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.line.id}), format='json', - data={ - 'type': 'AAAAA', - 'text': 'NEKUDOTAYIM', - 'source': 'manual' - } + data={'type': 'AAAAA', 'text': 'NEKUDOTAYIM'} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { @@ -222,50 +167,15 @@ class TestTranscriptionCreate(FixtureAPITestCase): response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.line.id}), format='json', - data={ - 'type': 'line', - 'text': 'A classy text line', - 'source': 'manual' - } + data={'type': TranscriptionType.Line.value, 'text': 'A classy text line'} ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertFalse(get_layer_mock().send.called) - def test_create_transcription_no_source_no_worker_version(self): - self.client.force_login(self.user) - response = self.client.post( - reverse('api:transcription-create', kwargs={'pk': self.line.id}), - format='json', - data={ - 'type': 'word', - 'text': 'NEKUDOTAYIM', - } - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'source': ['This field XOR worker_version field must be set to create a transcription'], - 'worker_version': ['This field XOR source field must be set to create a transcription'] - }) - - def test_create_transcription_source_and_worker_version_returns_error(self): - self.client.force_login(self.user) - response = self.client.post( - reverse('api:transcription-create', kwargs={'pk': self.line.id}), - format='json', - data={ - 'type': 'word', - 'text': 'NEKUDOTAYIM', - 'source': 'manual', - 'worker_version': str(self.worker_version.id), - } - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'source': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a transcription'] - }) - def test_create_transcription_worker_version_non_internal(self): + """ + An internal user is required to create a transcription with a worker version + """ self.client.force_login(self.user) response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.line.id}), @@ -279,14 +189,16 @@ class TestTranscriptionCreate(FixtureAPITestCase): ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - 'worker_version': [( - 'An internal user is required to create a transcription ' - f'with the worker_version "{self.worker_version.id}"' - )] + 'worker_version': [ + 'An internal user is required to create a transcription refering to a worker_version' + ] }) @patch('arkindex.project.triggers.get_channel_layer') def test_create_transcription_worker_version(self, get_layer_mock): + """ + Creates a transcription with a worker version triggers its indexation on ElasticSearch + """ get_layer_mock.return_value.send = AsyncMock() self.client.force_login(self.internal_user) @@ -332,11 +244,7 @@ class TestTranscriptionCreate(FixtureAPITestCase): response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.line.id}), format='json', - data={ - 'type': 'word', - 'text': 'NEKUDOTAYIM', - 'source': 'manual' - } + data={'type': TranscriptionType.Word.value, 'text': 'NEKUDOTAYIM'} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { @@ -344,51 +252,25 @@ class TestTranscriptionCreate(FixtureAPITestCase): }) def test_manual_transcription_no_allowed_type(self): + """ + Manual transctiptions are forbidden for corpus types with no allowed transcription + """ self.line.type.allowed_transcription = None self.line.type.save() self.client.force_login(self.user) response = self.client.post( reverse('api:transcription-create', kwargs={'pk': self.line.id}), format='json', - data={ - 'type': 'line', - 'text': 'A classy text line', - 'source': 'manual' - } + data={'type': TranscriptionType.Line.value, 'text': 'A classy text line'} ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { 'element': ['The element type does not allow creating a manual transcription'] }) - def test_ml_transcription_non_internal(self): - """ - A transcription created on an element with an internal source - should only be authorized for an internal user (e.g. a worker) - """ - self.src.internal = True - self.src.save() - - self.client.force_login(self.user) - response = self.client.post( - reverse('api:transcription-create', kwargs={'pk': self.line.id}), - format='json', - data={ - 'type': 'word', - 'text': 'SQUARE', - 'source': self.src.slug, - 'score': .99 - } - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {'source': [ - 'An internal user is required to create a transcription with the internal source "test"'] - }) - def test_ml_trancription_required_score(self): """ - A score is required when creating a transcription on an - element with an internal source + A score is required when creating a transcription on an element with an worker version """ self.src.internal = True self.src.save() @@ -400,15 +282,15 @@ class TestTranscriptionCreate(FixtureAPITestCase): data={ 'type': 'word', 'text': 'CIRCLE', - 'source': self.src.slug + 'worker_version': str(self.worker_version.id) } ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {'score': ['This field is required for non manual sources.']}) + self.assertEqual(response.json(), {'score': ['This field is required for transcription with a worker version.']}) def test_ml_trancription_requires_element_zone(self): """ - A worker can publish a transcription on an element with a zone only + A worker cannot create a transcription on an element without a zone """ null_zone_page = self.corpus.elements.create(type=self.page.type) self.src.internal = True @@ -419,46 +301,10 @@ class TestTranscriptionCreate(FixtureAPITestCase): reverse('api:transcription-create', kwargs={'pk': null_zone_page.id}), format='json', data={ - 'type': 'word', + 'type': TranscriptionType.Word.value, 'text': 'ELLIPSE', - 'source': self.src.slug, 'score': .42 } ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {'element': ['The element has no zone']}) - - @patch('arkindex.project.triggers.get_channel_layer') - def test_ml_transcription_create(self, get_layer_mock): - """ - Creates a transcription depending on an element zone - """ - get_layer_mock.return_value.send = AsyncMock() - self.src.internal = True - self.src.save() - - self.client.force_login(self.internal_user) - self.assertTrue(self.internal_user.is_internal) - response = self.client.post( - reverse('api:transcription-create', kwargs={'pk': self.line.id}), - format='json', - data={ - 'type': 'word', - 'text': 'TRIANGLE', - 'source': self.src.slug, - 'score': .42 - } - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(self.line.transcriptions.filter(source=self.src).count(), 1) - - get_layer_mock().send.assert_called_once_with('reindex', { - 'type': 'reindex.start', - 'element': str(self.line.id), - 'corpus': None, - 'entity': None, - 'transcriptions': True, - 'elements': True, - 'entities': False, - 'drop': False, - }) diff --git a/arkindex/documents/tests/test_entities_api.py b/arkindex/documents/tests/test_entities_api.py index 5b1645f530b5e4d0cf6e9c0b0eb4e2788cd70c86..b538a335a94eb484f4327d28300791827f10d2c1 100644 --- a/arkindex/documents/tests/test_entities_api.py +++ b/arkindex/documents/tests/test_entities_api.py @@ -235,7 +235,7 @@ class TestEntitiesAPI(FixtureAPITestCase): 'key': 'value', 'other key': 'other value' }, - 'ner': self.entity_source.slug + 'worker_version': str(self.worker_version.id) } self.client.force_login(self.user) response = self.client.post(reverse('api:entity-create'), data=data, format='json') @@ -243,8 +243,7 @@ class TestEntitiesAPI(FixtureAPITestCase): entity = Entity.objects.get(id=response.json()['id']) self.assertEqual(entity.name, 'entity') self.assertEqual(entity.raw_dates, None) - self.assertEqual(entity.source, self.entity_source) - self.assertEqual(entity.worker_version, None) + self.assertEqual(entity.worker_version, self.worker_version) def test_create_entity_number(self): self.entity_source.internal = True @@ -257,7 +256,7 @@ class TestEntitiesAPI(FixtureAPITestCase): 'key': 'value', 'other key': 'other value' }, - 'ner': self.entity_source.slug + 'worker_version': str(self.worker_version.id) } self.client.force_login(self.user) response = self.client.post(reverse('api:entity-create'), data=data, format='json') @@ -265,8 +264,7 @@ class TestEntitiesAPI(FixtureAPITestCase): entity = Entity.objects.get(id=response.json()['id']) self.assertEqual(entity.name, '300g') self.assertEqual(entity.raw_dates, None) - self.assertEqual(entity.source, self.entity_source) - self.assertEqual(entity.worker_version, None) + self.assertEqual(entity.worker_version, self.worker_version) def test_create_entity_date(self): self.entity_source.internal = True @@ -279,7 +277,7 @@ class TestEntitiesAPI(FixtureAPITestCase): 'key': 'value', 'other key': 'other value' }, - 'ner': self.entity_source.slug + 'worker_version': str(self.worker_version.id) } self.client.force_login(self.user) response = self.client.post(reverse('api:entity-create'), data=data, format='json') @@ -287,8 +285,7 @@ class TestEntitiesAPI(FixtureAPITestCase): entity = Entity.objects.get(id=response.json()['id']) self.assertEqual(entity.name, '1789') self.assertEqual(entity.raw_dates, entity.name) - self.assertEqual(entity.source, self.entity_source) - self.assertEqual(entity.worker_version, None) + self.assertEqual(entity.worker_version, self.worker_version) def test_create_entity_requires_login(self): data = { @@ -304,44 +301,6 @@ class TestEntitiesAPI(FixtureAPITestCase): response = self.client.post(reverse('api:entity-create'), data=data, format='json') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_create_entity_no_source_no_worker_version(self): - data = { - 'name': '1789', - 'type': EntityType.Date.value, - 'corpus': str(self.corpus.id), - 'metas': { - 'key': 'value', - 'other key': 'other value' - }, - } - self.client.force_login(self.user) - response = self.client.post(reverse('api:entity-create'), data=data, format='json') - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'ner': ['This field XOR worker_version field must be set to create an entity'], - 'worker_version': ['This field XOR ner field must be set to create an entity'] - }) - - def test_create_entity_with_source_and_worker_version_returns_error(self): - data = { - 'name': '1789', - 'type': EntityType.Date.value, - 'corpus': str(self.corpus.id), - 'metas': { - 'key': 'value', - 'other key': 'other value' - }, - 'ner': self.entity_source.slug, - 'worker_version': str(self.worker_version.id) - } - self.client.force_login(self.user) - response = self.client.post(reverse('api:entity-create'), data=data, format='json') - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'ner': ['You can only refer to a DataSource XOR a WorkerVersion on an entity'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on an entity'] - }) - def test_create_entity_with_worker_version(self): data = { 'name': '1789', @@ -863,12 +822,15 @@ class TestEntitiesAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @patch('arkindex.documents.api.entities.reindex_start') - def test_entity_create_index(self, reindex_mock): + def test_entity_create_reindex(self, reindex_mock): + """ + Created entities are indexed into ElasticSearch + """ data = { 'name': 'entity', 'type': EntityType.Person.value, 'corpus': str(self.corpus.id), - 'ner': self.entity_source.slug + 'worker_version': str(self.worker_version.id) } self.client.force_login(self.user) response = self.client.post(reverse('api:entity-create'), data=data, format='json') diff --git a/arkindex/documents/tests/test_moderation.py b/arkindex/documents/tests/test_moderation.py index 58aa646b764bcb2e59d426aa305e83c17581abc9..79cd3f7cb5bbf883164ac70dbe5b63cf3b7b791a 100644 --- a/arkindex/documents/tests/test_moderation.py +++ b/arkindex/documents/tests/test_moderation.py @@ -19,38 +19,36 @@ class TestClasses(FixtureAPITestCase): cls.act_type = cls.corpus.types.get(slug='act') cls.element = Element.objects.get(name='Volume 1, page 1v') cls.folder = cls.corpus.elements.get(name='Volume 1') - + cls.worker_version = WorkerVersion.objects.get(worker__slug='dla') + cls.internal_user = User.objects.get_by_natural_key('internal@internal.fr') cls.classifier_source = DataSource.objects.create( type=MLToolType.Classifier, slug='some_classifier', revision='1.3.3.7', internal=False, ) - cls.worker_version = WorkerVersion.objects.get(worker__slug='dla') - cls.internal_user = User.objects.get_by_natural_key('internal@internal.fr') - def _create_classification(self): + def _create_classification_from_source(self): return self.element.classifications.create( source=self.classifier_source, ml_class=self.text, confidence=.5, ) - def test_classification_creation(self): + def test_manual_classification_creation(self): """ - Ensure classification creation works and set auto fields correctly + Creating a manual classification set auto fields correctly """ self.client.force_login(self.user) response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'manual', + 'worker_version': '', }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) classification = self.element.classifications.get() - self.assertEqual(classification.source.type, MLToolType.Classifier) - self.assertEqual(classification.source.slug, 'manual') - self.assertFalse(classification.source.internal) + + self.assertEqual(classification.worker_version, None) self.assertEqual(classification.ml_class, self.text) self.assertEqual(classification.state, ClassificationState.Validated) self.assertEqual(classification.confidence, 1) @@ -59,14 +57,13 @@ class TestClasses(FixtureAPITestCase): 'id': str(classification.id), 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'manual', 'worker_version': None, 'state': ClassificationState.Validated.value, - 'confidence': 1, + 'confidence': 1.0, 'high_confidence': True, }) - def test_classification_exists(self): + def test_manual_classification_exists(self): """ If a classification exists, creation must respond a 400_BAD_REQUEST with an explicit message @@ -74,18 +71,14 @@ class TestClasses(FixtureAPITestCase): self.client.force_login(self.user) request = ( reverse('api:classification-create'), - { - 'element': str(self.element.id), - 'ml_class': str(self.text.id), - 'source': 'manual' - } + {'element': str(self.element.id), 'ml_class': str(self.text.id)} ) response = self.client.post(*request) self.assertEqual(response.status_code, status.HTTP_201_CREATED) response = self.client.post(*request) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { - 'non_field_errors': ['The fields element, source, ml_class must make a unique set.'] + 'non_field_errors': ['The fields element, ml_class must make a unique set.'] }) def test_classification_ignored_params(self): @@ -96,7 +89,6 @@ class TestClasses(FixtureAPITestCase): response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'manual', 'confidence': 0.5, 'high_confidence': False, 'state': 'rejected', @@ -107,7 +99,6 @@ class TestClasses(FixtureAPITestCase): 'id': str(classification.id), 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'manual', 'worker_version': None, 'state': ClassificationState.Validated.value, 'confidence': 1, @@ -122,7 +113,6 @@ class TestClasses(FixtureAPITestCase): response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'manual', }) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), { @@ -131,53 +121,22 @@ class TestClasses(FixtureAPITestCase): ] }) - def test_classification_creation_no_source_no_worker_version(self): - self.client.force_login(self.user) - response = self.client.post(reverse('api:classification-create'), { - 'element': str(self.element.id), - 'ml_class': str(self.text.id), - 'confidence': 0.42, - 'high_confidence': False, - }) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'source': ['This field XOR worker_version field must be set to create a classification'], - 'worker_version': ['This field XOR source field must be set to create a classification'] - }) - - def test_classification_creation_source_and_worker_version_returns_error(self): - self.client.force_login(self.user) - response = self.client.post(reverse('api:classification-create'), { - 'element': str(self.element.id), - 'ml_class': str(self.text.id), - 'source': 'manual', - 'worker_version': str(self.worker_version.id), - 'confidence': 0.42, - 'high_confidence': False, - }) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'source': ['You can only refer to a DataSource XOR a WorkerVersion on a classification'], - 'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a classification'] - }) - def test_classification_non_manual_requires_internal(self): """ - Test creating a classification on a non-manual source requires an internal user + Test creating a classification with a worker version requires the user to be internal """ self.client.force_login(self.user) response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'some_classifier', + 'worker_version': str(self.worker_version.id) }) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.maxDiff = None self.assertDictEqual(response.json(), { - 'source': [ - 'An internal user is required to create a classification with the non-manual source "some_classifier"' - ], - 'confidence': ['This field is required for non manual sources.'], - 'high_confidence': ['This field is required for non manual sources.'], + 'worker_version': ['An internal user is required to create a classification with a worker version.'], + 'confidence': ['This field is required to create a classification with a worker version.'], + 'high_confidence': ['This field is required to create a classification with a worker version.'] }) def test_classification_non_manual(self): @@ -185,39 +144,18 @@ class TestClasses(FixtureAPITestCase): response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'some_classifier', 'confidence': 0.42, 'high_confidence': False, + 'worker_version': str(self.worker_version.id) }) self.assertEqual(response.status_code, status.HTTP_201_CREATED) classification = self.element.classifications.get() - self.assertEqual(classification.source, self.classifier_source) - self.assertEqual(classification.worker_version, None) + self.assertEqual(classification.worker_version, self.worker_version) self.assertEqual(classification.ml_class, self.text) self.assertEqual(classification.state, ClassificationState.Pending) self.assertEqual(classification.confidence, 0.42) self.assertFalse(classification.high_confidence) - def test_classification_creation_worker_version_requires_internal(self): - """ - Test creating a classification on a worker_version requires an internal user - """ - self.client.force_login(self.user) - response = self.client.post(reverse('api:classification-create'), { - 'element': str(self.element.id), - 'ml_class': str(self.text.id), - 'worker_version': str(self.worker_version.id), - }) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertDictEqual(response.json(), { - 'worker_version': [( - 'An internal user is required to create a classification ' - f'with the worker_version "{self.worker_version.id}"' - )], - 'confidence': ['This field is required for non manual sources.'], - 'high_confidence': ['This field is required for non manual sources.'], - }) - def test_classification_creation_worker_version(self): self.client.force_login(self.internal_user) response = self.client.post(reverse('api:classification-create'), { @@ -244,14 +182,13 @@ class TestClasses(FixtureAPITestCase): response = self.client.post(reverse('api:classification-create'), { 'element': str(self.element.id), 'ml_class': str(self.text.id), - 'source': 'some_classifier', + 'worker_version': str(self.worker_version.id), 'confidence': 0, 'high_confidence': False, }) self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.json()) classification = self.element.classifications.get() - self.assertEqual(classification.source, self.classifier_source) - self.assertEqual(classification.worker_version, None) + self.assertEqual(classification.worker_version, self.worker_version) self.assertEqual(classification.ml_class, self.text) self.assertEqual(classification.state, ClassificationState.Pending) self.assertEqual(classification.confidence, 0) @@ -259,7 +196,7 @@ class TestClasses(FixtureAPITestCase): def test_classification_validate(self): self.client.force_login(self.user) - classification = self._create_classification() + classification = self._create_classification_from_source() response = self.client.put(reverse('api:classification-validate', kwargs={'pk': classification.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -287,13 +224,13 @@ class TestClasses(FixtureAPITestCase): self.assertEqual(classification.moderator, self.user) def test_classification_validate_without_permissions(self): - classification = self._create_classification() + classification = self._create_classification_from_source() response = self.client.put(reverse('api:classification-validate', kwargs={'pk': classification.id})) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_classification_reject(self): self.client.force_login(self.user) - classification = self._create_classification() + classification = self._create_classification_from_source() response = self.client.put(reverse('api:classification-reject', kwargs={'pk': classification.id})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -334,13 +271,13 @@ class TestClasses(FixtureAPITestCase): classification.refresh_from_db() def test_classification_reject_without_permissions(self): - classification = self._create_classification() + classification = self._create_classification_from_source() response = self.client.put(reverse('api:classification-reject', kwargs={'pk': classification.id})) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_classification_can_still_be_moderated(self): self.client.force_login(self.user) - classification = self._create_classification() + classification = self._create_classification_from_source() classification.moderator = self.user classification.state = ClassificationState.Validated.value classification.save() diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py index 018290826c90249199473fd40a998e7958752255..4199a71d5a0a83b810e13dad92cf5ac12f6c2892 100644 --- a/arkindex/project/serializer_fields.py +++ b/arkindex/project/serializer_fields.py @@ -3,7 +3,6 @@ from rest_framework import serializers from enum import Enum from uuid import UUID from urllib.parse import quote, unquote -from arkindex_common.ml_tool import MLToolType from arkindex.project.gis import ensure_linear_ring @@ -31,52 +30,6 @@ class EnumField(serializers.ChoiceField): raise serializers.ValidationError('Value is not of type {}'.format(self.enum.__name__)) -class DataSourceSlugField(serializers.CharField): - - def __init__(self, tool_type=None, *args, **kwargs): - super().__init__(*args, **kwargs) - if tool_type is not None: - assert isinstance(tool_type, MLToolType) - self.tool_type = tool_type - - def to_representation(self, obj): - from arkindex.documents.models import DataSource - if isinstance(obj, DataSource): - return super().to_representation(obj.slug) - return super().to_representation(obj) - - def to_internal_value(self, data): - from arkindex.documents.models import DataSource - - slug = super().to_internal_value(data) - if self.tool_type and slug == 'manual': - # Get or create a manual source for this tool. - # This behavior is not RESTful but allows to serialize - # a manual source for any tool. - manual_source, _ = DataSource.objects.get_or_create( - type=self.tool_type, - slug=slug, - defaults={ - 'revision': '', - 'internal': False, - } - ) - return manual_source - - # Pick the most recent source by ordering by revision - # Note that there still might be duplicate sources if there is no tool type filter - # as DataSource only requires unique slugs by type and revision; - # there could be a "sklearn 0.1" classifier and a "sklearn 0.1" DLA. - queryset = DataSource.objects.filter(slug=slug).order_by('-revision') - if self.tool_type: - queryset = queryset.filter(type=self.tool_type) - - source = queryset.first() - if not source: - raise serializers.ValidationError('Source with slug {!r} not found'.format(data)) - return source - - class PointField(serializers.ListField): child = serializers.IntegerField() diff --git a/arkindex/project/tests/test_datasource_slug_field.py b/arkindex/project/tests/test_datasource_slug_field.py deleted file mode 100644 index f3930922a07dc3d67c602262da101122c9634ec4..0000000000000000000000000000000000000000 --- a/arkindex/project/tests/test_datasource_slug_field.py +++ /dev/null @@ -1,70 +0,0 @@ -from django.test import TestCase -from rest_framework.serializers import ValidationError -from arkindex_common.ml_tool import MLToolType -from arkindex.project.serializer_fields import DataSourceSlugField -from arkindex.documents.models import DataSource - - -class TestDataSourceSlugField(TestCase): - - @classmethod - def setUpTestData(cls): - super().setUpTestData() - cls.external_source = DataSource.objects.create( - name='External sauce', - type=MLToolType.Classifier, - slug='sauce', - internal=False, - ) - cls.internal_source = DataSource.objects.create( - name='Internal sauce', - type=MLToolType.Recognizer, - slug='better_sauce', - revision='13.37', - internal=True, - ) - - def test_init(self): - with self.assertRaises(AssertionError): - DataSourceSlugField(tool_type='oops') - - def test_to_representation(self): - field = DataSourceSlugField() - self.assertEqual(field.to_representation('something'), 'something') - self.assertEqual(field.to_representation(self.external_source), 'sauce') - - def test_to_internal_value(self): - field = DataSourceSlugField() - self.assertEqual(field.to_internal_value('sauce'), self.external_source) - with self.assertRaises(ValidationError, msg="Source with slug 'mayo' not found"): - field.to_internal_value('mayo') - - def test_to_internal_value_latest(self): - """ - Test DataSourceSlugField.to_internal_value picks the latest source by revision - """ - field = DataSourceSlugField() - DataSource.objects.create( - name='Spoiled mayonnaise', - type=MLToolType.Recognizer, - slug='mayo', - revision='42', - internal=False, - ) - mayo = DataSource.objects.create( - name='Edible mayonnaise', - type=MLToolType.Recognizer, - slug='mayo', - revision='43', - internal=False, - ) - self.assertEqual(field.to_internal_value('mayo'), mayo) - - def test_to_internal_value_tool_type_not_found(self): - - field = DataSourceSlugField(tool_type=MLToolType.Classifier) - with self.assertRaises(ValidationError, msg="Source with slug 'better_sauce' not found"): - field.to_internal_value('better_sauce') - - field = DataSourceSlugField(tool_type=MLToolType.Recognizer) - self.assertEqual(field.to_internal_value('better_sauce'), self.internal_source)