From 1ade7cfa821fb7c44e1657a8e55afc5ea2ebc908 Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Thu, 8 Apr 2021 17:23:03 +0200 Subject: [PATCH] Support giving a CachedElement in create_classification --- arkindex_worker/worker/classification.py | 5 +++-- tests/test_elements_worker/test_classifications.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index d9dfaad1..6f17ee61 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -2,6 +2,7 @@ from apistar.exceptions import ErrorResponse from arkindex_worker import logger +from arkindex_worker.cache import CachedElement from arkindex_worker.models import Element @@ -60,8 +61,8 @@ class ClassificationMixin(object): Create a classification on the given element through API """ assert element and isinstance( - element, Element - ), "element shouldn't be null and should be of type Element" + element, (Element, CachedElement) + ), "element shouldn't be null and should be an Element or CachedElement" assert ml_class and isinstance( ml_class, str ), "ml_class shouldn't be null and should be of type str" diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 05e8d8d6..9e1655aa 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -159,7 +159,10 @@ def test_create_classification_wrong_element(mock_elements_worker): confidence=0.42, high_confidence=True, ) - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) with pytest.raises(AssertionError) as e: mock_elements_worker.create_classification( @@ -168,7 +171,10 @@ def test_create_classification_wrong_element(mock_elements_worker): confidence=0.42, high_confidence=True, ) - assert str(e.value) == "element shouldn't be null and should be of type Element" + assert ( + str(e.value) + == "element shouldn't be null and should be an Element or CachedElement" + ) def test_create_classification_wrong_ml_class(mock_elements_worker, responses): -- GitLab