Skip to content
Snippets Groups Projects
Commit 24ef418a authored by Valentin Rigal's avatar Valentin Rigal
Browse files

CreateClassifications

parent 2eaef534
No related branches found
No related tags found
No related merge requests found
......@@ -407,7 +407,6 @@ class ClassificationsSerializer(serializers.Serializer):
# The real queryset is set in __init__
queryset=Element.objects.none(),
)
classifier = DataSourceSlugField(tool_type=MLToolType.Classifier, default=None)
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None)
classifications = ClassificationBulkSerializer(many=True, allow_empty=False)
......@@ -425,19 +424,13 @@ class ClassificationsSerializer(serializers.Serializer):
def validate(self, data):
data = super().validate(data)
if not (data['classifier'] is None) ^ (data['worker_version'] is None):
raise ValidationError({
'classifier': ['This field XOR worker_version field must be set to create classifications'],
'worker_version': ['This field XOR classifier field must be set to create classifications']
})
ml_class_names = [
classification['ml_class']
for classification in data['classifications']
]
if len(ml_class_names) != len(set(ml_class_names)):
raise ValidationError({
'classifications': ['Duplicated ML classes are not allowed from the same source or worker version.']
'classifications': ['Duplicated ML classes are not allowed from the same worker version.']
})
return data
......@@ -467,12 +460,9 @@ class ClassificationsSerializer(serializers.Serializer):
ml_classes.update({ml_class.name: ml_class.id for ml_class in new_classes})
source = validated_data.get('classifier')
worker_version = validated_data.get('worker_version')
origin = {'source': source} if source else {'worker_version': worker_version}
# Delete classifications with the same origin
parent.classifications.filter(**origin).delete()
parent.classifications.filter(worker_version=worker_version).delete()
Classification.objects.bulk_create([
Classification(
......@@ -480,7 +470,7 @@ class ClassificationsSerializer(serializers.Serializer):
ml_class_id=ml_classes[cl['ml_class']],
confidence=cl['confidence'],
high_confidence=cl['high_confidence'],
**origin
worker_version=worker_version
)
for cl in validated_data['classifications']
])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment