From 4195b24114df79d9dbb3e13558b21f369340a59e Mon Sep 17 00:00:00 2001
From: Bastien Abadie <bastien@nextcairn.com>
Date: Mon, 19 Oct 2020 09:36:52 +0000
Subject: [PATCH] Reload known ML classes when a 400 is received on creation

---
 arkindex_worker/worker.py     | 26 +++++++++++---
 tests/test_elements_worker.py | 64 +++++++++++++++++++++++++++++++++++
 2 files changed, 85 insertions(+), 5 deletions(-)

diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index fabd80b1..fc0b54bf 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -8,6 +8,7 @@ import uuid
 from enum import Enum
 from pathlib import Path
 
+import apistar
 import gnupg
 import yaml
 from apistar.exceptions import ErrorResponse
@@ -296,11 +297,26 @@ class ElementsWorker(BaseWorker):
         ml_class_id = self.classes[corpus_id].get(ml_class)
         if ml_class_id is None:
             logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}")
-            response = self.api_client.request(
-                "CreateMLClass", id=corpus_id, body={"name": ml_class}
-            )
-            ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
-            logger.debug(f"Created ML class {ml_class_id}")
+            try:
+                response = self.api_client.request(
+                    "CreateMLClass", id=corpus_id, body={"name": ml_class}
+                )
+                ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
+                logger.debug(f"Created ML class {response['id']}")
+            except apistar.exceptions.ErrorResponse as e:
+                # Only reload for 400 errors
+                if e.status_code != 400:
+                    raise
+
+                # Reload and make sure we have the class
+                logger.info(
+                    f"Reloading corpus classes to see if {ml_class} already exists"
+                )
+                self.load_corpus_classes(corpus_id)
+                assert (
+                    ml_class in self.classes[corpus_id]
+                ), "Missing class {ml_class} even after reloading"
+                ml_class_id = self.classes[corpus_id][ml_class]
 
         return ml_class_id
 
diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py
index 3b7ebf02..0794f89c 100644
--- a/tests/test_elements_worker.py
+++ b/tests/test_elements_worker.py
@@ -370,6 +370,70 @@ def test_get_ml_class_id(mock_elements_worker):
     assert ml_class_id == "0000"
 
 
+def test_get_ml_class_reload(responses, mock_elements_worker):
+    corpus_id = "12341234-1234-1234-1234-123412341234"
+
+    # Add some initial classes
+    responses.add(
+        responses.GET,
+        f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
+        json={
+            "results": [
+                {
+                    "id": "class1_id",
+                    "name": "class1",
+                }
+            ]
+        },
+    )
+
+    # Invalid response when trying to create class2
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
+        status=400,
+        json={"non_field_errors": "Already exists"},
+    )
+
+    # Add both classes (class2 is created by another process)
+    responses.add(
+        responses.GET,
+        f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
+        json={
+            "results": [
+                {
+                    "id": "class1_id",
+                    "name": "class1",
+                },
+                {
+                    "id": "class2_id",
+                    "name": "class2",
+                },
+            ]
+        },
+    )
+
+    # Simply request class 2, it should be reloaded
+    assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id"
+
+    assert len(responses.calls) == 4
+    assert mock_elements_worker.classes == {
+        corpus_id: {
+            "class1": "class1_id",
+            "class2": "class2_id",
+        }
+    }
+    assert [(call.request.method, call.request.url) for call in responses.calls] == [
+        (
+            "GET",
+            "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        ),
+        ("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/?page=1"),
+        ("POST", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
+        ("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/?page=1"),
+    ]
+
+
 def test_create_sub_element_wrong_element(mock_elements_worker):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_sub_element(
-- 
GitLab