Skip to content
Snippets Groups Projects
Commit 73fd76ca authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Eva Bardou
Browse files

Remove batch publication on CreateTranscriptionEntities helper

parent c4c63979
No related branches found
No related tags found
1 merge request!602Remove batch publication on CreateTranscriptionEntities helper
Pipeline #202409 passed
......@@ -16,9 +16,6 @@ from arkindex_worker.cache import (
)
from arkindex_worker.models import Element, Transcription
from arkindex_worker.utils import (
DEFAULT_BATCH_SIZE,
batch_publication,
make_batches,
pluralize,
)
......@@ -219,12 +216,10 @@ class EntityMixin:
return transcription_ent
@unsupported_cache
@batch_publication
def create_transcription_entities(
self,
transcription: Transcription,
entities: list[Entity],
batch_size: int = DEFAULT_BATCH_SIZE,
) -> list[dict[str, str]]:
"""
Create multiple entities attached to a transcription in a single API request.
......@@ -247,8 +242,6 @@ class EntityMixin:
confidence (float or None)
Optional confidence score, between 0.0 and 1.0.
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
:return: List of dicts, with each dict having a two keys, `transcription_entity_id` and `entity_id`, holding the UUID of each created object.
"""
assert transcription and isinstance(
......@@ -300,18 +293,14 @@ class EntityMixin:
)
return
created_entities = [
created_entity
for batch in make_batches(entities, "entity", batch_size)
for created_entity in self.api_client.request(
"CreateTranscriptionEntities",
id=transcription.id,
body={
"worker_run_id": self.worker_run_id,
"entities": batch,
},
)["entities"]
]
created_entities = self.api_client.request(
"CreateTranscriptionEntities",
id=transcription.id,
body={
"worker_run_id": self.worker_run_id,
"entities": entities,
},
)["entities"]
return created_entities
......
......@@ -13,7 +13,6 @@ from arkindex_worker.cache import (
CachedTranscriptionEntity,
)
from arkindex_worker.models import Transcription
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
from arkindex_worker.worker.transcription import TextOrientation
from tests import CORPUS_ID
......@@ -836,89 +835,50 @@ def test_create_transcription_entities_wrong_entity(
)
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
def test_create_transcription_entities(batch_size, responses, mock_elements_worker):
def test_create_transcription_entities(responses, mock_elements_worker):
transcription = Transcription(id="transcription-id")
# Call to Transcription entities creation in bulk
if batch_size > 1:
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
status=201,
match=[
matchers.json_params_matcher(
{
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"entities": [
{
"name": "Teklia",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 0,
"length": 6,
"confidence": 1.0,
},
{
"name": "Team Rocket",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 7,
"length": 11,
"confidence": 1.0,
},
],
}
)
],
json={
"entities": [
{
"transcription_entity_id": "transc-entity-id",
"entity_id": "entity-id1",
},
{
"transcription_entity_id": "transc-entity-id",
"entity_id": "entity-id2",
},
]
},
)
else:
for idx, (name, offset, length) in enumerate(
[
("Teklia", 0, 6),
("Team Rocket", 7, 11),
],
start=1,
):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
status=201,
match=[
matchers.json_params_matcher(
{
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"entities": [
{
"name": name,
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": offset,
"length": length,
"confidence": 1.0,
}
],
}
)
],
json={
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
status=201,
match=[
matchers.json_params_matcher(
{
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"entities": [
{
"transcription_entity_id": "transc-entity-id",
"entity_id": f"entity-id{idx}",
}
]
},
"name": "Teklia",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 0,
"length": 6,
"confidence": 1.0,
},
{
"name": "Team Rocket",
"type_id": "22222222-2222-2222-2222-222222222222",
"offset": 7,
"length": 11,
"confidence": 1.0,
},
],
}
)
],
json={
"entities": [
{
"transcription_entity_id": "transc-entity-id",
"entity_id": "entity-id1",
},
{
"transcription_entity_id": "transc-entity-id",
"entity_id": "entity-id2",
},
]
},
)
# Store entity type/slug correspondence on the worker
mock_elements_worker.entity_types = {
......@@ -942,26 +902,16 @@ def test_create_transcription_entities(batch_size, responses, mock_elements_work
"confidence": 1.0,
},
],
batch_size=batch_size,
)
assert len(created_objects) == 2
bulk_api_calls = [
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
)
]
if batch_size != DEFAULT_BATCH_SIZE:
bulk_api_calls.append(
(
"POST",
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
)
)
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + bulk_api_calls
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment