From 4bc68933016ee94a7bf7b8cf0a3723d073d73441 Mon Sep 17 00:00:00 2001
From: Valentin Rigal <rigal@teklia.com>
Date: Mon, 9 Dec 2024 09:11:45 +0000
Subject: [PATCH] Return a HTTP 400 for too long names using
 CreateTranscriptionEntities and CreateEntity

---
 arkindex/documents/api/entities.py            | 66 +++++++++++++------
 arkindex/documents/serializers/entities.py    | 13 +++-
 .../tests/test_bulk_transcription_entities.py | 22 +++++++
 arkindex/documents/tests/test_entities_api.py | 17 +++++
 4 files changed, 97 insertions(+), 21 deletions(-)

diff --git a/arkindex/documents/api/entities.py b/arkindex/documents/api/entities.py
index 68234eeb51..7c60ae8f43 100644
--- a/arkindex/documents/api/entities.py
+++ b/arkindex/documents/api/entities.py
@@ -3,8 +3,10 @@ from textwrap import dedent
 from uuid import UUID
 
 from django.core.exceptions import ValidationError as DjangoValidationError
+from django.db.utils import OperationalError
 from django.shortcuts import get_object_or_404
 from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema, extend_schema_view
+from psycopg2.errors import ProgramLimitExceeded
 from rest_framework import permissions, serializers, status
 from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError
 from rest_framework.generics import CreateAPIView, ListAPIView, RetrieveUpdateDestroyAPIView
@@ -199,6 +201,41 @@ class EntityCreate(CreateAPIView):
     permission_classes = (IsVerified, )
     serializer_class = EntityCreateSerializer
 
+    def handle_creation(
+        self,
+        *,
+        use_existing,
+        name,
+        type,
+        corpus,
+        metas,
+        worker_version_id,
+        worker_run,
+    ):
+        status_code = status.HTTP_201_CREATED
+        if use_existing:
+            entity, created = Entity.objects.get_or_create(name=name, corpus=corpus, type=type, defaults={
+                "metas": metas,
+                "worker_version_id": worker_version_id,
+                "worker_run": worker_run
+            })
+            # When using the "use_existing" option, we return a 200_OK instead of a 201_CREATED status code
+            if not created:
+                status_code = status.HTTP_200_OK
+        else:
+            entity = Entity.objects.create(
+                name=name,
+                type=type,
+                corpus=corpus,
+                metas=metas,
+                worker_version_id=worker_version_id,
+                worker_run=worker_run
+            )
+
+        entity_serializer = EntitySerializer(entity)
+        headers = self.get_success_headers(entity_serializer.data)
+        return Response(entity_serializer.data, status=status_code, headers=headers)
+
     def create(self, request, *args, **kwargs):
         # Overriding create in order to return EntitySerializer, not EntityCreateSerializer
         serializer = self.get_serializer(data=request.data)
@@ -211,30 +248,21 @@ class EntityCreate(CreateAPIView):
         worker_run = serializer.validated_data["worker_run"]
         worker_version_id = worker_run.version_id if worker_run else None
 
-        # When using the "use_existing" option, we return a 200_OK instead of a 201_CREATED status code
-        if request.data.get("use_existing"):
-            entity, created = Entity.objects.get_or_create(name=name, corpus=corpus, type=type, defaults={
-                "metas": metas,
-                "worker_version_id": worker_version_id,
-                "worker_run": worker_run
-            })
-            entity = EntitySerializer(entity)
-            if created:
-                status_code = status.HTTP_201_CREATED
-            else:
-                status_code = status.HTTP_200_OK
-        else:
-            entity = EntitySerializer(Entity.objects.create(
+        try:
+            return self.handle_creation(
+                use_existing=request.data.get("use_existing"),
                 name=name,
                 type=type,
                 corpus=corpus,
                 metas=metas,
                 worker_version_id=worker_version_id,
-                worker_run=worker_run
-            ))
-            status_code = status.HTTP_201_CREATED
-        headers = self.get_success_headers(serializer.data)
-        return Response(entity.data, status=status_code, headers=headers)
+                worker_run=worker_run,
+            )
+        except OperationalError as e:
+            if isinstance(getattr(e, "__cause__", None), ProgramLimitExceeded):
+                # As the max length is dynamic and depending on content, we cannot just limit on a specific length
+                raise ValidationError({"name": ["Value is too long for this field."]})
+            raise e
 
 
 @extend_schema_view(post=extend_schema(
diff --git a/arkindex/documents/serializers/entities.py b/arkindex/documents/serializers/entities.py
index 2b5d3fc2be..bcc2778368 100644
--- a/arkindex/documents/serializers/entities.py
+++ b/arkindex/documents/serializers/entities.py
@@ -2,6 +2,8 @@ from collections import defaultdict
 from textwrap import dedent
 
 from django.db import transaction
+from django.db.utils import OperationalError
+from psycopg2.errors import ProgramLimitExceeded
 from rest_framework import serializers
 from rest_framework.exceptions import ValidationError
 
@@ -373,7 +375,7 @@ class TranscriptionEntitiesBulkSerializer(serializers.Serializer):
 
     @transaction.atomic
     def save(self):
-        entities = Entity.objects.bulk_create([
+        entities_to_create = [
             Entity(
                 corpus=self.transcription.element.corpus,
                 name=item["name"],
@@ -383,7 +385,14 @@ class TranscriptionEntitiesBulkSerializer(serializers.Serializer):
                 worker_version_id=self.validated_data["worker_run"].version_id,
             )
             for item in self.validated_data["entities"]
-        ])
+        ]
+        try:
+            entities = Entity.objects.bulk_create(entities_to_create)
+        except OperationalError as e:
+            if isinstance(getattr(e, "__cause__", None), ProgramLimitExceeded):
+                # As the max length is dynamic and depending on content, we cannot just limit on a specific length
+                raise ValidationError({"entities": {"name": ["Value is too long for this field."]}})
+            raise e
 
         transcription_entities = TranscriptionEntity.objects.bulk_create([
             TranscriptionEntity(
diff --git a/arkindex/documents/tests/test_bulk_transcription_entities.py b/arkindex/documents/tests/test_bulk_transcription_entities.py
index dc0b505c64..8c60037671 100644
--- a/arkindex/documents/tests/test_bulk_transcription_entities.py
+++ b/arkindex/documents/tests/test_bulk_transcription_entities.py
@@ -416,3 +416,25 @@ class TestBulkTranscriptionEntities(FixtureAPITestCase):
             ),
             [("Guzzlord", self.person_ent_type.id, 6, 9, .42, self.local_worker_run.id, self.local_worker_run.id)],
         )
+
+    def test_create_name_too_long(self):
+        self.client.force_login(self.user)
+        response = self.client.post(
+            reverse("api:transcription-entities-bulk", kwargs={"pk": str(self.transcription.id)}),
+            data={
+                "entities": [
+                    {
+                        "name": "A" * 500000,
+                        "type_id": str(self.person_ent_type.id),
+                        "offset": 6,
+                        "length": 9,
+                        "confidence": .42,
+                    },
+                ],
+                "worker_run_id": str(self.local_worker_run.id),
+            }
+        )
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.json(), {
+            "entities": {"name": ["Value is too long for this field."]},
+        })
diff --git a/arkindex/documents/tests/test_entities_api.py b/arkindex/documents/tests/test_entities_api.py
index d4ab32bfe2..32b4dffa23 100644
--- a/arkindex/documents/tests/test_entities_api.py
+++ b/arkindex/documents/tests/test_entities_api.py
@@ -534,6 +534,23 @@ class TestEntitiesAPI(FixtureAPITestCase):
         self.assertEqual(entity.worker_version_id, local_worker_run.version_id)
         self.assertEqual(entity.worker_run, local_worker_run)
 
+    def test_create_entity_name_too_long(self):
+        self.client.force_login(self.user)
+        with self.assertNumQueries(6):
+            response = self.client.post(
+                reverse("api:entity-create"),
+                data={
+                    "name": "A" * 500000,
+                    "type_id": str(self.person_type.id),
+                    "corpus": str(self.corpus.id),
+                    "worker_run_id": str(self.local_worker_run.id),
+                },
+            )
+            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.json(), {
+            "name": ["Value is too long for this field."],
+        })
+
     def test_create_transcription_entity(self):
         self.client.force_login(self.user)
         with self.assertNumQueries(6):
-- 
GitLab