Skip to content
Snippets Groups Projects
Commit 4195b241 authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Reload known ML classes when a 400 is received on creation

parent a9aad29f
No related branches found
No related tags found
1 merge request!30Reload known ML classes when a 400 is received on creation
Pipeline #78001 passed
......@@ -8,6 +8,7 @@ import uuid
from enum import Enum
from pathlib import Path
import apistar
import gnupg
import yaml
from apistar.exceptions import ErrorResponse
......@@ -296,11 +297,26 @@ class ElementsWorker(BaseWorker):
ml_class_id = self.classes[corpus_id].get(ml_class)
if ml_class_id is None:
logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}")
response = self.api_client.request(
"CreateMLClass", id=corpus_id, body={"name": ml_class}
)
ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
logger.debug(f"Created ML class {ml_class_id}")
try:
response = self.api_client.request(
"CreateMLClass", id=corpus_id, body={"name": ml_class}
)
ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
logger.debug(f"Created ML class {response['id']}")
except apistar.exceptions.ErrorResponse as e:
# Only reload for 400 errors
if e.status_code != 400:
raise
# Reload and make sure we have the class
logger.info(
f"Reloading corpus classes to see if {ml_class} already exists"
)
self.load_corpus_classes(corpus_id)
assert (
ml_class in self.classes[corpus_id]
), "Missing class {ml_class} even after reloading"
ml_class_id = self.classes[corpus_id][ml_class]
return ml_class_id
......
......@@ -370,6 +370,70 @@ def test_get_ml_class_id(mock_elements_worker):
assert ml_class_id == "0000"
def test_get_ml_class_reload(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
# Add some initial classes
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
json={
"results": [
{
"id": "class1_id",
"name": "class1",
}
]
},
)
# Invalid response when trying to create class2
responses.add(
responses.POST,
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
status=400,
json={"non_field_errors": "Already exists"},
)
# Add both classes (class2 is created by another process)
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
json={
"results": [
{
"id": "class1_id",
"name": "class1",
},
{
"id": "class2_id",
"name": "class2",
},
]
},
)
# Simply request class 2, it should be reloaded
assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id"
assert len(responses.calls) == 4
assert mock_elements_worker.classes == {
corpus_id: {
"class1": "class1_id",
"class2": "class2_id",
}
}
assert [(call.request.method, call.request.url) for call in responses.calls] == [
(
"GET",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/?page=1"),
("POST", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/?page=1"),
]
def test_create_sub_element_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_sub_element(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment