Skip to content
Snippets Groups Projects

Use MLClass when using the CreateClassifications helper

Merged Yoann Schneider requested to merge use-ml-class-create-classifs into master
All threads resolved!
2 files
+ 212
38
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -64,6 +64,30 @@ class ClassificationMixin(object):
return ml_class_id
def retrieve_ml_class(self, ml_class_id: str) -> str:
"""
Retrieve the name of the MLClass from its ID.
:param ml_class_id: ID of the searched MLClass.
:return: The MLClass's name
"""
# Load the corpus' MLclasses if they are not available yet
if self.corpus_id not in self.classes:
self.load_corpus_classes()
# Filter classes by this ml_class_id
ml_class_name = next(
filter(
lambda x: self.classes[self.corpus_id][x] == ml_class_id,
self.classes[self.corpus_id],
),
None,
)
assert (
ml_class_name is not None
), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})"
return ml_class_name
def create_classification(
self,
element: Union[Element, CachedElement],
@@ -97,7 +121,6 @@ class ClassificationMixin(object):
"Cannot create classification as this worker is in read-only mode"
)
return
try:
created = self.request(
"CreateClassification",
@@ -166,7 +189,7 @@ class ClassificationMixin(object):
:param element: The element to create classifications on.
:param classifications: The classifications to create, a list of dicts. Each of them contains
a **class_name** (str), the name of the MLClass for this classification;
a **ml_class_id** (str), the ID of the MLClass for this classification;
a **confidence** (float), the confidence score, between 0 and 1;
a **high_confidence** (bool), the high confidence state of the classification.
:returns: List of created classifications, as returned in the ``classifications`` field by
the ``CreateClassifications`` API endpoint.
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert classifications and isinstance(
@@ -181,7+204,7 @@
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
ml_class_id = classification.get("ml_class_id")
assert ml_class_id and isinstance(
ml_class_id, str
), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
confidence = classification.get("confidence")
assert (
@@ -215,6 +238,7 @@ class ClassificationMixin(object):
)["classifications"]
for created_cl in created_cls:
created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
@@ -224,7 +248,7 @@ class ClassificationMixin(object):
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"class_name": created_cl.pop("class_name"),
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_run_id": self.worker_run_id,
Loading