Skip to content
Snippets Groups Projects
Commit 10527c74 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

add test for retrieve ml class

parent 6399ac95
No related branches found
No related tags found
1 merge request!278Use MLClass when using the CreateClassifications helper
Pipeline #79929 passed
This commit is part of merge request !278. Comments created here will be created in the context of that merge request.
......@@ -64,6 +64,30 @@ class ClassificationMixin(object):
return ml_class_id
def retrieve_ml_class(self, ml_class_id: str) -> str:
"""
Retrieve the name of the MLClass from its ID.
:param ml_class_id: ID of the searched MLClass.
:return: The MLClass's name
"""
# Load the corpus' MLclasses if they are not available yet
if self.corpus_id not in self.classes:
self.load_corpus_classes()
# Filter classes by this ml_class_id
ml_class_name = next(
filter(
lambda x: self.classes[self.corpus_id][x] == ml_class_id,
self.classes[self.corpus_id],
),
None,
)
assert (
ml_class_name is not None
), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})"
return ml_class_name
def create_classification(
self,
element: Union[Element, CachedElement],
......@@ -97,7 +121,6 @@ class ClassificationMixin(object):
"Cannot create classification as this worker is in read-only mode"
)
return
try:
created = self.request(
"CreateClassification",
......@@ -166,7 +189,7 @@ class ClassificationMixin(object):
:param element: The element to create classifications on.
:param classifications: The classifications to create, a list of dicts. Each of them contains
a **class_name** (str), the name of the MLClass for this classification;
a **ml_class_id** (str), the ID of the MLClass for this classification;
a **confidence** (float), the confidence score, between 0 and 1;
a **high_confidence** (bool), the high confidence state of the classification.
......@@ -181,10 +204,10 @@ class ClassificationMixin(object):
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
ml_class_id = classification.get("ml_class_id")
assert ml_class_id and isinstance(
ml_class_id, str
), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
confidence = classification.get("confidence")
assert (
......@@ -215,6 +238,7 @@ class ClassificationMixin(object):
)["classifications"]
for created_cl in created_cls:
created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
......@@ -224,7 +248,7 @@ class ClassificationMixin(object):
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"class_name": created_cl.pop("class_name"),
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_run_id": self.worker_run_id,
......
......@@ -162,6 +162,46 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
]
def test_retrieve_ml_class_in_cache(mock_elements_worker):
"""
Look for a class that exists in cache -> No API Call
"""
mock_elements_worker.classes[mock_elements_worker.corpus_id] = {"class1": "uuid1"}
assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1"
def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
"""
Retrieve class not in cache -> Retrieve corpus ml classes via API
"""
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
status=200,
json={
"count": 1,
"next": None,
"results": [
{
"id": "uuid1",
"name": "class1",
},
],
},
)
assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1"
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
),
]
def test_create_classification_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -520,12 +560,12 @@ def test_create_classifications_wrong_element(mock_elements_worker):
element=None,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": False,
},
......@@ -541,12 +581,12 @@ def test_create_classifications_wrong_element(mock_elements_worker):
element="not element type",
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": False,
},
......@@ -584,19 +624,19 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"confidence": 0.25,
"ml_class_id": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
== "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
......@@ -604,12 +644,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": None,
"ml_class_id": None,
"confidence": 0.25,
"high_confidence": False,
},
......@@ -617,7 +657,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
== "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
......@@ -625,12 +665,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": 1234,
"ml_class_id": 1234,
"confidence": 0.25,
"high_confidence": False,
},
......@@ -638,7 +678,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
== "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
......@@ -646,12 +686,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"high_confidence": False,
},
],
......@@ -666,12 +706,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": None,
"high_confidence": False,
},
......@@ -687,12 +727,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": "wrong confidence",
"high_confidence": False,
},
......@@ -708,12 +748,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0,
"high_confidence": False,
},
......@@ -729,12 +769,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 2.00,
"high_confidence": False,
},
......@@ -750,12 +790,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": "wrong high_confidence",
},
......@@ -776,12 +816,12 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
classes = [
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": False,
},
......@@ -806,15 +846,20 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
def test_create_classifications(responses, mock_elements_worker_with_cache):
# Set MLClass in cache
mock_elements_worker_with_cache.classes[
mock_elements_worker_with_cache.corpus_id
] = {"portrait": "uuid1", "landscape": "uuid2"}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": False,
},
......@@ -830,14 +875,14 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"class_name": "portrait",
"ml_class": "uuid1",
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"class_name": "landscape",
"ml_class": "uuid2",
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
......@@ -882,3 +927,108 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]
def test_create_classifications_not_in_cache(
responses, mock_elements_worker_with_cache
):
"""
CreateClassifications using ID that are not in `.classes` attribute.
Will load corpus MLClass to insert the corresponding name in Cache.
"""
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": False,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=200,
json={
"parent": str(elt.id),
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"ml_class": "uuid1",
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"ml_class": "uuid2",
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
},
],
},
)
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
status=200,
json={
"count": 2,
"next": None,
"results": [
{
"id": "uuid1",
"name": "portrait",
},
{"id": "uuid2", "name": "landscape"},
],
},
)
mock_elements_worker_with_cache.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 2
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classification/bulk/"),
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
),
]
assert json.loads(responses.calls[-2].request.body) == {
"parent": str(elt.id),
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": classes,
}
# Check that created classifications were properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID(elt.id),
class_name="portrait",
confidence=0.75,
state="pending",
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
CachedClassification(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID(elt.id),
class_name="landscape",
confidence=0.25,
state="pending",
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment