Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (5)
......@@ -198,6 +198,9 @@ class BaseWorker(object):
"'ARKINDEX_CORPUS_ID' was not set in the environment. Any API request involving a `corpus_id` will fail."
)
# Define model_version_id from environment
self.model_version_id = os.environ.get("ARKINDEX_MODEL_VERSION_ID")
# Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets}
......@@ -258,6 +261,9 @@ class BaseWorker(object):
logger.info("Loaded model version configuration from WorkerRun")
self.model_configuration.update(model_version.get("configuration"))
# Set model_version ID as worker attribute
self.model_version_id = model_version.get("id")
# if debug mode is set to true activate debug mode in logger
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG)
......
......@@ -3,7 +3,7 @@
ElementsWorker methods for entities.
"""
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, TypedDict, Union
from peewee import IntegrityError
......@@ -11,6 +11,10 @@ from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity
from arkindex_worker.models import Element, Transcription
Entity = TypedDict(
"Entity", name=str, type_id=str, length=int, offset=int, confidence=Optional[float]
)
class MissingEntityType(Exception):
"""
......@@ -108,7 +112,7 @@ class EntityMixin(object):
"worker_run_id": self.worker_run_id,
},
)
self.report.add_entity(element.id, entity["id"], type, name)
self.report.add_entity(element.id, entity["id"], entity_type_id, name)
if self.use_cache:
# Store entity in local cache
......@@ -204,6 +208,110 @@ class EntityMixin(object):
)
return transcription_ent
def create_transcription_entities(
self,
transcription: Transcription,
entities: List[Entity],
) -> List[Dict[str, str]]:
"""
Create multiple entities attached to a transcription in a single API request.
:param transcription: Transcription to create the entity on.
:param entities: List of dicts, one per element. Each dict can have the following keys:
name (str)
Required. Name of the entity.
type_id (str)
Required. ID of the EntityType of the entity.
length (int)
Required. Length of the entity in the transcription's text.
offset (int)
Required. Starting position of the entity in the transcription's text, as a 0-based index.
confidence (float or None)
Optional confidence score, between 0.0 and 1.0.
:return: List of dicts, with each dict having a two keys, `transcription_entity_id` and `entity_id`, holding the UUID of each created object.
"""
assert transcription and isinstance(
transcription, Transcription
), "transcription shouldn't be null and should be of type Transcription"
# Needed for MLreport
assert (
hasattr(transcription, "element") and transcription.element
), f"No element linked to {transcription}"
assert entities and isinstance(
entities, list
), "entities shouldn't be null and should be of type list"
for index, entity in enumerate(entities):
assert isinstance(
entity, dict
), f"Entity at index {index} in entities: Should be of type dict"
name = entity.get("name")
assert name and isinstance(
name, str
), f"Entity at index {index} in entities: name shouldn't be null and should be of type str"
type_id = entity.get("type_id")
assert type_id and isinstance(
type_id, str
), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str"
offset = entity.get("offset")
assert (
offset is not None and isinstance(offset, int) and offset >= 0
), f"Entity at index {index} in entities: offset shouldn't be null and should be a positive integer"
length = entity.get("length")
assert (
length is not None and isinstance(length, int) and length > 0
), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer"
confidence = entity.get("confidence")
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range"
if self.is_read_only:
logger.warning(
"Cannot create transcription entities in bulk as this worker is in read-only mode"
)
return
created_ids = self.request(
"CreateTranscriptionEntities",
id=transcription.id,
body={
"worker_run_id": self.worker_run_id,
"entities": entities,
},
)
for entity, created_objects in zip(entities, created_ids["entities"]):
# Report entity creation
self.report.add_entity(
transcription.element.id,
created_objects["entity_id"],
entity.get("type_id"),
entity.get("name"),
)
# Report transcription entity creation
self.report.add_transcription_entity(
created_objects["entity_id"],
transcription,
created_objects["transcription_entity_id"],
)
return created_ids["entities"]
def list_transcription_entities(
self,
transcription: Transcription,
......
black==22.12.0
doc8==1.1.1
mkdocs==1.4.2
mkdocs-material==9.0.10
mkdocs-material==9.0.15
mkdocstrings==0.20.0
mkdocstrings-python==0.8.3
recommonmark==0.7.1
......@@ -126,6 +126,9 @@ Many attributes are set on the worker during at the configuration stage. Here is
`model_configuration`
: The parsed configuration as stored in the `ModelVersion` object on Arkindex.
`model_version_id`
: The ID of the model version linked to the current `WorkerRun` object on Arkindex. You may set it in developer mode via the `ARKINDEX_MODEL_VERSION_ID` environment variable.
`process_information`
: The details about the process parent to this worker execution. Only set in Arkindex mode.
......
arkindex-client==1.0.11
peewee==3.15.4
peewee==3.16.0
Pillow==9.4.0
pymdown-extensions==9.9.2
python-gitlab==3.13.0
......
pytest==7.2.1
pytest-mock==3.10.0
pytest-responses==0.5.1
requests==2.28.1
......@@ -494,6 +494,7 @@ def test_configure_load_model_configuration(mocker, monkeypatch, responses):
"param2": 2,
"param3": None,
}
assert worker.model_version_id == "12341234-1234-1234-1234-123412341234"
def test_load_missing_secret():
......
......@@ -893,3 +893,239 @@ def test_check_required_entity_types_no_creation_allowed(
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS
@pytest.mark.parametrize("transcription", (None, "not a transcription", 1))
def test_create_transcription_entities_wrong_transcription(
mock_elements_worker, transcription
):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=transcription,
entities=[],
)
assert (
str(e.value)
== "transcription shouldn't be null and should be of type Transcription"
)
def test_create_transcription_entities_no_transcription_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(id="transcription_id"),
entities=[],
)
assert str(e.value) == "No element linked to Transcription (transcription_id)"
@pytest.mark.parametrize("entities", (None, "not a list of entities", 1))
def test_create_transcription_entities_wrong_entities(mock_elements_worker, entities):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(
id="transcription_id", element={"id": "element_id"}
),
entities=entities,
)
assert str(e.value) == "entities shouldn't be null and should be of type list"
def test_create_transcription_entities_wrong_entities_subtype(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(
id="transcription_id", element={"id": "element_id"}
),
entities=["not a dict"],
)
assert str(e.value) == "Entity at index 0 in entities: Should be of type dict"
@pytest.mark.parametrize(
"entity, error",
(
(
{
"name": None,
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 1,
"confidence": 0.5,
},
"Entity at index 0 in entities: name shouldn't be null and should be of type str",
),
(
{"name": "A", "type_id": None, "offset": 0, "length": 1, "confidence": 0.5},
"Entity at index 0 in entities: type_id shouldn't be null and should be of type str",
),
(
{"name": "A", "type_id": 0, "offset": 0, "length": 1, "confidence": 0.5},
"Entity at index 0 in entities: type_id shouldn't be null and should be of type str",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": None,
"length": 1,
"confidence": 0.5,
},
"Entity at index 0 in entities: offset shouldn't be null and should be a positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": -2,
"length": 1,
"confidence": 0.5,
},
"Entity at index 0 in entities: offset shouldn't be null and should be a positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": None,
"confidence": 0.5,
},
"Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 0,
"confidence": 0.5,
},
"Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 1,
"confidence": "not None or a float",
},
"Entity at index 0 in entities: confidence should be None or a float in [0..1] range",
),
(
{
"name": "A",
"type_id": "12341234-1234-1234-1234-123412341234",
"offset": 0,
"length": 1,
"confidence": 1.3,
},
"Entity at index 0 in entities: confidence should be None or a float in [0..1] range",
),
),
)
def test_create_transcription_entities_wrong_entity(
mock_elements_worker, entity, error
):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entities(
transcription=Transcription(
id="transcription_id", element={"id": "element_id"}
),
entities=[entity],
)
assert str(e.value) == error
def test_create_transcription_entities(responses, mock_elements_worker):
element_id = "element_id"
transcription = Transcription(id="transcription-id", element={"id": element_id})
# Call to Transcription entities creation in bulk
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
status=201,
match=[
matchers.json_params_matcher(
{
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"entities": [
{
"name": "Teklia",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 0,
"length": 6,
"confidence": 1.0,
}
],
}
)
],
json={
"entities": [
{
"transcription_entity_id": "transc-entity-id",
"entity_id": "entity-id",
}
]
},
)
# Store entity type/slug correspondence on the worker
mock_elements_worker.entity_types = {
"22222222-2222-2222-2222-222222222222": "organization"
}
created_objects = mock_elements_worker.create_transcription_entities(
transcription=transcription,
entities=[
{
"name": "Teklia",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 0,
"length": 6,
"confidence": 1.0,
}
],
)
assert len(created_objects) == 1
assert element_id in mock_elements_worker.report.report_data["elements"]
ml_report = mock_elements_worker.report.report_data["elements"][element_id]
assert "started" in ml_report
del ml_report["started"]
# Check reporting
assert ml_report == {
"elements": {},
"transcriptions": 0,
"classifications": {},
"entities": [
{
"id": "entity-id",
"type": "22222222-2222-2222-2222-222222222222",
"name": "Teklia",
}
],
"transcription_entities": [
{
"transcription_id": "transcription-id",
"entity_id": "entity-id",
"transcription_entity_id": "transc-entity-id",
}
],
"metadata": [],
"errors": [],
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
),
]
......@@ -167,7 +167,10 @@ def test_add_transcription_count():
def test_add_entity():
reporter = Reporter("worker")
reporter.add_entity(
"myelement", "12341234-1234-1234-1234-123412341234", "person", "Bob Bob"
"myelement",
"12341234-1234-1234-1234-123412341234",
"person-entity-type-id",
"Bob Bob",
)
assert "myelement" in reporter.report_data["elements"]
element_data = reporter.report_data["elements"]["myelement"]
......@@ -179,7 +182,7 @@ def test_add_entity():
"entities": [
{
"id": "12341234-1234-1234-1234-123412341234",
"type": "person",
"type": "person-entity-type-id",
"name": "Bob Bob",
}
],
......