From 62fca186ba743e48f688747fc3fb98ffcb11ce10 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 15 Feb 2023 15:30:57 +0000 Subject: [PATCH] Simplify classes ID attribute --- arkindex_worker/worker/classification.py | 24 +++---- .../test_classifications.py | 65 ++++++------------- tests/test_elements_worker/test_elements.py | 8 +-- 3 files changed, 35 insertions(+), 62 deletions(-) diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index 98b5db4b..98cc6eae 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -23,10 +23,10 @@ class ClassificationMixin(object): "ListCorpusMLClasses", id=self.corpus_id, ) - self.classes[self.corpus_id] = { - ml_class["name"]: ml_class["id"] for ml_class in corpus_classes - } - logger.info(f"Loaded {len(self.classes[self.corpus_id])} ML classes") + self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes} + logger.info( + f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})" + ) def get_ml_class_id(self, ml_class: str) -> str: """ @@ -36,17 +36,17 @@ class ClassificationMixin(object): :param ml_class: Name of the MLClass. :returns: ID of the retrieved or created MLClass. """ - if not self.classes.get(self.corpus_id): + if not self.classes: self.load_corpus_classes() - ml_class_id = self.classes[self.corpus_id].get(ml_class) + ml_class_id = self.classes.get(ml_class) if ml_class_id is None: logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}") try: response = self.request( "CreateMLClass", id=self.corpus_id, body={"name": ml_class} ) - ml_class_id = self.classes[self.corpus_id][ml_class] = response["id"] + ml_class_id = self.classes[ml_class] = response["id"] logger.debug(f"Created ML class {response['id']}") except ErrorResponse as e: # Only reload for 400 errors @@ -59,9 +59,9 @@ class ClassificationMixin(object): ) self.load_corpus_classes() assert ( - ml_class in self.classes[self.corpus_id] + ml_class in self.classes ), "Missing class {ml_class} even after reloading" - ml_class_id = self.classes[self.corpus_id][ml_class] + ml_class_id = self.classes[ml_class] return ml_class_id @@ -73,14 +73,14 @@ class ClassificationMixin(object): :return: The MLClass's name """ # Load the corpus' MLclasses if they are not available yet - if self.corpus_id not in self.classes: + if not self.classes: self.load_corpus_classes() # Filter classes by this ml_class_id ml_class_name = next( filter( - lambda x: self.classes[self.corpus_id][x] == ml_class_id, - self.classes[self.corpus_id], + lambda x: self.classes[x] == ml_class_id, + self.classes, ), None, ) diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 63b06a7a..23db7426 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -41,18 +41,14 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker): f"http://testserver/api/v1/corpus/{corpus_id}/classes/", ), ] - assert mock_elements_worker.classes == { - "11111111-1111-1111-1111-111111111111": {"good": "0000"} - } + assert mock_elements_worker.classes == {"good": "0000"} assert ml_class_id == "0000" def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses): # A missing class is now created automatically corpus_id = "11111111-1111-1111-1111-111111111111" - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"good": "0000"} - } + mock_elements_worker.classes = {"good": "0000"} responses.add( responses.POST, @@ -62,26 +58,20 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses): ) # Missing class at first - assert mock_elements_worker.classes == { - "11111111-1111-1111-1111-111111111111": {"good": "0000"} - } + assert mock_elements_worker.classes == {"good": "0000"} ml_class_id = mock_elements_worker.get_ml_class_id("bad") assert ml_class_id == "new-ml-class-1234" # Now it's available assert mock_elements_worker.classes == { - "11111111-1111-1111-1111-111111111111": { - "good": "0000", - "bad": "new-ml-class-1234", - } + "good": "0000", + "bad": "new-ml-class-1234", } def test_get_ml_class_id(mock_elements_worker): - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"good": "0000"} - } + mock_elements_worker.classes = {"good": "0000"} ml_class_id = mock_elements_worker.get_ml_class_id("good") assert ml_class_id == "0000" @@ -139,10 +129,8 @@ def test_get_ml_class_reload(responses, mock_elements_worker): assert len(responses.calls) == len(BASE_API_CALLS) + 3 assert mock_elements_worker.classes == { - corpus_id: { - "class1": "class1_id", - "class2": "class2_id", - } + "class1": "class1_id", + "class2": "class2_id", } assert [ (call.request.method, call.request.url) for call in responses.calls @@ -166,7 +154,7 @@ def test_retrieve_ml_class_in_cache(mock_elements_worker): """ Look for a class that exists in cache -> No API Call """ - mock_elements_worker.classes[mock_elements_worker.corpus_id] = {"class1": "uuid1"} + mock_elements_worker.classes = {"class1": "uuid1"} assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1" @@ -262,9 +250,7 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses): status=201, json={"id": "new-classification-1234"}, ) - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"another_class": "0000"} - } + mock_elements_worker.classes = {"another_class": "0000"} mock_elements_worker.create_classification( element=elt, ml_class="a_class", @@ -298,9 +284,7 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses): def test_create_classification_wrong_confidence(mock_elements_worker): - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} - } + mock_elements_worker.classes = {"a_class": "0000"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_classification( @@ -352,9 +336,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker): def test_create_classification_wrong_high_confidence(mock_elements_worker): - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} - } + mock_elements_worker.classes = {"a_class": "0000"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: @@ -381,9 +363,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker): def test_create_classification_api_error(responses, mock_elements_worker): - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} - } + mock_elements_worker.classes = {"a_class": "0000"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, @@ -413,9 +393,7 @@ def test_create_classification_api_error(responses, mock_elements_worker): def test_create_classification(responses, mock_elements_worker): - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} - } + mock_elements_worker.classes = {"a_class": "0000"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, @@ -452,9 +430,7 @@ def test_create_classification(responses, mock_elements_worker): def test_create_classification_with_cache(responses, mock_elements_worker_with_cache): - mock_elements_worker_with_cache.classes = { - "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} - } + mock_elements_worker_with_cache.classes = {"a_class": "0000"} elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") responses.add( @@ -513,9 +489,7 @@ def test_create_classification_with_cache(responses, mock_elements_worker_with_c def test_create_classification_duplicate_worker_run(responses, mock_elements_worker): - mock_elements_worker.classes = { - "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} - } + mock_elements_worker.classes = {"a_class": "0000"} elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, @@ -870,9 +844,10 @@ def test_create_classifications(responses, mock_elements_worker_with_cache): # Set MLClass in cache portrait_uuid = str(uuid4()) landscape_uuid = str(uuid4()) - mock_elements_worker_with_cache.classes[ - mock_elements_worker_with_cache.corpus_id - ] = {"portrait": portrait_uuid, "landscape": landscape_uuid} + mock_elements_worker_with_cache.classes = { + "portrait": portrait_uuid, + "landscape": landscape_uuid, + } elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") classes = [ diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index 502a20cb..7eb714e0 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -350,11 +350,9 @@ def test_load_corpus_classes(responses, mock_elements_worker): ), ] assert mock_elements_worker.classes == { - "11111111-1111-1111-1111-111111111111": { - "good": "0000", - "average": "1111", - "bad": "2222", - } + "good": "0000", + "average": "1111", + "bad": "2222", } -- GitLab