From 0400360ecb6cebfd40d513ea4592c0d9dafb166b Mon Sep 17 00:00:00 2001
From: ml bonhomme <bonhomme@teklia.com>
Date: Fri, 19 Jul 2024 08:55:03 +0000
Subject: [PATCH] Add with_transcription param to elements list endpoints

---
 arkindex/documents/api/elements.py            |  15 ++
 arkindex/documents/serializers/elements.py    |  16 +-
 .../documents/tests/test_children_elements.py |   6 +
 .../documents/tests/test_parents_elements.py  |   1 +
 .../documents/tests/test_transcriptions.py    | 187 ++++++++++++++++++
 arkindex/training/tests/test_datasets_api.py  |   4 +
 6 files changed, 228 insertions(+), 1 deletion(-)

diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py
index b5dc503dd4..e135afb0f1 100644
--- a/arkindex/documents/api/elements.py
+++ b/arkindex/documents/api/elements.py
@@ -475,6 +475,13 @@ class ElementsListAutoSchema(AutoSchema):
                     type=bool,
                     required=False,
                 ),
+                OpenApiParameter(
+                    "with_transcriptions",
+                    description="Returns all transcriptions for each element. "
+                                "Otherwise, `transcriptions` will always be null.",
+                    type=bool,
+                    required=False,
+                ),
                 OpenApiParameter(
                     "with_has_children",
                     description="Include the `has_children` boolean to tell if each element has direct children. "
@@ -924,6 +931,14 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView):
                 to_attr="prefetched_metadata",
             ))
 
+        with_transcriptions = self.clean_params.get("with_transcriptions")
+        if with_transcriptions and with_transcriptions.lower() not in ("false", "0"):
+            prefetch.add(Prefetch(
+                "transcriptions",
+                queryset=Transcription.objects.select_related("worker_run").order_by("-confidence"),
+                to_attr="prefetched_transcriptions"
+            ))
+
         return prefetch
 
     @property
diff --git a/arkindex/documents/serializers/elements.py b/arkindex/documents/serializers/elements.py
index 3c00f5019f..1f443526a7 100644
--- a/arkindex/documents/serializers/elements.py
+++ b/arkindex/documents/serializers/elements.py
@@ -21,7 +21,11 @@ from arkindex.documents.serializers.light import (
     ElementTypeLightSerializer,
     MetaDataLightSerializer,
 )
-from arkindex.documents.serializers.ml import ClassificationSerializer, WorkerRunSummarySerializer
+from arkindex.documents.serializers.ml import (
+    ClassificationSerializer,
+    TranscriptionSerializer,
+    WorkerRunSummarySerializer,
+)
 from arkindex.images.models import Image
 from arkindex.images.serializers import ZoneSerializer
 from arkindex.process.models import WorkerVersion
@@ -427,6 +431,15 @@ class ElementListSerializer(ElementTinySerializer):
         # all of the element's metadata if they have not been prefetched.
         source="prefetched_metadata",
     )
+    transcriptions = TranscriptionSerializer(
+        many=True,
+        default=None,
+        read_only=True,
+        help_text="Transcriptions on this element, if the `with_transcriptions` option has been enabled.",
+        # Use a custom attribute here, so that the serializer does not try to load
+        # all of the element's transcriptions if they have not been prefetched.
+        source="prefetched_transcriptions",
+    )
     has_children = serializers.BooleanField(default=None, read_only=True)
     worker_run = WorkerRunSummarySerializer(read_only=True, allow_null=True)
 
@@ -450,6 +463,7 @@ class ElementListSerializer(ElementTinySerializer):
             "created",
             "classes",
             "metadata",
+            "transcriptions",
             "has_children",
             "worker_run",
             "confidence",
diff --git a/arkindex/documents/tests/test_children_elements.py b/arkindex/documents/tests/test_children_elements.py
index 74aad75eda..c127584ea8 100644
--- a/arkindex/documents/tests/test_children_elements.py
+++ b/arkindex/documents/tests/test_children_elements.py
@@ -107,6 +107,7 @@ class TestChildrenElements(FixtureAPITestCase):
                 "classes": None,
                 "has_children": None,
                 "metadata": None,
+                "transcriptions": None,
                 "thumbnail_url": nested_volume.thumbnail.s3_url,
             }
         ])
@@ -137,6 +138,7 @@ class TestChildrenElements(FixtureAPITestCase):
                 "classes": None,
                 "has_children": None,
                 "metadata": None,
+                "transcriptions": None,
                 "thumbnail_url": nested_volume.thumbnail.s3_url,
             }
         ])
@@ -186,6 +188,7 @@ class TestChildrenElements(FixtureAPITestCase):
                 "classes": None,
                 "has_children": None,
                 "metadata": None,
+                "transcriptions": None,
                 "thumbnail_url": None,
             }
         ])
@@ -215,6 +218,7 @@ class TestChildrenElements(FixtureAPITestCase):
                 "classes": None,
                 "has_children": None,
                 "metadata": None,
+                "transcriptions": None,
                 "thumbnail_url": lonely_element.iiif_thumbnail_url,
             }
         ])
@@ -259,6 +263,7 @@ class TestChildrenElements(FixtureAPITestCase):
                 "classes": None,
                 "has_children": None,
                 "metadata": None,
+                "transcriptions": None,
                 "thumbnail_url": None,
             }
         ])
@@ -293,6 +298,7 @@ class TestChildrenElements(FixtureAPITestCase):
                 "classes": None,
                 "has_children": None,
                 "metadata": None,
+                "transcriptions": None,
                 "thumbnail_url": worker_run_child.iiif_thumbnail_url,
             }
         ])
diff --git a/arkindex/documents/tests/test_parents_elements.py b/arkindex/documents/tests/test_parents_elements.py
index b24752640c..829eaf9868 100644
--- a/arkindex/documents/tests/test_parents_elements.py
+++ b/arkindex/documents/tests/test_parents_elements.py
@@ -87,6 +87,7 @@ class TestParentsElements(FixtureAPITestCase):
                 "classes": None,
                 "has_children": None,
                 "metadata": None,
+                "transcriptions": None,
                 "thumbnail_url": self.vol.thumbnail.s3_url,
                 "rotation_angle": 0,
                 "mirrored": False,
diff --git a/arkindex/documents/tests/test_transcriptions.py b/arkindex/documents/tests/test_transcriptions.py
index f5f59dff0e..c0fcffb796 100644
--- a/arkindex/documents/tests/test_transcriptions.py
+++ b/arkindex/documents/tests/test_transcriptions.py
@@ -297,3 +297,190 @@ class TestTranscriptions(FixtureAPITestCase):
             self.assertEqual(response.status_code, status.HTTP_200_OK)
 
         self.assertListEqual(response.json()["results"], [])
