Skip to content
Snippets Groups Projects
Commit ba179c45 authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

Prevent using elements without zones in ML APIs

parent 1dc6bbe9
No related branches found
No related tags found
No related merge requests found
......@@ -66,8 +66,12 @@ class TranscriptionCreateSerializer(serializers.Serializer):
"""
Allows for insertion of new transcriptions and zones
"""
element = serializers.PrimaryKeyRelatedField(queryset=Element.objects.all())
source = serializers.PrimaryKeyRelatedField(queryset=DataSource.objects.filter(type=MLToolType.Recognizer))
element = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.filter(zone__isnull=False),
)
source = serializers.PrimaryKeyRelatedField(
queryset=DataSource.objects.filter(type=MLToolType.Recognizer),
)
polygon = serializers.ListField(
child=serializers.ListField(
child=serializers.IntegerField(),
......@@ -85,6 +89,7 @@ class TranscriptionCreateSerializer(serializers.Serializer):
if not self.context.get('request'): # May be None when generating an OpenAPI schema or using from a REPL
return
self.fields['element'].queryset = Element.objects.filter(
zone__isnull=False,
corpus__in=Corpus.objects.writable(self.context['request'].user),
)
......@@ -113,7 +118,9 @@ class TranscriptionsSerializer(serializers.Serializer):
Allows for insertion of new transcriptions and zones
in Bulk (uses serializer above) on a common parent
"""
parent = serializers.PrimaryKeyRelatedField(queryset=Element.objects.all())
parent = serializers.PrimaryKeyRelatedField(
queryset=Element.objects.filter(zone__isnull=False),
)
recognizer = MLToolField(MLToolType.Recognizer)
transcriptions = TranscriptionBulkSerializer(many=True, allow_empty=False)
......@@ -122,6 +129,7 @@ class TranscriptionsSerializer(serializers.Serializer):
if not self.context.get('request'): # May be None when generating an OpenAPI schema or using from a REPL
return
self.fields['parent'].queryset = Element.objects.filter(
zone__isnull=False,
corpus__in=Corpus.objects.writable(self.context['request'].user),
)
......@@ -139,7 +147,9 @@ class ClassificationsSerializer(serializers.Serializer):
"""
Insert N classifications on a single element from a single ML tool
"""
parent = serializers.PrimaryKeyRelatedField(queryset=Page.objects.all())
parent = serializers.PrimaryKeyRelatedField(
queryset=Page.objects.filter(zone__isnull=False),
)
classifier = MLToolField(MLToolType.Classifier)
classifications = ClassificationBulkSerializer(many=True, allow_empty=False)
......@@ -148,5 +158,6 @@ class ClassificationsSerializer(serializers.Serializer):
if not self.context.get('request'): # May be None when generating an OpenAPI schema or using from a REPL
return
self.fields['parent'].queryset = Page.objects.filter(
zone__isnull=False,
corpus__in=Corpus.objects.writable(self.context['request'].user),
)
......@@ -16,6 +16,7 @@ class TestTranscriptionCreate(FixtureAPITestCase):
def setUpTestData(cls):
super().setUpTestData()
cls.page = Page.objects.get(corpus=cls.corpus, zone__image__path='img1')
cls.vol = cls.corpus.elements.get(name='Volume 1')
cls.src = DataSource.objects.get(slug='test')
def test_require_login(self):
......@@ -136,7 +137,6 @@ class TestTranscriptionCreate(FixtureAPITestCase):
self.client.force_login(self.user)
response = self.client.post(reverse('api:transcription-bulk'), format='json', data={
"parent": str(self.page.id),
"image": str(self.page.zone.image.id),
"recognizer": self.src.slug,
"transcriptions": [
{
......@@ -167,3 +167,28 @@ class TestTranscriptionCreate(FixtureAPITestCase):
# Indexer called
self.assertEqual(indexer.return_value.run_index.call_count, 2)
@patch('arkindex.project.serializer_fields.MLTool.get')
def test_bulk_transcription_no_zone(self, ml_get_mock):
self.assertIsNone(self.vol.zone)
ml_get_mock.return_value.type = self.src.type
ml_get_mock.return_value.slug = self.src.slug
ml_get_mock.return_value.version = self.src.revision
self.src.internal = True
self.src.save()
self.client.force_login(self.user)
response = self.client.post(reverse('api:transcription-bulk'), format='json', data={
"parent": str(self.vol.id),
"recognizer": self.src.slug,
"transcriptions": [
{
"type": "word",
"polygon": [(0, 0), (100, 0), (100, 100), (0, 100), (0, 0)],
"text": "NEKUDOTAYIM",
"score": 0.83,
},
]
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment