diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index fbec01a38d6e66a80e20d390eb75d58be5922256..f3e9cbfe4c5bf8446f0a0a2b371c759537b3224e 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -299,6 +299,7 @@ class ElementsWorker(BaseWorker):
                 response = self.api_client.request(
                     "CreateMLClass", id=corpus_id, body={"name": ml_class}
                 )
+                ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
                 logger.debug(f"Created ML class {response['id']}")
             except apistar.exceptions.ErrorResponse as e:
                 # Only reload for 400 errors
@@ -306,12 +307,14 @@ class ElementsWorker(BaseWorker):
                     raise
 
                 # Reload and make sure we have the class
+                logger.info(
+                    f"Reloading corpus classes to see if {ml_class} already exists"
+                )
                 self.load_corpus_classes(corpus_id)
                 assert (
                     ml_class in self.classes[corpus_id]
                 ), "Missing class {ml_class} even after reloading"
-
-            ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
+                ml_class_id = self.classes[corpus_id][ml_class]
 
         return ml_class_id
 
diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py
index 8d0bc4114537bc55e37e6cb9aeadeaaac83ad2d0..6316cf76d584139ccdc0b89a10fbb3b32a5a1a00 100644
--- a/tests/test_elements_worker.py
+++ b/tests/test_elements_worker.py
@@ -367,6 +367,70 @@ def test_get_ml_class_id(mock_elements_worker):
     assert ml_class_id == "0000"
 
 
+def test_get_ml_class_reload(responses, mock_elements_worker):
+    corpus_id = "12341234-1234-1234-1234-123412341234"
+
+    # Add some initial classes
+    responses.add(
+        responses.GET,
+        f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
+        json={
+            "results": [
+                {
+                    "id": "class1_id",
+                    "name": "class1",
+                }
+            ]
+        },
+    )
+
+    # Invalid response when trying to create class2
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
+        status=400,
+        json={"non_field_errors": "Already exists"},
+    )
+
+    # Add both classes (class2 is created by another process)
+    responses.add(
+        responses.GET,
+        f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
+        json={
+            "results": [
+                {
+                    "id": "class1_id",
+                    "name": "class1",
+                },
+                {
+                    "id": "class2_id",
+                    "name": "class2",
+                },
+            ]
+        },
+    )
+
+    # Simply request class 2, it should be reloaded
+    assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id"
+
+    assert len(responses.calls) == 4
+    assert mock_elements_worker.classes == {
+        corpus_id: {
+            "class1": "class1_id",
+            "class2": "class2_id",
+        }
+    }
+    assert [(call.request.method, call.request.url) for call in responses.calls] == [
+        (
+            "GET",
+            "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        ),
+        ("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/?page=1"),
+        ("POST", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
+        ("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/?page=1"),
+    ]
+
+
 def test_create_sub_element_wrong_element(mock_elements_worker):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_sub_element(