Skip to content
Snippets Groups Projects
Commit 137ecd29 authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Fix 'create_classifications' helper

parent 988f3090
No related branches found
No related tags found
1 merge request!509Fix 'create_classifications' helper
Pipeline #163002 passed
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
ElementsWorker methods for classifications and ML classes. ElementsWorker methods for classifications and ML classes.
""" """
from uuid import UUID
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from peewee import IntegrityError from peewee import IntegrityError
...@@ -178,10 +176,14 @@ class ClassificationMixin: ...@@ -178,10 +176,14 @@ class ClassificationMixin:
Create multiple classifications at once on the given element through the API. Create multiple classifications at once on the given element through the API.
:param element: The element to create classifications on. :param element: The element to create classifications on.
:param classifications: The classifications to create, a list of dicts. Each of them contains :param classifications: A list of dicts representing a classification each, with the following keys:
a **ml_class_id** (str), the ID of the MLClass for this classification;
a **confidence** (float), the confidence score, between 0 and 1; ml_class (str)
a **high_confidence** (bool), the high confidence state of the classification. Required. Name of the MLClass to use.
confidence (float)
Required. Confidence score for the classification. Must be between 0 and 1.
high_confidence (bool)
Optional. Whether or not the classification is of high confidence.
:returns: List of created classifications, as returned in the ``classifications`` field by :returns: List of created classifications, as returned in the ``classifications`` field by
the ``CreateClassifications`` API endpoint. the ``CreateClassifications`` API endpoint.
...@@ -194,18 +196,10 @@ class ClassificationMixin: ...@@ -194,18 +196,10 @@ class ClassificationMixin:
), "classifications shouldn't be null and should be of type list" ), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications): for index, classification in enumerate(classifications):
ml_class_id = classification.get("ml_class_id") ml_class = classification.get("ml_class")
assert ( assert (
ml_class_id and isinstance(ml_class_id, str) ml_class and isinstance(ml_class, str)
), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str" ), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str"
# Make sure it's a valid UUID
try:
UUID(ml_class_id)
except ValueError as e:
raise ValueError(
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
) from e
confidence = classification.get("confidence") confidence = classification.get("confidence")
assert ( assert (
...@@ -231,7 +225,13 @@ class ClassificationMixin: ...@@ -231,7 +225,13 @@ class ClassificationMixin:
body={ body={
"parent": str(element.id), "parent": str(element.id),
"worker_run_id": self.worker_run_id, "worker_run_id": self.worker_run_id,
"classifications": classifications, "classifications": [
{
**classification,
"ml_class": self.get_ml_class_id(classification["ml_class"]),
}
for classification in classifications
],
}, },
)["classifications"] )["classifications"]
......
import json import json
import re import re
from uuid import UUID, uuid4 from uuid import UUID
import pytest import pytest
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
...@@ -325,9 +325,6 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses): ...@@ -325,9 +325,6 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
) )
# Check a class & classification has been created # Check a class & classification has been created
for call in responses.calls:
print(call.request.url, call.request.body)
assert [ assert [
(call.request.url, json.loads(call.request.body)) (call.request.url, json.loads(call.request.body))
for call in responses.calls[-2:] for call in responses.calls[-2:]
...@@ -506,12 +503,12 @@ def test_create_classifications_wrong_data( ...@@ -506,12 +503,12 @@ def test_create_classifications_wrong_data(
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"classifications": [ "classifications": [
{ {
"ml_class_id": "uuid1", "ml_class": "cat",
"confidence": 0.75, "confidence": 0.75,
"high_confidence": False, "high_confidence": False,
}, },
{ {
"ml_class_id": "uuid2", "ml_class": "dog",
"confidence": 0.25, "confidence": 0.25,
"high_confidence": False, "high_confidence": False,
}, },
...@@ -523,86 +520,71 @@ def test_create_classifications_wrong_data( ...@@ -523,86 +520,71 @@ def test_create_classifications_wrong_data(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("arg_name", "data", "error_message", "error_type"), ("arg_name", "data", "error_message"),
[ [
# Wrong classifications > ml_class_id # Wrong classifications > ml_class
( (
"ml_class_id", "ml_class",
DELETE_PARAMETER, DELETE_PARAMETER,
"ml_class_id shouldn't be null and should be of type str", "ml_class shouldn't be null and should be of type str",
AssertionError, ),
), # Updated
( (
"ml_class_id", "ml_class",
None, None,
"ml_class_id shouldn't be null and should be of type str", "ml_class shouldn't be null and should be of type str",
AssertionError,
), ),
( (
"ml_class_id", "ml_class",
1234, 1234,
"ml_class_id shouldn't be null and should be of type str", "ml_class shouldn't be null and should be of type str",
AssertionError,
),
(
"ml_class_id",
"not_an_uuid",
"ml_class_id is not a valid uuid.",
ValueError,
), ),
# Wrong classifications > confidence # Wrong classifications > confidence
( (
"confidence", "confidence",
DELETE_PARAMETER, DELETE_PARAMETER,
"confidence shouldn't be null and should be a float in [0..1] range", "confidence shouldn't be null and should be a float in [0..1] range",
AssertionError,
), ),
( (
"confidence", "confidence",
None, None,
"confidence shouldn't be null and should be a float in [0..1] range", "confidence shouldn't be null and should be a float in [0..1] range",
AssertionError,
), ),
( (
"confidence", "confidence",
"wrong confidence", "wrong confidence",
"confidence shouldn't be null and should be a float in [0..1] range", "confidence shouldn't be null and should be a float in [0..1] range",
AssertionError,
), ),
( (
"confidence", "confidence",
0, 0,
"confidence shouldn't be null and should be a float in [0..1] range", "confidence shouldn't be null and should be a float in [0..1] range",
AssertionError,
), ),
( (
"confidence", "confidence",
2.00, 2.00,
"confidence shouldn't be null and should be a float in [0..1] range", "confidence shouldn't be null and should be a float in [0..1] range",
AssertionError,
), ),
# Wrong classifications > high_confidence # Wrong classifications > high_confidence
( (
"high_confidence", "high_confidence",
"wrong high_confidence", "wrong high_confidence",
"high_confidence should be of type bool", "high_confidence should be of type bool",
AssertionError,
), ),
], ],
) )
def test_create_classifications_wrong_classifications_data( def test_create_classifications_wrong_classifications_data(
arg_name, data, error_message, error_type, mock_elements_worker arg_name, data, error_message, mock_elements_worker
): ):
all_data = { all_data = {
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}), "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
"classifications": [ "classifications": [
{ {
"ml_class_id": str(uuid4()), "ml_class": "cat",
"confidence": 0.75, "confidence": 0.75,
"high_confidence": False, "high_confidence": False,
}, },
{ {
"ml_class_id": str(uuid4()), "ml_class": "dog",
"confidence": 0.25, "confidence": 0.25,
"high_confidence": False, "high_confidence": False,
# Overwrite with wrong data # Overwrite with wrong data
...@@ -614,7 +596,7 @@ def test_create_classifications_wrong_classifications_data( ...@@ -614,7 +596,7 @@ def test_create_classifications_wrong_classifications_data(
del all_data["classifications"][1][arg_name] del all_data["classifications"][1][arg_name]
with pytest.raises( with pytest.raises(
error_type, AssertionError,
match=re.escape( match=re.escape(
f"Classification at index 1 in classifications: {error_message}" f"Classification at index 1 in classifications: {error_message}"
), ),
...@@ -623,6 +605,7 @@ def test_create_classifications_wrong_classifications_data( ...@@ -623,6 +605,7 @@ def test_create_classifications_wrong_classifications_data(
def test_create_classifications_api_error(responses, mock_elements_worker): def test_create_classifications_api_error(responses, mock_elements_worker):
mock_elements_worker.classes = {"cat": "0000", "dog": "1111"}
responses.add( responses.add(
responses.POST, responses.POST,
"http://testserver/api/v1/classification/bulk/", "http://testserver/api/v1/classification/bulk/",
...@@ -631,12 +614,12 @@ def test_create_classifications_api_error(responses, mock_elements_worker): ...@@ -631,12 +614,12 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
classes = [ classes = [
{ {
"ml_class_id": str(uuid4()), "ml_class": "cat",
"confidence": 0.75, "confidence": 0.75,
"high_confidence": False, "high_confidence": False,
}, },
{ {
"ml_class_id": str(uuid4()), "ml_class": "dog",
"confidence": 0.25, "confidence": 0.25,
"high_confidence": False, "high_confidence": False,
}, },
...@@ -660,57 +643,96 @@ def test_create_classifications_api_error(responses, mock_elements_worker): ...@@ -660,57 +643,96 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
] ]
def test_create_classifications(responses, mock_elements_worker_with_cache): def test_create_classifications_create_ml_class(mock_elements_worker, responses):
# Set MLClass in cache elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
portrait_uuid = str(uuid4())
landscape_uuid = str(uuid4())
mock_elements_worker_with_cache.classes = {
"portrait": portrait_uuid,
"landscape": landscape_uuid,
}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"ml_class_id": portrait_uuid,
"confidence": 0.75,
"high_confidence": False,
},
{
"ml_class_id": landscape_uuid,
"confidence": 0.25,
"high_confidence": False,
},
]
# Automatically create a missing class!
responses.add(
responses.POST,
"http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
status=201,
json={"id": "new-ml-class-1234"},
)
responses.add( responses.add(
responses.POST, responses.POST,
"http://testserver/api/v1/classification/bulk/", "http://testserver/api/v1/classification/bulk/",
status=200, status=201,
json={ json={
"parent": str(elt.id), "parent": str(elt.id),
"worker_run_id": "56785678-5678-5678-5678-567856785678", "worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": [ "classifications": [
{ {
"id": "00000000-0000-0000-0000-000000000000", "id": "00000000-0000-0000-0000-000000000000",
"ml_class": portrait_uuid, "ml_class": "new-ml-class-1234",
"confidence": 0.75, "confidence": 0.75,
"high_confidence": False, "high_confidence": False,
"state": "pending", "state": "pending",
}, },
{
"id": "11111111-1111-1111-1111-111111111111",
"ml_class": landscape_uuid,
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
},
], ],
}, },
) )
mock_elements_worker.classes = {"another_class": "0000"}
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"ml_class": "a_class",
"confidence": 0.75,
"high_confidence": False,
}
],
)
mock_elements_worker_with_cache.create_classifications( # Check a class & classification has been created
element=elt, classifications=classes assert len(responses.calls) == len(BASE_API_CALLS) + 2
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
),
("POST", "http://testserver/api/v1/classification/bulk/"),
]
assert json.loads(responses.calls[-2].request.body) == {"name": "a_class"}
assert json.loads(responses.calls[-1].request.body) == {
"parent": "12341234-1234-1234-1234-123412341234",
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": [
{
"ml_class": "new-ml-class-1234",
"confidence": 0.75,
"high_confidence": False,
}
],
}
def test_create_classifications(responses, mock_elements_worker):
mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=200,
json={"classifications": []},
)
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"ml_class": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"ml_class": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
) )
assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert len(responses.calls) == len(BASE_API_CALLS) + 1
...@@ -723,52 +745,24 @@ def test_create_classifications(responses, mock_elements_worker_with_cache): ...@@ -723,52 +745,24 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
assert json.loads(responses.calls[-1].request.body) == { assert json.loads(responses.calls[-1].request.body) == {
"parent": str(elt.id), "parent": str(elt.id),
"worker_run_id": "56785678-5678-5678-5678-567856785678", "worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": classes, "classifications": [
{
"confidence": 0.75,
"high_confidence": False,
"ml_class": "0000",
},
{
"confidence": 0.25,
"high_confidence": False,
"ml_class": "1111",
},
],
} }
# Check that created classifications were properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID(elt.id),
class_name="portrait",
confidence=0.75,
state="pending",
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
CachedClassification(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID(elt.id),
class_name="landscape",
confidence=0.25,
state="pending",
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]
def test_create_classifications_not_in_cache( def test_create_classifications_with_cache(responses, mock_elements_worker_with_cache):
responses, mock_elements_worker_with_cache mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
):
"""
CreateClassifications using ID that are not in `.classes` attribute.
Will load corpus MLClass to insert the corresponding name in Cache.
"""
portrait_uuid = str(uuid4())
landscape_uuid = str(uuid4())
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"ml_class_id": portrait_uuid,
"confidence": 0.75,
"high_confidence": False,
},
{
"ml_class_id": landscape_uuid,
"confidence": 0.25,
"high_confidence": False,
},
]
responses.add( responses.add(
responses.POST, responses.POST,
...@@ -780,14 +774,14 @@ def test_create_classifications_not_in_cache( ...@@ -780,14 +774,14 @@ def test_create_classifications_not_in_cache(
"classifications": [ "classifications": [
{ {
"id": "00000000-0000-0000-0000-000000000000", "id": "00000000-0000-0000-0000-000000000000",
"ml_class": portrait_uuid, "ml_class": "0000",
"confidence": 0.75, "confidence": 0.75,
"high_confidence": False, "high_confidence": False,
"state": "pending", "state": "pending",
}, },
{ {
"id": "11111111-1111-1111-1111-111111111111", "id": "11111111-1111-1111-1111-111111111111",
"ml_class": landscape_uuid, "ml_class": "1111",
"confidence": 0.25, "confidence": 0.25,
"high_confidence": False, "high_confidence": False,
"state": "pending", "state": "pending",
...@@ -795,42 +789,45 @@ def test_create_classifications_not_in_cache( ...@@ -795,42 +789,45 @@ def test_create_classifications_not_in_cache(
], ],
}, },
) )
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
status=200,
json={
"count": 2,
"next": None,
"results": [
{
"id": portrait_uuid,
"name": "portrait",
},
{"id": landscape_uuid, "name": "landscape"},
],
},
)
mock_elements_worker_with_cache.create_classifications( mock_elements_worker_with_cache.create_classifications(
element=elt, classifications=classes element=elt,
classifications=[
{
"ml_class": "portrait",
"confidence": 0.75,
"high_confidence": False,
},
{
"ml_class": "landscape",
"confidence": 0.25,
"high_confidence": False,
},
],
) )
assert len(responses.calls) == len(BASE_API_CALLS) + 2 assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [ assert [
(call.request.method, call.request.url) for call in responses.calls (call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [ ] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classification/bulk/"), ("POST", "http://testserver/api/v1/classification/bulk/"),
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
),
] ]
assert json.loads(responses.calls[-2].request.body) == { assert json.loads(responses.calls[-1].request.body) == {
"parent": str(elt.id), "parent": str(elt.id),
"worker_run_id": "56785678-5678-5678-5678-567856785678", "worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": classes, "classifications": [
{
"confidence": 0.75,
"high_confidence": False,
"ml_class": "0000",
},
{
"confidence": 0.25,
"high_confidence": False,
"ml_class": "1111",
},
],
} }
# Check that created classifications were properly stored in SQLite cache # Check that created classifications were properly stored in SQLite cache
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment