Skip to content
Snippets Groups Projects

Add create_classifications method calling CreateClassifications endpoint

Merged Eva Bardou requested to merge add-create-classifications into master
1 file
+ 370
0
Compare changes
  • Side-by-side
  • Inline
@@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
# Classification has NOT been created
assert mock_elements_worker.report.report_data["elements"] == {}
def test_create_classifications_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=None,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element="not element type",
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_classifications_wrong_classifications(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=None,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=1234,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": None,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": 1234,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": None,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": "wrong confidence",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 2.00,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": "wrong high_confidence",
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: high_confidence should be of type bool"
)
def test_create_classifications_api_error(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=500,
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
with pytest.raises(ErrorResponse):
mock_elements_worker.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
]
def test_create_classifications(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=200,
json={
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
},
],
},
)
mock_elements_worker_with_cache.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classification/bulk/"),
]
assert json.loads(responses.calls[-1].request.body) == {
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": classes,
}
# Check that created classifications were properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID(elt.id),
class_name="portrait",
confidence=0.75,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedClassification(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID(elt.id),
class_name="landscape",
confidence=0.25,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
Loading