From d2c8950fd2b6d7abdf68e8262726ca706fee787d Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Fri, 9 Apr 2021 08:42:16 +0000
Subject: [PATCH] Support giving a CachedElement in helper functions using the
 local cache

---
 arkindex_worker/worker/classification.py      |   5 +-
 arkindex_worker/worker/element.py             |   4 +-
 arkindex_worker/worker/transcription.py       |   4 +-
 .../test_classifications.py                   |  51 +++++++-
 tests/test_elements_worker/test_elements.py   |  18 ++-
 .../test_transcriptions.py                    | 118 +++++++++++++++++-
 6 files changed, 180 insertions(+), 20 deletions(-)

diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py
index 3e979c04..8bc3b7ba 100644
--- a/arkindex_worker/worker/classification.py
+++ b/arkindex_worker/worker/classification.py
@@ -4,6 +4,7 @@ import os
 from apistar.exceptions import ErrorResponse
 
 from arkindex_worker import logger
+from arkindex_worker.cache import CachedElement
 from arkindex_worker.models import Element
 
 
@@ -65,8 +66,8 @@ class ClassificationMixin(object):
         Create a classification on the given element through API
         """
         assert element and isinstance(
-            element, Element
-        ), "element shouldn't be null and should be of type Element"
+            element, (Element, CachedElement)
+        ), "element shouldn't be null and should be an Element or CachedElement"
         assert ml_class and isinstance(
             ml_class, str
         ), "ml_class shouldn't be null and should be of type str"
diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py
index ecedb26a..b25e8c51 100644
--- a/arkindex_worker/worker/element.py
+++ b/arkindex_worker/worker/element.py
@@ -172,8 +172,8 @@ class ElementMixin(object):
         List children of an element
         """
         assert element and isinstance(
-            element, Element
-        ), "element shouldn't be null and should be of type Element"
+            element, (Element, CachedElement)
+        ), "element shouldn't be null and should be an Element or CachedElement"
         query_params = {}
         if best_class is not None:
             assert isinstance(best_class, str) or isinstance(
diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py
index 23f1088e..2dd27a60 100644
--- a/arkindex_worker/worker/transcription.py
+++ b/arkindex_worker/worker/transcription.py
@@ -233,8 +233,8 @@ class TranscriptionMixin(object):
         List transcriptions on an element
         """
         assert element and isinstance(
-            element, Element
-        ), "element shouldn't be null and should be of type Element"
+            element, (Element, CachedElement)
+        ), "element shouldn't be null and should be an Element or CachedElement"
         query_params = {}
         if element_type:
             assert isinstance(element_type, str), "element_type should be of type str"
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index 20a67707..7ea525ea 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -4,6 +4,7 @@ import json
 import pytest
 from apistar.exceptions import ErrorResponse
 
+from arkindex_worker.cache import CachedElement
 from arkindex_worker.models import Element
 
 
@@ -159,7 +160,10 @@ def test_create_classification_wrong_element(mock_elements_worker):
             confidence=0.42,
             high_confidence=True,
         )
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_classification(
@@ -168,7 +172,10 @@ def test_create_classification_wrong_element(mock_elements_worker):
             confidence=0.42,
             high_confidence=True,
         )
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
 
 def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
@@ -394,6 +401,46 @@ def test_create_classification(responses, mock_elements_worker):
     ] == {"a_class": 1}
 
 
+def test_create_classification_with_cached_element(responses, mock_elements_worker):
+    mock_elements_worker.classes = {
+        "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
+    }
+    elt = CachedElement(id="12341234-1234-1234-1234-123412341234")
+
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/classifications/",
+        status=200,
+    )
+
+    mock_elements_worker.create_classification(
+        element=elt,
+        ml_class="a_class",
+        confidence=0.42,
+        high_confidence=True,
+    )
+
+    assert len(responses.calls) == 3
+    assert [call.request.url for call in responses.calls] == [
+        "http://testserver/api/v1/user/",
+        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        "http://testserver/api/v1/classifications/",
+    ]
+
+    assert json.loads(responses.calls[2].request.body) == {
+        "element": "12341234-1234-1234-1234-123412341234",
+        "ml_class": "0000",
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "confidence": 0.42,
+        "high_confidence": True,
+    }
+
+    # Classification has been created and reported
+    assert mock_elements_worker.report.report_data["elements"][elt.id][
+        "classifications"
+    ] == {"a_class": 1}
+
+
 def test_create_classification_duplicate(responses, mock_elements_worker):
     mock_elements_worker.classes = {
         "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index a903a74d..018bd9cb 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -900,11 +900,17 @@ def test_create_elements_integrity_error(
 def test_list_element_children_wrong_element(mock_elements_worker):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.list_element_children(element=None)
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.list_element_children(element="not element type")
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
 
 def test_list_element_children_wrong_best_class(mock_elements_worker):
@@ -1125,7 +1131,7 @@ def test_list_element_children_with_cache_unhandled_param(
         # Filter on element should give all elements inserted
         (
             {
-                "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
+                "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
             },
             (
                 "11111111-1111-1111-1111-111111111111",
@@ -1135,7 +1141,7 @@ def test_list_element_children_with_cache_unhandled_param(
         # Filter on element and page should give the second element
         (
             {
-                "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
+                "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
                 "type": "page",
             },
             ("22222222-2222-2222-2222-222222222222",),
@@ -1143,7 +1149,7 @@ def test_list_element_children_with_cache_unhandled_param(
         # Filter on element and worker version should give all elements
         (
             {
-                "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
+                "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
                 "worker_version": "56785678-5678-5678-5678-567856785678",
             },
             (
@@ -1154,7 +1160,7 @@ def test_list_element_children_with_cache_unhandled_param(
         # Filter on element, type something  and worker version should give first
         (
             {
-                "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
+                "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
                 "type": "something",
                 "worker_version": "56785678-5678-5678-5678-567856785678",
             },
diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py
index 20042876..72c9a6ab 100644
--- a/tests/test_elements_worker/test_transcriptions.py
+++ b/tests/test_elements_worker/test_transcriptions.py
@@ -143,7 +143,42 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
     ]
 
 
-def test_create_transcription(responses, mock_elements_worker_with_cache):
+def test_create_transcription(responses, mock_elements_worker):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+        status=200,
+        json={
+            "id": "56785678-5678-5678-5678-567856785678",
+            "text": "i am a line",
+            "score": 0.42,
+            "confidence": 0.42,
+            "worker_version_id": "12341234-1234-1234-1234-123412341234",
+        },
+    )
+
+    mock_elements_worker.create_transcription(
+        element=elt,
+        text="i am a line",
+        score=0.42,
+    )
+
+    assert len(responses.calls) == 3
+    assert [call.request.url for call in responses.calls] == [
+        "http://testserver/api/v1/user/",
+        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+    ]
+
+    assert json.loads(responses.calls[2].request.body) == {
+        "text": "i am a line",
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "score": 0.42,
+    }
+
+
+def test_create_transcription_with_cache(responses, mock_elements_worker_with_cache):
     elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
 
     responses.add(
@@ -933,7 +968,72 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
     ]
 
 
-def test_create_element_transcriptions(responses, mock_elements_worker_with_cache):
+def test_create_element_transcriptions(responses, mock_elements_worker):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
+        status=200,
+        json=[
+            {
+                "id": "56785678-5678-5678-5678-567856785678",
+                "element_id": "11111111-1111-1111-1111-111111111111",
+                "created": True,
+            },
+            {
+                "id": "67896789-6789-6789-6789-678967896789",
+                "element_id": "22222222-2222-2222-2222-222222222222",
+                "created": False,
+            },
+            {
+                "id": "78907890-7890-7890-7890-789078907890",
+                "element_id": "11111111-1111-1111-1111-111111111111",
+                "created": True,
+            },
+        ],
+    )
+
+    annotations = mock_elements_worker.create_element_transcriptions(
+        element=elt,
+        sub_element_type="page",
+        transcriptions=TRANSCRIPTIONS_SAMPLE,
+    )
+
+    assert len(responses.calls) == 3
+    assert [call.request.url for call in responses.calls] == [
+        "http://testserver/api/v1/user/",
+        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
+    ]
+
+    assert json.loads(responses.calls[2].request.body) == {
+        "element_type": "page",
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "transcriptions": TRANSCRIPTIONS_SAMPLE,
+        "return_elements": True,
+    }
+    assert annotations == [
+        {
+            "id": "56785678-5678-5678-5678-567856785678",
+            "element_id": "11111111-1111-1111-1111-111111111111",
+            "created": True,
+        },
+        {
+            "id": "67896789-6789-6789-6789-678967896789",
+            "element_id": "22222222-2222-2222-2222-222222222222",
+            "created": False,
+        },
+        {
+            "id": "78907890-7890-7890-7890-789078907890",
+            "element_id": "11111111-1111-1111-1111-111111111111",
+            "created": True,
+        },
+    ]
+
+
+def test_create_element_transcriptions_with_cache(
+    responses, mock_elements_worker_with_cache
+):
     elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
 
     responses.add(
@@ -1041,11 +1141,17 @@ def test_create_element_transcriptions(responses, mock_elements_worker_with_cach
 def test_list_transcriptions_wrong_element(mock_elements_worker):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.list_transcriptions(element=None)
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.list_transcriptions(element="not element type")
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
 
 def test_list_transcriptions_wrong_element_type(mock_elements_worker):
@@ -1215,7 +1321,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
         # Filter on element should give all elements inserted
         (
             {
-                "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
+                "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
             },
             (
                 "11111111-1111-1111-1111-111111111111",
@@ -1225,7 +1331,7 @@ def test_list_transcriptions_with_cache_skip_recursive(
         # Filter on element and worker version should give first element
         (
             {
-                "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
+                "element": CachedElement(id="12341234-1234-1234-1234-123412341234"),
                 "worker_version": "56785678-5678-5678-5678-567856785678",
             },
             ("11111111-1111-1111-1111-111111111111",),
-- 
GitLab