Skip to content
Snippets Groups Projects
Commit 4076f315 authored by Eva Bardou's avatar Eva Bardou Committed by Bastien Abadie
Browse files

Add a FK on Classification towards a WorkerVersion

parent 98106751
No related branches found
No related tags found
No related merge requests found
......@@ -409,7 +409,7 @@ class ClassificationCreate(CreateAPIView):
}
def perform_create(self, serializer):
if serializer.validated_data['source'].slug == 'manual':
if serializer.validated_data['source'] and serializer.validated_data['source'].slug == 'manual':
# A manual classification is immediately valid
serializer.save(
moderator=self.request.user,
......
......@@ -27,4 +27,30 @@ class Migration(migrations.Migration):
name='source',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='transcriptions', to='documents.DataSource'),
),
migrations.AddField(
model_name='classification',
name='worker_version',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='classifications', to='dataimport.WorkerVersion'),
),
migrations.AlterField(
model_name='classification',
name='source',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='classifications', to='documents.DataSource'),
),
migrations.AlterUniqueTogether(
name='classification',
unique_together=set(),
),
migrations.AddConstraint(
model_name='classification',
constraint=models.UniqueConstraint(condition=models.Q(worker_version_id__isnull=True), fields=('element', 'ml_class', 'source'), name='classification_unique_source'),
),
migrations.AddConstraint(
model_name='classification',
constraint=models.UniqueConstraint(condition=models.Q(source_id__isnull=True), fields=('element', 'ml_class', 'worker_version'), name='classification_unique_worker_version'),
),
migrations.AddConstraint(
model_name='classification',
constraint=models.CheckConstraint(check=models.Q(models.Q(('source_id__isnull', False), ('worker_version_id__isnull', True)), models.Q(('source_id__isnull', True), ('worker_version_id__isnull', False)), _connector='OR'), name='classification_source_xor_workerversion'),
),
]
......@@ -2,6 +2,7 @@ from django.db import models, transaction
from django.conf import settings
from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.fields import HStoreField
from django.db.models import Q
from django.utils.functional import cached_property
from django.core.exceptions import ValidationError
from enumfields import EnumField, Enum
......@@ -541,6 +542,15 @@ class Classification(models.Model):
DataSource,
on_delete=models.CASCADE,
related_name='classifications',
null=True,
blank=True,
)
worker_version = models.ForeignKey(
'dataimport.WorkerVersion',
on_delete=models.CASCADE,
related_name='classifications',
null=True,
blank=True,
)
moderator = models.ForeignKey(
'users.User',
......@@ -560,9 +570,22 @@ class Classification(models.Model):
confidence = models.FloatField(null=True, blank=True)
class Meta:
unique_together = (
('element', 'source', 'ml_class'),
)
constraints = [
models.UniqueConstraint(
fields=['element', 'ml_class', 'source'],
name='classification_unique_source',
condition=Q(worker_version_id__isnull=True),
),
models.UniqueConstraint(
fields=['element', 'ml_class', 'worker_version'],
name='classification_unique_worker_version',
condition=Q(source_id__isnull=True),
),
models.CheckConstraint(
check=Q(source_id__isnull=False, worker_version_id__isnull=True) | Q(source_id__isnull=True, worker_version_id__isnull=False),
name='classification_source_xor_workerversion',
)
]
class AllowedMetaData(models.Model):
......
from enum import Enum
from django.conf import settings
from rest_framework import serializers
from rest_framework.validators import UniqueTogetherValidator
from rest_framework.exceptions import ValidationError
from arkindex_common.ml_tool import MLToolType
from arkindex_common.enums import TranscriptionType
......@@ -96,7 +97,8 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
"""
Serializer to create a single classification, defaulting to manual
"""
source = DataSourceSlugField(tool_type=MLToolType.Classifier)
source = DataSourceSlugField(tool_type=MLToolType.Classifier, default=None)
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None)
confidence = serializers.FloatField(
min_value=0,
max_value=1,
......@@ -118,9 +120,20 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
'element',
'ml_class',
'source',
'worker_version',
'confidence',
'high_confidence',
)
validators = [
UniqueTogetherValidator(
queryset=Classification.objects.filter(worker_version_id__isnull=True),
fields=['element', 'source', 'ml_class']
),
UniqueTogetherValidator(
queryset=Classification.objects.filter(source_id__isnull=True),
fields=['element', 'worker_version', 'ml_class']
)
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -135,25 +148,44 @@ class ClassificationCreateSerializer(serializers.ModelSerializer):
def validate(self, data):
# Do not check data for manual classifications
# Note that (source, class, element) unicity is already checked by DRF
slug = data.get('source').slug
if slug != 'manual':
errors = {}
# Additional validation for transcriptions with an internal source
if not data.get('confidence'):
errors['confidence'] = ['This field is required for non manual sources.']
if data.get('high_confidence') is None:
errors['high_confidence'] = ['This field is required for non manual sources.']
errors = {}
user = self.context['request'].user
source = data.get('source')
worker_version = data.get('worker_version')
if not source and not worker_version:
raise ValidationError({
'source': ['This field XOR worker_version field must be set to create a classification'],
'worker_version': ['This field XOR source field must be set to create a classification']
})
elif source and worker_version:
raise ValidationError({
'source': ['You can only refer to a DataSource XOR a WorkerVersion on a classification'],
'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a classification']
})
elif source:
slug = data.get('source').slug
if slug == 'manual':
return data
user = self.context['request'].user
if not user or not user.is_internal:
errors['source'] = [
'An internal user is required to create a classification with '
f'the non-manual source "{slug}"'
]
elif worker_version and (not user or not user.is_internal):
errors['worker_version'] = [
'An internal user is required to create a classification with '
f'the worker_version "{worker_version.id}"'
]
# Additional validation for transcriptions with an internal source
if not data.get('confidence'):
errors['confidence'] = ['This field is required for non manual sources.']
if data.get('high_confidence') is None:
errors['high_confidence'] = ['This field is required for non manual sources.']
if errors:
raise ValidationError(errors)
if errors:
raise ValidationError(errors)
return data
......
......@@ -2,6 +2,7 @@ from django.test import override_settings
from django.urls import reverse
from rest_framework import status
from arkindex.dataimport.models import Worker, WorkerVersion
from arkindex.documents.models import \
ClassificationState, DataSource, MLClass, Element, Corpus, Classification, MLToolType
from arkindex.project.tests import FixtureAPITestCase
......@@ -25,6 +26,9 @@ class TestClasses(FixtureAPITestCase):
revision='1.3.3.7',
internal=False,
)
cls.creds = cls.user.credentials.get()
cls.repo = cls.creds.repos.get()
cls.rev = cls.repo.revisions.get()
def _create_classification(self):
return self.element.classifications.create(
......@@ -109,6 +113,47 @@ class TestClasses(FixtureAPITestCase):
]
})
def test_classification_creation_no_source_no_worker_version(self):
self.client.force_login(self.user)
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
'confidence': 0.42,
'high_confidence': False,
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'source': ['This field XOR worker_version field must be set to create a classification'],
'worker_version': ['This field XOR source field must be set to create a classification']
})
def test_classification_creation_source_and_worker_version_returns_error(self):
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
self.client.force_login(self.user)
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
'source': 'manual',
'worker_version': str(version.id),
'confidence': 0.42,
'high_confidence': False,
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'source': ['You can only refer to a DataSource XOR a WorkerVersion on a classification'],
'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on a classification']
})
def test_classification_non_manual_requires_internal(self):
"""
Test creating a classification on a non-manual source requires an internal user
......@@ -143,6 +188,69 @@ class TestClasses(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
classification = self.element.classifications.get()
self.assertEqual(classification.source, self.classifier_source)
self.assertEqual(classification.worker_version, None)
self.assertEqual(classification.ml_class, self.text)
self.assertEqual(classification.state, ClassificationState.Pending)
self.assertEqual(classification.confidence, 0.42)
self.assertFalse(classification.high_confidence)
def test_classification_creation_worker_version_requires_internal(self):
"""
Test creating a classification on a worker_version requires an internal user
"""
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
self.client.force_login(self.user)
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
'worker_version': str(version.id),
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'worker_version': [
f'An internal user is required to create a classification with the worker_version "{version.id}"'
],
'confidence': ['This field is required for non manual sources.'],
'high_confidence': ['This field is required for non manual sources.'],
})
def test_classification_creation_worker_version(self):
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
user = User.objects.create_user('internal@address.com')
user.is_internal = True
user.save()
self.client.force_login(user)
response = self.client.post(reverse('api:classification-create'), {
'element': str(self.element.id),
'ml_class': str(self.text.id),
'worker_version': str(version.id),
'confidence': 0.42,
'high_confidence': False,
})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
classification = self.element.classifications.get()
self.assertEqual(classification.source, None)
self.assertEqual(classification.worker_version, version)
self.assertEqual(classification.ml_class, self.text)
self.assertEqual(classification.state, ClassificationState.Pending)
self.assertEqual(classification.confidence, 0.42)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment