From b2c864463ca693439627d2b0dc1af068d2106139 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Tue, 16 Apr 2019 09:58:27 +0000 Subject: [PATCH] Proper polygon handling in the API --- arkindex/documents/serializers/ml.py | 20 +----- arkindex/images/serializers.py | 11 +--- arkindex/project/polygon.py | 11 +++- arkindex/project/serializer_fields.py | 27 ++++++++ .../project/tests/test_drf_polygonfield.py | 64 +++++++++++++++++++ 5 files changed, 105 insertions(+), 28 deletions(-) create mode 100644 arkindex/project/tests/test_drf_polygonfield.py diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 5482b0d7c8..0978bd1e2b 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -1,6 +1,6 @@ from rest_framework import serializers from arkindex_common.ml_tool import MLToolType -from arkindex.project.serializer_fields import EnumField, MLToolField +from arkindex.project.serializer_fields import EnumField, MLToolField, PolygonField from arkindex.documents.models import \ Corpus, Element, Page, Transcription, TranscriptionType, DataSource, Classification from arkindex.images.serializers import ZoneSerializer @@ -72,14 +72,7 @@ class TranscriptionCreateSerializer(serializers.Serializer): source = serializers.PrimaryKeyRelatedField( queryset=DataSource.objects.filter(type=MLToolType.Recognizer), ) - polygon = serializers.ListField( - child=serializers.ListField( - child=serializers.IntegerField(), - min_length=2, - max_length=2 - ), - min_length=3 - ) + polygon = PolygonField() text = serializers.CharField() score = serializers.FloatField(min_value=0, max_value=1) type = EnumField(TranscriptionType) @@ -100,14 +93,7 @@ class TranscriptionBulkSerializer(serializers.Serializer): in Bulk (used by serializer below) Note: no element ! """ - polygon = serializers.ListField( - child=serializers.ListField( - child=serializers.IntegerField(), - min_length=2, - max_length=2 - ), - min_length=3 - ) + polygon = PolygonField() text = serializers.CharField() score = serializers.FloatField(min_value=0, max_value=1) type = EnumField(TranscriptionType) diff --git a/arkindex/images/serializers.py b/arkindex/images/serializers.py index bddd22a9aa..9ee0781e7e 100644 --- a/arkindex/images/serializers.py +++ b/arkindex/images/serializers.py @@ -1,22 +1,17 @@ from rest_framework import serializers +from arkindex.project.serializer_fields import PolygonField from arkindex.images.models import Image, Zone class ZoneSerializer(serializers.ModelSerializer): - x = serializers.IntegerField(source='polygon.x') - y = serializers.IntegerField(source='polygon.y') - width = serializers.IntegerField(source='polygon.width') - height = serializers.IntegerField(source='polygon.height') + polygon = PolygonField() image_url = serializers.URLField(source='image.url') class Meta: model = Zone fields = ( 'id', - 'x', - 'y', - 'width', - 'height', + 'polygon', 'url', 'image_url', ) diff --git a/arkindex/project/polygon.py b/arkindex/project/polygon.py index 5087bba4f2..6b8184c58f 100644 --- a/arkindex/project/polygon.py +++ b/arkindex/project/polygon.py @@ -44,7 +44,7 @@ class Polygon(collections.abc.MutableSequence): A hashable Polygon in-memory ''' def __init__(self, points): - assert len(points) > 0 + assert len(points) > 1 if not all(isinstance(point, Point) for point in points): points = [Point(*p) for p in points] @@ -60,11 +60,16 @@ class Polygon(collections.abc.MutableSequence): * from a direct polygon as list of tuples * from x,y + width,height ''' + if isinstance(data, list): + return Polygon(data) assert isinstance(data, dict) # Direct usage - if 'polygon' in data and isinstance(data['polygon'], list): - return Polygon(data['polygon']) + if 'polygon' in data: + if isinstance(data['polygon'], list): + return Polygon(data['polygon']) + elif isinstance(data['polygon'], Polygon): + return data['polygon'] # Build from coords for k in ('x', 'y', 'width', 'height'): diff --git a/arkindex/project/serializer_fields.py b/arkindex/project/serializer_fields.py index c3d3dd98ed..6c3df8d5c7 100644 --- a/arkindex/project/serializer_fields.py +++ b/arkindex/project/serializer_fields.py @@ -2,6 +2,7 @@ from django.conf import settings from rest_framework import serializers from enum import Enum from arkindex_common.ml_tool import MLTool, MLToolType +from arkindex.project.polygon import Point, Polygon from arkindex.project.tools import elasticsearch_escape @@ -50,3 +51,29 @@ class SearchTermsField(serializers.CharField): """ def to_internal_value(self, query_terms): return elasticsearch_escape(query_terms) + + +class PointField(serializers.ListField): + child = serializers.IntegerField() + min_length = 2 + max_length = 2 + + def to_representation(self, point): + return [point.x, point.y] + + def to_internal_value(self, coords): + try: + return Point(*super().to_internal_value(coords)) + except (AssertionError, TypeError, ValueError) as e: + raise serializers.ValidationError(str(e)) + + +class PolygonField(serializers.ListField): + child = PointField() + min_length = 3 + + def to_internal_value(self, data): + try: + return Polygon(super().to_internal_value(data)) + except (AssertionError, ValueError) as e: + raise serializers.ValidationError(str(e)) diff --git a/arkindex/project/tests/test_drf_polygonfield.py b/arkindex/project/tests/test_drf_polygonfield.py new file mode 100644 index 0000000000..be6610ce49 --- /dev/null +++ b/arkindex/project/tests/test_drf_polygonfield.py @@ -0,0 +1,64 @@ +from unittest import TestCase +from arkindex.project.polygon import Point, Polygon +from arkindex.project.serializer_fields import PointField, PolygonField +from rest_framework.serializers import ValidationError + + +class TestPolygonSerializerField(TestCase): + + def test_pointfield_to_json(self): + self.assertListEqual( + PointField().to_representation(Point(13, 37)), + [13, 37], + ) + + def test_json_to_pointfield(self): + self.assertEqual( + PointField().to_internal_value([13, 37]), + Point(13, 37), + ) + + def test_bad_json_to_pointfield(self): + with self.assertRaises(ValidationError): + PointField().to_internal_value([]) + with self.assertRaises(ValidationError): + PointField().to_internal_value([1, 2, 3]) + with self.assertRaises(ValidationError): + PointField().to_internal_value(['a', 5]) + + def test_polygonfield_to_json(self): + self.assertListEqual( + PolygonField().to_representation(Polygon([ + Point(10, 20), + Point(30, 40), + Point(-40, 50), + ])), + [ + [10, 20], + [30, 40], + [-40, 50], + [10, 20], + ] + ) + + def test_json_to_polygonfield(self): + self.assertEqual( + PolygonField().to_internal_value([ + [10, 20], + [30, 40], + [-40, 50], + ]), Polygon([ + Point(10, 20), + Point(30, 40), + Point(-40, 50), + Point(10, 20), + ]) + ) + + def test_bad_json_to_polygonfield(self): + with self.assertRaises(ValidationError): + PolygonField().to_internal_value([]) + with self.assertRaises(ValidationError): + PolygonField().to_internal_value([1, 2, 3]) + with self.assertRaises(ValidationError): + PolygonField().to_internal_value([[1, 2]]) -- GitLab