From 97d66a04202123117df4d048f60500cf5d0a76d8 Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Thu, 8 Apr 2021 16:41:15 +0200 Subject: [PATCH] Use ARKINDEX_CORPUS_ID env variable in get_ml_class_id --- arkindex_worker/worker/classification.py | 9 +++- tests/conftest.py | 2 + .../test_classifications.py | 53 +++++-------------- 3 files changed, 22 insertions(+), 42 deletions(-) diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index d9dfaad1..e9cb5b8a 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import os + from apistar.exceptions import ErrorResponse from arkindex_worker import logger @@ -19,11 +21,14 @@ class ClassificationMixin(object): } logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes") - def get_ml_class_id(self, corpus_id, ml_class): + def get_ml_class_id(self, ml_class, corpus_id=None): """ Return the ID corresponding to the given class name on a specific corpus This method will automatically create missing classes """ + if not corpus_id: + corpus_id = os.environ.get("ARKINDEX_CORPUS_ID") + if not self.classes.get(corpus_id): self.load_corpus_classes(corpus_id) @@ -82,7 +87,7 @@ class ClassificationMixin(object): "CreateClassification", body={ "element": element.id, - "ml_class": self.get_ml_class_id(element.corpus.id, ml_class), + "ml_class": self.get_ml_class_id(ml_class), "worker_version": self.worker_version_id, "confidence": confidence, "high_confidence": high_confidence, diff --git a/tests/conftest.py b/tests/conftest.py index 7141bad3..265f5d05 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,6 +165,7 @@ def mock_user_api(responses): def mock_elements_worker(monkeypatch, mock_worker_version_api): """Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest""" monkeypatch.setattr(sys, "argv", ["worker"]) + monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111") worker = ElementsWorker() worker.configure() @@ -185,6 +186,7 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_version_api): def mock_elements_worker_with_cache(monkeypatch, mock_worker_version_api): """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest""" monkeypatch.setattr(sys, "argv", ["worker"]) + monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111") worker = ElementsWorker(use_cache=True) worker.configure() diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 05e8d8d6..f85ea3b9 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -27,7 +27,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker): ) assert not mock_elements_worker.classes - ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good") + ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id) assert len(responses.calls) == 3 assert [call.request.url for call in responses.calls] == [ @@ -60,7 +60,7 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses): "12341234-1234-1234-1234-123412341234": {"good": "0000"} } - ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "bad") + ml_class_id = mock_elements_worker.get_ml_class_id("bad", corpus_id=corpus_id) assert ml_class_id == "new-ml-class-1234" # Now it's available @@ -78,7 +78,7 @@ def test_get_ml_class_id(mock_elements_worker): "12341234-1234-1234-1234-123412341234": {"good": "0000"} } - ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good") + ml_class_id = mock_elements_worker.get_ml_class_id("good", corpus_id=corpus_id) assert ml_class_id == "0000" @@ -130,7 +130,10 @@ def test_get_ml_class_reload(responses, mock_elements_worker): ) # Simply request class 2, it should be reloaded - assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id" + assert ( + mock_elements_worker.get_ml_class_id("class2", corpus_id=corpus_id) + == "class2_id" + ) assert len(responses.calls) == 5 assert mock_elements_worker.classes == { @@ -172,12 +175,7 @@ def test_create_classification_wrong_element(mock_elements_worker): def test_create_classification_wrong_ml_class(mock_elements_worker, responses): - elt = Element( - { - "id": "12341234-1234-1234-1234-123412341234", - "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, - } - ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_classification( @@ -249,12 +247,7 @@ def test_create_classification_wrong_confidence(mock_elements_worker): mock_elements_worker.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} } - elt = Element( - { - "id": "12341234-1234-1234-1234-123412341234", - "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, - } - ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_classification( element=elt, @@ -308,12 +301,7 @@ def test_create_classification_wrong_high_confidence(mock_elements_worker): mock_elements_worker.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} } - elt = Element( - { - "id": "12341234-1234-1234-1234-123412341234", - "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, - } - ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: mock_elements_worker.create_classification( @@ -342,12 +330,7 @@ def test_create_classification_api_error(responses, mock_elements_worker): mock_elements_worker.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} } - elt = Element( - { - "id": "12341234-1234-1234-1234-123412341234", - "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, - } - ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, "http://testserver/api/v1/classifications/", @@ -379,12 +362,7 @@ def test_create_classification(responses, mock_elements_worker): mock_elements_worker.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} } - elt = Element( - { - "id": "12341234-1234-1234-1234-123412341234", - "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, - } - ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, "http://testserver/api/v1/classifications/", @@ -423,12 +401,7 @@ def test_create_classification_duplicate(responses, mock_elements_worker): mock_elements_worker.classes = { "11111111-1111-1111-1111-111111111111": {"a_class": "0000"} } - elt = Element( - { - "id": "12341234-1234-1234-1234-123412341234", - "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, - } - ) + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, "http://testserver/api/v1/classifications/", -- GitLab