From 40a3a12704be3da03a98c4c83f102080ab8bbf1c Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Fri, 9 Apr 2021 09:54:10 +0200 Subject: [PATCH] Refactor code --- arkindex_worker/worker/classification.py | 6 +++--- tests/test_elements_worker/test_classifications.py | 11 ++++------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index e9cb5b8a..3e979c04 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -21,12 +21,12 @@ class ClassificationMixin(object): } logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes") - def get_ml_class_id(self, ml_class, corpus_id=None): + def get_ml_class_id(self, corpus_id, ml_class): """ Return the ID corresponding to the given class name on a specific corpus This method will automatically create missing classes """ - if not corpus_id: + if corpus_id is None: corpus_id = os.environ.get("ARKINDEX_CORPUS_ID") if not self.classes.get(corpus_id): @@ -87,7 +87,7 @@ class ClassificationMixin(object): "CreateClassification", body={ "element": element.id, - "ml_class": self.get_ml_class_id(ml_class), + "ml_class": self.get_ml_class_id(None, ml_class), "worker_version": self.worker_version_id, "confidence": confidence, "high_confidence": high_confidence, diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index f85ea3b9..20a67707 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -27,7 +27,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker): ) assert not mock_elements_worker.classes - ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id) + ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good") assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ @@ -60,7 +60,7 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses): "12341234-1234-1234-1234-123412341234": {"good": "0000"} } - ml_class_id = mock_elements_worker.get_ml_class_id("bad", corpus_id=corpus_id) + ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "bad") assert ml_class_id == "new-ml-class-1234" # Now it's available @@ -78,7 +78,7 @@ def test_get_ml_class_id(mock_elements_worker): "12341234-1234-1234-1234-123412341234": {"good": "0000"} } - ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id) + ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good") assert ml_class_id == "0000" @@ -130,10 +130,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker): ) # Simply request class 2, it should be reloaded - assert ( - mock_elements_worker.get_ml_class_id("class2", corpus_id=corpus_id) - == "class2_id" - ) + assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id" assert len(responses.calls) == 5 assert mock_elements_worker.classes == { -- GitLab