Skip to content
Snippets Groups Projects

Add create_classifications method calling CreateClassifications endpoint

Merged Eva Bardou requested to merge add-create-classifications into master
2 files
+ 438
0
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -131,3 +131,71 @@ class ClassificationMixin(object):
raise
self.report.add_classification(element.id, ml_class)
def create_classifications(self, element, classifications):
"""
Create multiple classifications at once on the given element through the API
"""
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
assert classifications and isinstance(
classifications, list
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
confidence = classification.get("confidence")
assert (
confidence is not None
and isinstance(confidence, float)
and 0 <= confidence <= 1
), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range"
high_confidence = classification.get("high_confidence")
if high_confidence is not None:
assert isinstance(
high_confidence, bool
), f"Classification at index {index} in classifications: high_confidence should be of type bool"
if self.is_read_only:
logger.warning(
"Cannot create classifications as this worker is in read-only mode"
)
return
created_cls = self.request(
"CreateClassifications",
body={
"parent": str(element.id),
"worker_version": self.worker_version_id,
"classifications": classifications,
},
)["classifications"]
for created_cl in created_cls:
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
# Store classifications in local cache
try:
to_insert = [
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_version_id": self.worker_version_id,
}
for created_cl in created_cls
]
CachedClassification.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created classifications in local cache: {e}"
)
Loading