+
+    # Elements listing with with_transcriptions param
+
+    def test_list_elements_with_transcriptions_false(self):
+        self.assertTrue(self.page.transcriptions.exists())
+        self.client.force_login(self.user)
+        with self.assertNumQueries(8):
+            response = self.client.get(
+                reverse("api:corpus-elements", kwargs={"corpus": self.corpus.id}),
+                data={"type": "page", "with_transcriptions": False}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        data = response.json()
+        self.assertEqual(data["count"], 6)
+        for elem in data["results"]:
+            self.assertIsNone(elem["transcriptions"])
+
+    def test_list_elements_with_transcriptions(self):
+        tr1 = self.line.transcriptions.create(
+            text="Demain dès l'aube, à l'heure où blanchit la campagne",
+            worker_run=self.worker_run,
+            worker_version=self.worker_run.version,
+            confidence=0.89
+        )
+        tr2 = self.line.transcriptions.create(
+            text="Je partirai, vois-tu je says que tu m'attends.",
+            worker_run=self.worker_run,
+            worker_version=self.worker_run.version,
+            confidence=0.72
+        )
+        tl_2 = self.corpus.elements.create(
+            type=self.line.type,
+            name="Text Line 2",
+            image=self.page.image,
+            polygon=self.line.polygon
+        )
+
+        self.client.force_login(self.user)
+        with self.assertNumQueries(9):
+            response = self.client.get(
+                reverse("api:corpus-elements", kwargs={"corpus": self.corpus.id}),
+                data={"type": "text_line", "with_transcriptions": True}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        data = response.json()
+        self.assertEqual(data["count"], 2)
+        self.assertEqual(data["results"][0]["id"], str(self.line.id))
+        self.assertListEqual(data["results"][0]["transcriptions"], [
+            {
+                "id": str(tr1.id),
+                "text": "Demain dès l'aube, à l'heure où blanchit la campagne",
+                "confidence": 0.89,
+                "orientation": "horizontal-lr",
+                "worker_run": {
+                    "id": str(self.worker_run.id),
+                    "summary": "Worker Recognizer @ version 1"
+                }
+            },
+            {
+                "id": str(tr2.id),
+                "text": "Je partirai, vois-tu je says que tu m'attends.",
+                "confidence": 0.72,
+                "orientation": "horizontal-lr",
+                "worker_run": {
+                    "id": str(self.worker_run.id),
+                    "summary": "Worker Recognizer @ version 1"
+                }
+            }
+        ])
+        self.assertEqual(data["results"][1]["id"], str(tl_2.id))
+        self.assertListEqual(data["results"][1]["transcriptions"], [])
+
+    def test_list_element_parents_with_transcriptions_false(self):
+        self.assertTrue(self.page.transcriptions.exists())
+        self.client.force_login(self.user)
+        with self.assertNumQueries(8):
+            response = self.client.get(
+                reverse("api:elements-parents", kwargs={"pk": self.line.id}),
+                data={"type": "page", "with_transcriptions": False}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        data = response.json()
+        self.assertEqual(data["count"], 1)
+        self.assertEqual(data["results"][0]["id"], str(self.page.id))
+        self.assertIsNone(data["results"][0]["transcriptions"])
+
+    def test_list_element_parents_with_transcriptions(self):
+        tr = self.page.transcriptions.get()
+        self.client.force_login(self.user)
+        with self.assertNumQueries(9):
+            response = self.client.get(
+                reverse("api:elements-parents", kwargs={"pk": self.line.id}),
+                data={"type": "page", "with_transcriptions": True}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        data = response.json()
+        self.assertEqual(data["count"], 1)
+        self.assertEqual(data["results"][0]["id"], str(self.page.id))
+        self.assertListEqual(data["results"][0]["transcriptions"], [
+            {
+                "id": str(tr.id),
+                "text": "Lorem ipsum dolor sit amet",
+                "confidence": 1,
+                "orientation": "horizontal-lr",
+                "worker_run": {
+                    "id": str(self.worker_run.id),
+                    "summary": "Worker Recognizer @ version 1"
+                }
+            }
+        ])
+
+    def test_list_element_children_with_transcriptions_false(self):
+        self.assertTrue(self.page.transcriptions.exists())
+        self.client.force_login(self.user)
+        with self.assertNumQueries(8):
+            response = self.client.get(
+                reverse("api:elements-children", kwargs={"pk": self.volume.id}),
+                data={"type": "page", "with_transcriptions": False}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        data = response.json()
+        self.assertEqual(data["count"], 3)
+        self.assertEqual(data["results"][0]["id"], str(self.page.id))
+        for page_elem in data["results"]:
+            self.assertIsNone(page_elem["transcriptions"])
+
+    def test_list_element_children_with_transcriptions(self):
+        tr1 = self.line.transcriptions.create(
+            text="Demain dès l'aube, à l'heure où blanchit la campagne",
+            worker_run=self.worker_run,
+            worker_version=self.worker_run.version,
+            confidence=0.89
+        )
+        tr2 = self.line.transcriptions.create(
+            text="Je partirai, vois-tu je says que tu m'attends.",
+            worker_run=self.worker_run,
+            worker_version=self.worker_run.version,
+            confidence=0.72
+        )
+        tl_2 = self.corpus.elements.create(
+            type=self.line.type,
+            name="Text Line 2",
+            image=self.page.image,
+            polygon=self.line.polygon
+        )
+        tl_2.add_parent(self.page)
+
+        self.client.force_login(self.user)
+        with self.assertNumQueries(9):
+            response = self.client.get(
+                reverse("api:elements-children", kwargs={"pk": self.page.id}),
+                data={"type": "text_line", "with_transcriptions": True}
+            )
+            self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+        data = response.json()
+        self.assertEqual(data["count"], 2)
+        self.assertEqual(data["results"][0]["id"], str(self.line.id))
+        self.assertListEqual(data["results"][0]["transcriptions"], [
+            {
+                "id": str(tr1.id),
+                "text": "Demain dès l'aube, à l'heure où blanchit la campagne",
+                "confidence": 0.89,
+                "orientation": "horizontal-lr",
+                "worker_run": {
+                    "id": str(self.worker_run.id),
+                    "summary": "Worker Recognizer @ version 1"
+                }
+            },
+            {
+                "id": str(tr2.id),
+                "text": "Je partirai, vois-tu je says que tu m'attends.",
+                "confidence": 0.72,
+                "orientation": "horizontal-lr",
+                "worker_run": {
+                    "id": str(self.worker_run.id),
+                    "summary": "Worker Recognizer @ version 1"
+                }
+            }
+        ])
+        self.assertEqual(data["results"][1]["id"], str(tl_2.id))
+        self.assertListEqual(data["results"][1]["transcriptions"], [])
diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py
index c2ebc9e145..0113b00ca0 100644
--- a/arkindex/training/tests/test_datasets_api.py
+++ b/arkindex/training/tests/test_datasets_api.py
@@ -1371,6 +1371,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "classes": None,
                     "has_children": None,
                     "metadata": None,
+                    "transcriptions": None,
                     "worker_run": None,
                     "thumbnail_url": None,
                     "created": FAKE_CREATED
@@ -1419,6 +1420,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "classes": None,
                     "has_children": None,
                     "metadata": None,
+                    "transcriptions": None,
                     "worker_run": None,
                     "thumbnail_url": None,
                     "created": FAKE_CREATED
@@ -1467,6 +1469,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "classes": None,
                     "has_children": None,
                     "metadata": None,
+                    "transcriptions": None,
                     "worker_run": None,
                     "thumbnail_url": None,
                     "created": FAKE_CREATED
@@ -1490,6 +1493,7 @@ class TestDatasetsAPI(FixtureAPITestCase):
                     "classes": None,
                     "has_children": None,
                     "metadata": None,
+                    "transcriptions": None,
                     "worker_run": None,
                     "thumbnail_url": "s3_url",
                     "created": FAKE_CREATED
-- 
GitLab