From 62fca186ba743e48f688747fc3fb98ffcb11ce10 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 15 Feb 2023 15:30:57 +0000
Subject: [PATCH] Simplify classes ID attribute

---
 arkindex_worker/worker/classification.py      | 24 +++----
 .../test_classifications.py                   | 65 ++++++-------------
 tests/test_elements_worker/test_elements.py   |  8 +--
 3 files changed, 35 insertions(+), 62 deletions(-)

diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py
index 98b5db4b..98cc6eae 100644
--- a/arkindex_worker/worker/classification.py
+++ b/arkindex_worker/worker/classification.py
@@ -23,10 +23,10 @@ class ClassificationMixin(object):
             "ListCorpusMLClasses",
             id=self.corpus_id,
         )
-        self.classes[self.corpus_id] = {
-            ml_class["name"]: ml_class["id"] for ml_class in corpus_classes
-        }
-        logger.info(f"Loaded {len(self.classes[self.corpus_id])} ML classes")
+        self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
+        logger.info(
+            f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})"
+        )
 
     def get_ml_class_id(self, ml_class: str) -> str:
         """
@@ -36,17 +36,17 @@ class ClassificationMixin(object):
         :param ml_class: Name of the MLClass.
         :returns: ID of the retrieved or created MLClass.
         """
-        if not self.classes.get(self.corpus_id):
+        if not self.classes:
             self.load_corpus_classes()
 
-        ml_class_id = self.classes[self.corpus_id].get(ml_class)
+        ml_class_id = self.classes.get(ml_class)
         if ml_class_id is None:
             logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
             try:
                 response = self.request(
                     "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
                 )
-                ml_class_id = self.classes[self.corpus_id][ml_class] = response["id"]
+                ml_class_id = self.classes[ml_class] = response["id"]
                 logger.debug(f"Created ML class {response['id']}")
             except ErrorResponse as e:
                 # Only reload for 400 errors
@@ -59,9 +59,9 @@ class ClassificationMixin(object):
                 )
                 self.load_corpus_classes()
                 assert (
-                    ml_class in self.classes[self.corpus_id]
+                    ml_class in self.classes
                 ), "Missing class {ml_class} even after reloading"
-                ml_class_id = self.classes[self.corpus_id][ml_class]
+                ml_class_id = self.classes[ml_class]
 
         return ml_class_id
 
