From 97d66a04202123117df4d048f60500cf5d0a76d8 Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Thu, 8 Apr 2021 16:41:15 +0200
Subject: [PATCH] Use ARKINDEX_CORPUS_ID env variable in get_ml_class_id

---
 arkindex_worker/worker/classification.py      |  9 +++-
 tests/conftest.py                             |  2 +
 .../test_classifications.py                   | 53 +++++--------------
 3 files changed, 22 insertions(+), 42 deletions(-)

diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py
index d9dfaad1..e9cb5b8a 100644
--- a/arkindex_worker/worker/classification.py
+++ b/arkindex_worker/worker/classification.py
@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+import os
+
 from apistar.exceptions import ErrorResponse
 
 from arkindex_worker import logger
@@ -19,11 +21,14 @@ class ClassificationMixin(object):
         }
         logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes")
 
-    def get_ml_class_id(self, corpus_id, ml_class):
+    def get_ml_class_id(self, ml_class, corpus_id=None):
         """
         Return the ID corresponding to the given class name on a specific corpus
         This method will automatically create missing classes
         """
+        if not corpus_id:
+            corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
+
         if not self.classes.get(corpus_id):
             self.load_corpus_classes(corpus_id)
 
@@ -82,7 +87,7 @@ class ClassificationMixin(object):
                 "CreateClassification",
                 body={
                     "element": element.id,
-                    "ml_class": self.get_ml_class_id(element.corpus.id, ml_class),
+                    "ml_class": self.get_ml_class_id(ml_class),
                     "worker_version": self.worker_version_id,
                     "confidence": confidence,
                     "high_confidence": high_confidence,
diff --git a/tests/conftest.py b/tests/conftest.py
index 7141bad3..265f5d05 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -165,6 +165,7 @@ def mock_user_api(responses):
 def mock_elements_worker(monkeypatch, mock_worker_version_api):
     """Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
     monkeypatch.setattr(sys, "argv", ["worker"])
+    monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
 
     worker = ElementsWorker()
     worker.configure()
@@ -185,6 +186,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api):
 def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api):
     """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest"""
     monkeypatch.setattr(sys, "argv", ["worker"])
+    monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
 
     worker = ElementsWorker(use_cache=True)
     worker.configure()
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index 05e8d8d6..f85ea3b9 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -27,7 +27,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
     )
 
     assert not mock_elements_worker.classes
-    ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good")
+    ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id)
 
     assert len(responses.calls) == 3
     assert [call.request.url for call in responses.calls] == [
@@ -60,7 +60,7 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
         "12341234-1234-1234-1234-123412341234": {"good": "0000"}
     }
 
-    ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "bad")
+    ml_class_id = mock_elements_worker.get_ml_class_id("bad", corpus_id=corpus_id)
     assert ml_class_id == "new-ml-class-1234"
 
     # Now it's available
@@ -78,7 +78,7 @@ def test_get_ml_class_id(mock_elements_worker):
         "12341234-1234-1234-1234-123412341234": {"good": "0000"}
     }
 
-    ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good")
+    ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id)
     assert ml_class_id == "0000"
 
 
@@ -130,7 +130,10 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
     )
 
     # Simply request class 2, it should be reloaded
-    assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id"
+    assert (
+        mock_elements_worker.get_ml_class_id("class2", corpus_id=corpus_id)
+        == "class2_id"
+    )
 
     assert len(responses.calls) == 5
     assert mock_elements_worker.classes == {
@@ -172,12 +175,7 @@ def test_create_classification_wrong_element(mock_elements_worker):
 
 
 def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
-    elt = Element(
-        {
-            "id": "12341234-1234-1234-1234-123412341234",
-            "corpus": {"id": "11111111-1111-1111-1111-111111111111"},
-        }
-    )
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
 
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_classification(
@@ -249,12 +247,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker):
     mock_elements_worker.classes = {
         "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
     }
-    elt = Element(
-        {
-            "id": "12341234-1234-1234-1234-123412341234",
-            "corpus": {"id": "11111111-1111-1111-1111-111111111111"},
-        }
-    )
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_classification(
             element=elt,
@@ -308,12 +301,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker):
     mock_elements_worker.classes = {
         "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
     }
-    elt = Element(
-        {
-            "id": "12341234-1234-1234-1234-123412341234",
-            "corpus": {"id": "11111111-1111-1111-1111-111111111111"},
-        }
-    )
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
 
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_classification(
@@ -342,12 +330,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
     mock_elements_worker.classes = {
         "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
     }
-    elt = Element(
-        {
-            "id": "12341234-1234-1234-1234-123412341234",
-            "corpus": {"id": "11111111-1111-1111-1111-111111111111"},
-        }
-    )
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
         responses.POST,
         "http://testserver/api/v1/classifications/",
@@ -379,12 +362,7 @@ def test_create_classification(responses, mock_elements_worker):
     mock_elements_worker.classes = {
         "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
     }
-    elt = Element(
-        {
-            "id": "12341234-1234-1234-1234-123412341234",
-            "corpus": {"id": "11111111-1111-1111-1111-111111111111"},
-        }
-    )
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
         responses.POST,
         "http://testserver/api/v1/classifications/",
@@ -423,12 +401,7 @@ def test_create_classification_duplicate(responses, mock_elements_worker):
     mock_elements_worker.classes = {
         "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
     }
-    elt = Element(
-        {
-            "id": "12341234-1234-1234-1234-123412341234",
-            "corpus": {"id": "11111111-1111-1111-1111-111111111111"},
-        }
-    )
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
         responses.POST,
         "http://testserver/api/v1/classifications/",
-- 
GitLab