Skip to content
Snippets Groups Projects
Commit 6b7ba8d8 authored by Eva Bardou's avatar Eva Bardou Committed by Bastien Abadie
Browse files

Add create_classifications method calling CreateClassifications endpoint

parent 8d90063f
No related branches found
No related tags found
1 merge request!122Add create_classifications method calling CreateClassifications endpoint
Pipeline #78689 passed
......@@ -131,3 +131,71 @@ class ClassificationMixin(object):
raise
self.report.add_classification(element.id, ml_class)
def create_classifications(self, element, classifications):
"""
Create multiple classifications at once on the given element through the API
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert classifications and isinstance(
classifications, list
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
confidence = classification.get("confidence")
assert (
confidence is not None
and isinstance(confidence, float)
and 0 <= confidence <= 1
), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range"
high_confidence = classification.get("high_confidence")
if high_confidence is not None:
assert isinstance(
high_confidence, bool
), f"Classification at index {index} in classifications: high_confidence should be of type bool"
if self.is_read_only:
logger.warning(
"Cannot create classifications as this worker is in read-only mode"
)
return
created_cls = self.request(
"CreateClassifications",
body={
"parent": str(element.id),
"worker_version": self.worker_version_id,
"classifications": classifications,
},
)["classifications"]
for created_cl in created_cls:
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
# Store classifications in local cache
try:
to_insert = [
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_version_id": self.worker_version_id,
}
for created_cl in created_cls
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classifications in local cache: {e}"
)
......@@ -501,3 +501,373 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
# Classification has NOT been created
assert mock_elements_worker.report.report_data["elements"] == {}
def test_create_classifications_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=None,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element="not element type",
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "element shouldn't be null and should be an Element or CachedElement"
)
def test_create_classifications_wrong_classifications(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=None,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=1234,
)
assert (
str(e.value) == "classifications shouldn't be null and should be of type list"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": None,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": 1234,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": None,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": "wrong confidence",
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 2.00,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": "wrong high_confidence",
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: high_confidence should be of type bool"
)
def test_create_classifications_api_error(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=500,
)
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
with pytest.raises(ErrorResponse):
mock_elements_worker.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
# We retry 5 times the API call
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
("POST", "http://testserver/api/v1/classification/bulk/"),
]
def test_create_classifications(responses, mock_elements_worker_with_cache):
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=200,
json={
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"class_name": "portrait",
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"class_name": "landscape",
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
},
],
},
)
mock_elements_worker_with_cache.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classification/bulk/"),
]
assert json.loads(responses.calls[-1].request.body) == {
"parent": str(elt.id),
"worker_version": "12341234-1234-1234-1234-123412341234",
"classifications": classes,
}
# Check that created classifications were properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID(elt.id),
class_name="portrait",
confidence=0.75,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
CachedClassification(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID(elt.id),
class_name="landscape",
confidence=0.25,
state="pending",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
),
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment