From 0400360ecb6cebfd40d513ea4592c0d9dafb166b Mon Sep 17 00:00:00 2001 From: ml bonhomme <bonhomme@teklia.com> Date: Fri, 19 Jul 2024 08:55:03 +0000 Subject: [PATCH] Add with_transcription param to elements list endpoints --- arkindex/documents/api/elements.py | 15 ++ arkindex/documents/serializers/elements.py | 16 +- .../documents/tests/test_children_elements.py | 6 + .../documents/tests/test_parents_elements.py | 1 + .../documents/tests/test_transcriptions.py | 187 ++++++++++++++++++ arkindex/training/tests/test_datasets_api.py | 4 + 6 files changed, 228 insertions(+), 1 deletion(-) diff --git a/arkindex/documents/api/elements.py b/arkindex/documents/api/elements.py index b5dc503dd4..e135afb0f1 100644 --- a/arkindex/documents/api/elements.py +++ b/arkindex/documents/api/elements.py @@ -475,6 +475,13 @@ class ElementsListAutoSchema(AutoSchema): type=bool, required=False, ), + OpenApiParameter( + "with_transcriptions", + description="Returns all transcriptions for each element. " + "Otherwise, `transcriptions` will always be null.", + type=bool, + required=False, + ), OpenApiParameter( "with_has_children", description="Include the `has_children` boolean to tell if each element has direct children. " @@ -924,6 +931,14 @@ class ElementsListBase(CorpusACLMixin, DestroyModelMixin, ListAPIView): to_attr="prefetched_metadata", )) + with_transcriptions = self.clean_params.get("with_transcriptions") + if with_transcriptions and with_transcriptions.lower() not in ("false", "0"): + prefetch.add(Prefetch( + "transcriptions", + queryset=Transcription.objects.select_related("worker_run").order_by("-confidence"), + to_attr="prefetched_transcriptions" + )) + return prefetch @property diff --git a/arkindex/documents/serializers/elements.py b/arkindex/documents/serializers/elements.py index 3c00f5019f..1f443526a7 100644 --- a/arkindex/documents/serializers/elements.py +++ b/arkindex/documents/serializers/elements.py @@ -21,7 +21,11 @@ from arkindex.documents.serializers.light import ( ElementTypeLightSerializer, MetaDataLightSerializer, ) -from arkindex.documents.serializers.ml import ClassificationSerializer, WorkerRunSummarySerializer +from arkindex.documents.serializers.ml import ( + ClassificationSerializer, + TranscriptionSerializer, + WorkerRunSummarySerializer, +) from arkindex.images.models import Image from arkindex.images.serializers import ZoneSerializer from arkindex.process.models import WorkerVersion @@ -427,6 +431,15 @@ class ElementListSerializer(ElementTinySerializer): # all of the element's metadata if they have not been prefetched. source="prefetched_metadata", ) + transcriptions = TranscriptionSerializer( + many=True, + default=None, + read_only=True, + help_text="Transcriptions on this element, if the `with_transcriptions` option has been enabled.", + # Use a custom attribute here, so that the serializer does not try to load + # all of the element's transcriptions if they have not been prefetched. + source="prefetched_transcriptions", + ) has_children = serializers.BooleanField(default=None, read_only=True) worker_run = WorkerRunSummarySerializer(read_only=True, allow_null=True) @@ -450,6 +463,7 @@ class ElementListSerializer(ElementTinySerializer): "created", "classes", "metadata", + "transcriptions", "has_children", "worker_run", "confidence", diff --git a/arkindex/documents/tests/test_children_elements.py b/arkindex/documents/tests/test_children_elements.py index 74aad75eda..c127584ea8 100644 --- a/arkindex/documents/tests/test_children_elements.py +++ b/arkindex/documents/tests/test_children_elements.py @@ -107,6 +107,7 @@ class TestChildrenElements(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "thumbnail_url": nested_volume.thumbnail.s3_url, } ]) @@ -137,6 +138,7 @@ class TestChildrenElements(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "thumbnail_url": nested_volume.thumbnail.s3_url, } ]) @@ -186,6 +188,7 @@ class TestChildrenElements(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "thumbnail_url": None, } ]) @@ -215,6 +218,7 @@ class TestChildrenElements(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "thumbnail_url": lonely_element.iiif_thumbnail_url, } ]) @@ -259,6 +263,7 @@ class TestChildrenElements(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "thumbnail_url": None, } ]) @@ -293,6 +298,7 @@ class TestChildrenElements(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "thumbnail_url": worker_run_child.iiif_thumbnail_url, } ]) diff --git a/arkindex/documents/tests/test_parents_elements.py b/arkindex/documents/tests/test_parents_elements.py index b24752640c..829eaf9868 100644 --- a/arkindex/documents/tests/test_parents_elements.py +++ b/arkindex/documents/tests/test_parents_elements.py @@ -87,6 +87,7 @@ class TestParentsElements(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "thumbnail_url": self.vol.thumbnail.s3_url, "rotation_angle": 0, "mirrored": False, diff --git a/arkindex/documents/tests/test_transcriptions.py b/arkindex/documents/tests/test_transcriptions.py index f5f59dff0e..c0fcffb796 100644 --- a/arkindex/documents/tests/test_transcriptions.py +++ b/arkindex/documents/tests/test_transcriptions.py @@ -297,3 +297,190 @@ class TestTranscriptions(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertListEqual(response.json()["results"], []) + + # Elements listing with with_transcriptions param + + def test_list_elements_with_transcriptions_false(self): + self.assertTrue(self.page.transcriptions.exists()) + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.get( + reverse("api:corpus-elements", kwargs={"corpus": self.corpus.id}), + data={"type": "page", "with_transcriptions": False} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(data["count"], 6) + for elem in data["results"]: + self.assertIsNone(elem["transcriptions"]) + + def test_list_elements_with_transcriptions(self): + tr1 = self.line.transcriptions.create( + text="Demain dès l'aube, à l'heure où blanchit la campagne", + worker_run=self.worker_run, + worker_version=self.worker_run.version, + confidence=0.89 + ) + tr2 = self.line.transcriptions.create( + text="Je partirai, vois-tu je says que tu m'attends.", + worker_run=self.worker_run, + worker_version=self.worker_run.version, + confidence=0.72 + ) + tl_2 = self.corpus.elements.create( + type=self.line.type, + name="Text Line 2", + image=self.page.image, + polygon=self.line.polygon + ) + + self.client.force_login(self.user) + with self.assertNumQueries(9): + response = self.client.get( + reverse("api:corpus-elements", kwargs={"corpus": self.corpus.id}), + data={"type": "text_line", "with_transcriptions": True} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(data["count"], 2) + self.assertEqual(data["results"][0]["id"], str(self.line.id)) + self.assertListEqual(data["results"][0]["transcriptions"], [ + { + "id": str(tr1.id), + "text": "Demain dès l'aube, à l'heure où blanchit la campagne", + "confidence": 0.89, + "orientation": "horizontal-lr", + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Recognizer @ version 1" + } + }, + { + "id": str(tr2.id), + "text": "Je partirai, vois-tu je says que tu m'attends.", + "confidence": 0.72, + "orientation": "horizontal-lr", + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Recognizer @ version 1" + } + } + ]) + self.assertEqual(data["results"][1]["id"], str(tl_2.id)) + self.assertListEqual(data["results"][1]["transcriptions"], []) + + def test_list_element_parents_with_transcriptions_false(self): + self.assertTrue(self.page.transcriptions.exists()) + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.get( + reverse("api:elements-parents", kwargs={"pk": self.line.id}), + data={"type": "page", "with_transcriptions": False} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(data["count"], 1) + self.assertEqual(data["results"][0]["id"], str(self.page.id)) + self.assertIsNone(data["results"][0]["transcriptions"]) + + def test_list_element_parents_with_transcriptions(self): + tr = self.page.transcriptions.get() + self.client.force_login(self.user) + with self.assertNumQueries(9): + response = self.client.get( + reverse("api:elements-parents", kwargs={"pk": self.line.id}), + data={"type": "page", "with_transcriptions": True} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(data["count"], 1) + self.assertEqual(data["results"][0]["id"], str(self.page.id)) + self.assertListEqual(data["results"][0]["transcriptions"], [ + { + "id": str(tr.id), + "text": "Lorem ipsum dolor sit amet", + "confidence": 1, + "orientation": "horizontal-lr", + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Recognizer @ version 1" + } + } + ]) + + def test_list_element_children_with_transcriptions_false(self): + self.assertTrue(self.page.transcriptions.exists()) + self.client.force_login(self.user) + with self.assertNumQueries(8): + response = self.client.get( + reverse("api:elements-children", kwargs={"pk": self.volume.id}), + data={"type": "page", "with_transcriptions": False} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(data["count"], 3) + self.assertEqual(data["results"][0]["id"], str(self.page.id)) + for page_elem in data["results"]: + self.assertIsNone(page_elem["transcriptions"]) + + def test_list_element_children_with_transcriptions(self): + tr1 = self.line.transcriptions.create( + text="Demain dès l'aube, à l'heure où blanchit la campagne", + worker_run=self.worker_run, + worker_version=self.worker_run.version, + confidence=0.89 + ) + tr2 = self.line.transcriptions.create( + text="Je partirai, vois-tu je says que tu m'attends.", + worker_run=self.worker_run, + worker_version=self.worker_run.version, + confidence=0.72 + ) + tl_2 = self.corpus.elements.create( + type=self.line.type, + name="Text Line 2", + image=self.page.image, + polygon=self.line.polygon + ) + tl_2.add_parent(self.page) + + self.client.force_login(self.user) + with self.assertNumQueries(9): + response = self.client.get( + reverse("api:elements-children", kwargs={"pk": self.page.id}), + data={"type": "text_line", "with_transcriptions": True} + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(data["count"], 2) + self.assertEqual(data["results"][0]["id"], str(self.line.id)) + self.assertListEqual(data["results"][0]["transcriptions"], [ + { + "id": str(tr1.id), + "text": "Demain dès l'aube, à l'heure où blanchit la campagne", + "confidence": 0.89, + "orientation": "horizontal-lr", + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Recognizer @ version 1" + } + }, + { + "id": str(tr2.id), + "text": "Je partirai, vois-tu je says que tu m'attends.", + "confidence": 0.72, + "orientation": "horizontal-lr", + "worker_run": { + "id": str(self.worker_run.id), + "summary": "Worker Recognizer @ version 1" + } + } + ]) + self.assertEqual(data["results"][1]["id"], str(tl_2.id)) + self.assertListEqual(data["results"][1]["transcriptions"], []) diff --git a/arkindex/training/tests/test_datasets_api.py b/arkindex/training/tests/test_datasets_api.py index c2ebc9e145..0113b00ca0 100644 --- a/arkindex/training/tests/test_datasets_api.py +++ b/arkindex/training/tests/test_datasets_api.py @@ -1371,6 +1371,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "worker_run": None, "thumbnail_url": None, "created": FAKE_CREATED @@ -1419,6 +1420,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "worker_run": None, "thumbnail_url": None, "created": FAKE_CREATED @@ -1467,6 +1469,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "worker_run": None, "thumbnail_url": None, "created": FAKE_CREATED @@ -1490,6 +1493,7 @@ class TestDatasetsAPI(FixtureAPITestCase): "classes": None, "has_children": None, "metadata": None, + "transcriptions": None, "worker_run": None, "thumbnail_url": "s3_url", "created": FAKE_CREATED -- GitLab