@@ -73,14 +73,14 @@ class ClassificationMixin(object):
         :return: The MLClass's name
         """
         # Load the corpus' MLclasses if they are not available yet
-        if self.corpus_id not in self.classes:
+        if not self.classes:
             self.load_corpus_classes()
 
         # Filter classes by this ml_class_id
         ml_class_name = next(
             filter(
-                lambda x: self.classes[self.corpus_id][x] == ml_class_id,
-                self.classes[self.corpus_id],
+                lambda x: self.classes[x] == ml_class_id,
+                self.classes,
             ),
             None,
         )
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index 63b06a7a..23db7426 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -41,18 +41,14 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
             f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
         ),
     ]
-    assert mock_elements_worker.classes == {
-        "11111111-1111-1111-1111-111111111111": {"good": "0000"}
-    }
+    assert mock_elements_worker.classes == {"good": "0000"}
     assert ml_class_id == "0000"
 
 
 def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
     # A missing class is now created automatically
     corpus_id = "11111111-1111-1111-1111-111111111111"
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"good": "0000"}
-    }
+    mock_elements_worker.classes = {"good": "0000"}
 
     responses.add(
         responses.POST,
@@ -62,26 +58,20 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
     )
 
     # Missing class at first
-    assert mock_elements_worker.classes == {
-        "11111111-1111-1111-1111-111111111111": {"good": "0000"}
-    }
+    assert mock_elements_worker.classes == {"good": "0000"}
 
     ml_class_id = mock_elements_worker.get_ml_class_id("bad")
     assert ml_class_id == "new-ml-class-1234"
 
     # Now it's available
     assert mock_elements_worker.classes == {
-        "11111111-1111-1111-1111-111111111111": {
-            "good": "0000",
-            "bad": "new-ml-class-1234",
-        }
+        "good": "0000",
+        "bad": "new-ml-class-1234",
     }
 
 
 def test_get_ml_class_id(mock_elements_worker):
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"good": "0000"}
-    }
+    mock_elements_worker.classes = {"good": "0000"}
 
     ml_class_id = mock_elements_worker.get_ml_class_id("good")
     assert ml_class_id == "0000"
@@ -139,10 +129,8 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
 
     assert len(responses.calls) == len(BASE_API_CALLS) + 3
     assert mock_elements_worker.classes == {
-        corpus_id: {
-            "class1": "class1_id",
-            "class2": "class2_id",
-        }
+        "class1": "class1_id",
+        "class2": "class2_id",
     }
     assert [
         (call.request.method, call.request.url) for call in responses.calls
@@ -166,7 +154,7 @@ def test_retrieve_ml_class_in_cache(mock_elements_worker):
     """
     Look for a class that exists in cache -> No API Call
     """
-    mock_elements_worker.classes[mock_elements_worker.corpus_id] = {"class1": "uuid1"}
+    mock_elements_worker.classes = {"class1": "uuid1"}
 
     assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1"
 
@@ -262,9 +250,7 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
         status=201,
         json={"id": "new-classification-1234"},
     )
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"another_class": "0000"}
-    }
+    mock_elements_worker.classes = {"another_class": "0000"}
     mock_elements_worker.create_classification(
         element=elt,
         ml_class="a_class",
@@ -298,9 +284,7 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
 
 
 def test_create_classification_wrong_confidence(mock_elements_worker):
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
-    }
+    mock_elements_worker.classes = {"a_class": "0000"}
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_classification(
@@ -352,9 +336,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker):
 
 
 def test_create_classification_wrong_high_confidence(mock_elements_worker):
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
-    }
+    mock_elements_worker.classes = {"a_class": "0000"}
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
 
     with pytest.raises(AssertionError) as e:
@@ -381,9 +363,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker):
 
 
 def test_create_classification_api_error(responses, mock_elements_worker):
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
-    }
+    mock_elements_worker.classes = {"a_class": "0000"}
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
         responses.POST,
@@ -413,9 +393,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
 
 
 def test_create_classification(responses, mock_elements_worker):
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
-    }
+    mock_elements_worker.classes = {"a_class": "0000"}
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
         responses.POST,
@@ -452,9 +430,7 @@ def test_create_classification(responses, mock_elements_worker):
 
 
 def test_create_classification_with_cache(responses, mock_elements_worker_with_cache):
-    mock_elements_worker_with_cache.classes = {
-        "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
-    }
+    mock_elements_worker_with_cache.classes = {"a_class": "0000"}
     elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
 
     responses.add(
@@ -513,9 +489,7 @@ def test_create_classification_with_cache(responses, mock_elements_worker_with_c
 
 
 def test_create_classification_duplicate_worker_run(responses, mock_elements_worker):
-    mock_elements_worker.classes = {
-        "11111111-1111-1111-1111-111111111111": {"a_class": "0000"}
-    }
+    mock_elements_worker.classes = {"a_class": "0000"}
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
         responses.POST,
@@ -870,9 +844,10 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
     # Set MLClass in cache
     portrait_uuid = str(uuid4())
     landscape_uuid = str(uuid4())
-    mock_elements_worker_with_cache.classes[
-        mock_elements_worker_with_cache.corpus_id
-    ] = {"portrait": portrait_uuid, "landscape": landscape_uuid}
+    mock_elements_worker_with_cache.classes = {
+        "portrait": portrait_uuid,
+        "landscape": landscape_uuid,
+    }
 
     elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
     classes = [
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index 502a20cb..7eb714e0 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -350,11 +350,9 @@ def test_load_corpus_classes(responses, mock_elements_worker):
         ),
     ]
     assert mock_elements_worker.classes == {
-        "11111111-1111-1111-1111-111111111111": {
-            "good": "0000",
-            "average": "1111",
-            "bad": "2222",
-        }
+        "good": "0000",
+        "average": "1111",
+        "bad": "2222",
     }
 
 
-- 
GitLab