From 40a3a12704be3da03a98c4c83f102080ab8bbf1c Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Fri, 9 Apr 2021 09:54:10 +0200
Subject: [PATCH] Refactor code

---
 arkindex_worker/worker/classification.py           |  6 +++---
 tests/test_elements_worker/test_classifications.py | 11 ++++-------
 2 files changed, 7 insertions(+), 10 deletions(-)

diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py
index e9cb5b8a..3e979c04 100644
--- a/arkindex_worker/worker/classification.py
+++ b/arkindex_worker/worker/classification.py
@@ -21,12 +21,12 @@ class ClassificationMixin(object):
         }
         logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes")
 
-    def get_ml_class_id(self, ml_class, corpus_id=None):
+    def get_ml_class_id(self, corpus_id, ml_class):
         """
         Return the ID corresponding to the given class name on a specific corpus
         This method will automatically create missing classes
         """
-        if not corpus_id:
+        if corpus_id is None:
             corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
 
         if not self.classes.get(corpus_id):
@@ -87,7 +87,7 @@ class ClassificationMixin(object):
                 "CreateClassification",
                 body={
                     "element": element.id,
-                    "ml_class": self.get_ml_class_id(ml_class),
+                    "ml_class": self.get_ml_class_id(None, ml_class),
                     "worker_version": self.worker_version_id,
                     "confidence": confidence,
                     "high_confidence": high_confidence,
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index f85ea3b9..20a67707 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -27,7 +27,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
     )
 
     assert not mock_elements_worker.classes
-    ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id)
+    ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good")
 
     assert len(responses.calls) == 3
     assert [call.request.url for call in responses.calls] == [
@@ -60,7 +60,7 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
         "12341234-1234-1234-1234-123412341234": {"good": "0000"}
     }
 
-    ml_class_id = mock_elements_worker.get_ml_class_id("bad", corpus_id=corpus_id)
+    ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "bad")
     assert ml_class_id == "new-ml-class-1234"
 
     # Now it's available
@@ -78,7 +78,7 @@ def test_get_ml_class_id(mock_elements_worker):
         "12341234-1234-1234-1234-123412341234": {"good": "0000"}
     }
 
-    ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id)
+    ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good")
     assert ml_class_id == "0000"
 
 
@@ -130,10 +130,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
     )
 
     # Simply request class 2, it should be reloaded
-    assert (
-        mock_elements_worker.get_ml_class_id("class2", corpus_id=corpus_id)
-        == "class2_id"
-    )
+    assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id"
 
     assert len(responses.calls) == 5
     assert mock_elements_worker.classes == {
-- 
GitLab