Skip to content
Snippets Groups Projects
Commit 62fca186 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Simplify classes ID attribute

parent 3b0a327c
No related branches found
No related tags found
1 merge request!305Simplify classes ID attribute
Pipeline #80089 passed
......@@ -23,10 +23,10 @@ class ClassificationMixin(object):
"ListCorpusMLClasses",
id=self.corpus_id,
)
self.classes[self.corpus_id] = {
ml_class["name"]: ml_class["id"] for ml_class in corpus_classes
}
logger.info(f"Loaded {len(self.classes[self.corpus_id])} ML classes")
self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
logger.info(
f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})"
)
def get_ml_class_id(self, ml_class: str) -> str:
"""
......@@ -36,17 +36,17 @@ class ClassificationMixin(object):
:param ml_class: Name of the MLClass.
:returns: ID of the retrieved or created MLClass.
"""
if not self.classes.get(self.corpus_id):
if not self.classes:
self.load_corpus_classes()
ml_class_id = self.classes[self.corpus_id].get(ml_class)
ml_class_id = self.classes.get(ml_class)
if ml_class_id is None:
logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
try:
response = self.request(
"CreateMLClass", id=self.corpus_id, body={"name": ml_class}
)
ml_class_id = self.classes[self.corpus_id][ml_class] = response["id"]
ml_class_id = self.classes[ml_class] = response["id"]
logger.debug(f"Created ML class {response['id']}")
except ErrorResponse as e:
# Only reload for 400 errors
......@@ -59,9 +59,9 @@ class ClassificationMixin(object):
)
self.load_corpus_classes()
assert (
ml_class in self.classes[self.corpus_id]
ml_class in self.classes
), "Missing class {ml_class} even after reloading"
ml_class_id = self.classes[self.corpus_id][ml_class]
ml_class_id = self.classes[ml_class]
return ml_class_id
......@@ -73,14 +73,14 @@ class ClassificationMixin(object):
:return: The MLClass's name
"""
# Load the corpus' MLclasses if they are not available yet
if self.corpus_id not in self.classes:
if not self.classes:
self.load_corpus_classes()
# Filter classes by this ml_class_id
ml_class_name = next(
filter(
lambda x: self.classes[self.corpus_id][x] == ml_class_id,
self.classes[self.corpus_id],
lambda x: self.classes[x] == ml_class_id,
self.classes,
),
None,
)
......
......@@ -41,18 +41,14 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
),
]
assert mock_elements_worker.classes == {
"11111111-1111-1111-1111-111111111111": {"good": "0000"}
}
assert mock_elements_worker.classes == {"good": "0000"}
assert ml_class_id == "0000"
def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
# A missing class is now created automatically
corpus_id = "11111111-1111-1111-1111-111111111111"
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"good": "0000"}
}
mock_elements_worker.classes = {"good": "0000"}
responses.add(
responses.POST,
......@@ -62,26 +58,20 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
)
# Missing class at first
assert mock_elements_worker.classes == {
"11111111-1111-1111-1111-111111111111": {"good": "0000"}
}
assert mock_elements_worker.classes == {"good": "0000"}
ml_class_id = mock_elements_worker.get_ml_class_id("bad")
assert ml_class_id == "new-ml-class-1234"
# Now it's available
assert mock_elements_worker.classes == {
"11111111-1111-1111-1111-111111111111": {
"good": "0000",
"bad": "new-ml-class-1234",
}
"good": "0000",
"bad": "new-ml-class-1234",
}
def test_get_ml_class_id(mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"good": "0000"}
}
mock_elements_worker.classes = {"good": "0000"}
ml_class_id = mock_elements_worker.get_ml_class_id("good")
assert ml_class_id == "0000"
......@@ -139,10 +129,8 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
assert len(responses.calls) == len(BASE_API_CALLS) + 3
assert mock_elements_worker.classes == {
corpus_id: {
"class1": "class1_id",
"class2": "class2_id",
}
"class1": "class1_id",
"class2": "class2_id",
}
assert [
(call.request.method, call.request.url) for call in responses.calls
......@@ -166,7 +154,7 @@ def test_retrieve_ml_class_in_cache(mock_elements_worker):
"""
Look for a class that exists in cache -> No API Call
"""
mock_elements_worker.classes[mock_elements_worker.corpus_id] = {"class1": "uuid1"}
mock_elements_worker.classes = {"class1": "uuid1"}
assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1"
......@@ -262,9 +250,7 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
status=201,
json={"id": "new-classification-1234"},
)
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"another_class": "0000"}
}
mock_elements_worker.classes = {"another_class": "0000"}
mock_elements_worker.create_classification(
element=elt,
ml_class="a_class",
......@@ -298,9 +284,7 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
def test_create_classification_wrong_confidence(mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
mock_elements_worker.classes = {"a_class": "0000"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -352,9 +336,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker):
def test_create_classification_wrong_high_confidence(mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
mock_elements_worker.classes = {"a_class": "0000"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
with pytest.raises(AssertionError) as e:
......@@ -381,9 +363,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker):
def test_create_classification_api_error(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
mock_elements_worker.classes = {"a_class": "0000"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -413,9 +393,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
def test_create_classification(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
mock_elements_worker.classes = {"a_class": "0000"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -452,9 +430,7 @@ def test_create_classification(responses, mock_elements_worker):
def test_create_classification_with_cache(responses, mock_elements_worker_with_cache):
mock_elements_worker_with_cache.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
mock_elements_worker_with_cache.classes = {"a_class": "0000"}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
responses.add(
......@@ -513,9 +489,7 @@ def test_create_classification_with_cache(responses, mock_elements_worker_with_c
def test_create_classification_duplicate_worker_run(responses, mock_elements_worker):
mock_elements_worker.classes = {
"11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
}
mock_elements_worker.classes = {"a_class": "0000"}
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
......@@ -870,9 +844,10 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
# Set MLClass in cache
portrait_uuid = str(uuid4())
landscape_uuid = str(uuid4())
mock_elements_worker_with_cache.classes[
mock_elements_worker_with_cache.corpus_id
] = {"portrait": portrait_uuid, "landscape": landscape_uuid}
mock_elements_worker_with_cache.classes = {
"portrait": portrait_uuid,
"landscape": landscape_uuid,
}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
......
......@@ -350,11 +350,9 @@ def test_load_corpus_classes(responses, mock_elements_worker):
),
]
assert mock_elements_worker.classes == {
"11111111-1111-1111-1111-111111111111": {
"good": "0000",
"average": "1111",
"bad": "2222",
}
"good": "0000",
"average": "1111",
"bad": "2222",
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment