From 1ade7cfa821fb7c44e1657a8e55afc5ea2ebc908 Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Thu, 8 Apr 2021 17:23:03 +0200
Subject: [PATCH] Support giving a CachedElement in create_classification

---
 arkindex_worker/worker/classification.py           |  5 +++--
 tests/test_elements_worker/test_classifications.py | 10 ++++++++--
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py
index d9dfaad1..6f17ee61 100644
--- a/arkindex_worker/worker/classification.py
+++ b/arkindex_worker/worker/classification.py
@@ -2,6 +2,7 @@
 from apistar.exceptions import ErrorResponse
 
 from arkindex_worker import logger
+from arkindex_worker.cache import CachedElement
 from arkindex_worker.models import Element
 
 
@@ -60,8 +61,8 @@ class ClassificationMixin(object):
         Create a classification on the given element through API
         """
         assert element and isinstance(
-            element, Element
-        ), "element shouldn't be null and should be of type Element"
+            element, (Element, CachedElement)
+        ), "element shouldn't be null and should be an Element or CachedElement"
         assert ml_class and isinstance(
             ml_class, str
         ), "ml_class shouldn't be null and should be of type str"
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index 05e8d8d6..9e1655aa 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -159,7 +159,10 @@ def test_create_classification_wrong_element(mock_elements_worker):
             confidence=0.42,
             high_confidence=True,
         )
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_classification(
@@ -168,7 +171,10 @@ def test_create_classification_wrong_element(mock_elements_worker):
             confidence=0.42,
             high_confidence=True,
         )
-    assert str(e.value) == "element shouldn't be null and should be of type Element"
+    assert (
+        str(e.value)
+        == "element shouldn't be null and should be an Element or CachedElement"
+    )
 
 
 def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
-- 
GitLab