diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index 5e57326efdd5f061f0266e98e9cb18f4b655f8d1..6d356c3902a1cfd179bde47dec361e654d5fe46a 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -131,3 +131,71 @@ class ClassificationMixin(object): raise self.report.add_classification(element.id, ml_class) + + def create_classifications(self, element, classifications): + """ + Create multiple classifications at once on the given element through the API + """ + assert element and isinstance( + element, (Element, CachedElement) + ), "element shouldn't be null and should be an Element or CachedElement" + assert classifications and isinstance( + classifications, list + ), "classifications shouldn't be null and should be of type list" + + for index, classification in enumerate(classifications): + class_name = classification.get("class_name") + assert class_name and isinstance( + class_name, str + ), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str" + + confidence = classification.get("confidence") + assert ( + confidence is not None + and isinstance(confidence, float) + and 0 <= confidence <= 1 + ), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range" + + high_confidence = classification.get("high_confidence") + if high_confidence is not None: + assert isinstance( + high_confidence, bool + ), f"Classification at index {index} in classifications: high_confidence should be of type bool" + + if self.is_read_only: + logger.warning( + "Cannot create classifications as this worker is in read-only mode" + ) + return + + created_cls = self.request( + "CreateClassifications", + body={ + "parent": str(element.id), + "worker_version": self.worker_version_id, + "classifications": classifications, + }, + )["classifications"] + + for created_cl in created_cls: + self.report.add_classification(element.id, created_cl["class_name"]) + + if self.use_cache: + # Store classifications in local cache + try: + to_insert = [ + { + "id": created_cl["id"], + "element_id": element.id, + "class_name": created_cl["class_name"], + "confidence": created_cl["confidence"], + "state": created_cl["state"], + "worker_version_id": self.worker_version_id, + } + for created_cl in created_cls + ] + CachedClassification.insert_many(to_insert).execute() + except IntegrityError as e: + logger.warning( + f"Couldn't save created classifications in local cache: {e}" + )