Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
......@@ -532,6 +532,63 @@ class ElementsWorker(BaseWorker):
f"Couldn't save created transcription in local cache: {e}"
)
def create_transcriptions(self, transcriptions):
"""
Create multiple transcriptions at once on existing elements through the API.
"""
assert transcriptions and isinstance(
transcriptions, list
), "transcriptions shouldn't be null and should be of type list"
for index, transcription in enumerate(transcriptions):
element_id = transcription.get("element_id")
assert element_id and isinstance(
element_id, str
), f"Transcription at index {index} in transcriptions: element_id shouldn't be null and should be of type str"
text = transcription.get("text")
assert text and isinstance(
text, str
), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str"
score = transcription.get("score")
assert (
score is not None and isinstance(score, float) and 0 <= score <= 1
), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
created_trs = self.api_client.request(
"CreateTranscriptions",
body={
"worker_version": self.worker_version_id,
"transcriptions": transcriptions,
},
)["transcriptions"]
for created_tr in created_trs:
self.report.add_transcription(created_tr["element_id"])
if self.cache:
# Store transcriptions in local cache
try:
to_insert = [
CachedTranscription(
id=convert_str_uuid_to_hex(created_tr["id"]),
element_id=convert_str_uuid_to_hex(created_tr["element_id"]),
text=created_tr["text"],
confidence=created_tr["confidence"],
worker_version_id=convert_str_uuid_to_hex(
self.worker_version_id
),
)
for created_tr in created_trs
]
self.cache.insert("transcriptions", to_insert)
except sqlite3.IntegrityError as e:
logger.warning(
f"Couldn't save created transcriptions in local cache: {e}"
)
def create_classification(
self, element, ml_class, confidence, high_confidence=False
):
......@@ -687,7 +744,7 @@ class ElementsWorker(BaseWorker):
f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation"
)
self.report.add_element(element.id, sub_element_type)
self.report.add_transcription(annotation["id"])
self.report.add_transcription(annotation["element_id"])
if self.cache:
# Store transcriptions and their associated element (if created) in local cache
......@@ -698,8 +755,11 @@ class ElementsWorker(BaseWorker):
worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id)
for index, annotation in enumerate(annotations):
transcription = transcriptions[index]
element_id_hex = convert_str_uuid_to_hex(annotation["id"])
if annotation["created"] and annotation["id"] not in created_ids:
element_id_hex = convert_str_uuid_to_hex(annotation["element_id"])
if (
annotation["created"]
and annotation["element_id"] not in created_ids
):
elements_to_insert.append(
CachedElement(
id=element_id_hex,
......@@ -709,12 +769,11 @@ class ElementsWorker(BaseWorker):
worker_version_id=worker_version_id_hex,
)
)
created_ids.append(annotation["id"])
created_ids.append(annotation["element_id"])
transcriptions_to_insert.append(
CachedTranscription(
# TODO: Retrieve real transcription_id through API
id=convert_str_uuid_to_hex(str(uuid.uuid4())),
id=convert_str_uuid_to_hex(annotation["id"]),
element_id=element_id_hex,
text=transcription["text"],
confidence=transcription["score"],
......
......@@ -189,6 +189,349 @@ def test_create_transcription(responses, mock_elements_worker_with_cache):
]
def test_create_transcriptions_wrong_transcriptions(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=None,
)
assert str(e.value) == "transcriptions shouldn't be null and should be of type list"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=1234,
)
assert str(e.value) == "transcriptions shouldn't be null and should be of type list"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"text": "word",
"score": 0.5,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: element_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": None,
"text": "word",
"score": 0.5,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: element_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": 1234,
"text": "word",
"score": 0.5,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: element_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"score": 0.5,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": None,
"score": 0.5,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": 1234,
"score": 0.5,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: text shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": None,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": "a wrong score",
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": 0,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcriptions(
transcriptions=[
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": 2.00,
},
],
)
assert (
str(e.value)
== "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
)
def test_create_transcriptions_api_error(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/bulk/",
status=500,
)
trans = [
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": 0.42,
},
]
with pytest.raises(ErrorResponse):
mock_elements_worker.create_transcriptions(transcriptions=trans)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/transcription/bulk/",
]
def test_create_transcriptions(responses, mock_elements_worker_with_cache):
trans = [
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"score": 0.75,
},
{
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"score": 0.42,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/bulk/",
status=200,
json={
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": [
{
"id": "00000000-0000-0000-0000-000000000000",
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "The",
"confidence": 0.75,
},
{
"id": "11111111-1111-1111-1111-111111111111",
"element_id": "11111111-1111-1111-1111-111111111111",
"text": "word",
"confidence": 0.42,
},
],
},
)
mock_elements_worker_with_cache.create_transcriptions(
transcriptions=trans,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/transcription/bulk/",
]
assert json.loads(responses.calls[2].request.body) == {
"worker_version": "12341234-1234-1234-1234-123412341234",
"transcriptions": trans,
}
# Check that created transcriptions were properly stored in SQLite cache
cache_path = f"{CACHE_DIR}/db.sqlite"
assert os.path.isfile(cache_path)
rows = mock_elements_worker_with_cache.cache.cursor.execute(
"SELECT * FROM transcriptions"
).fetchall()
assert [CachedTranscription(**dict(row)) for row in rows] == [
CachedTranscription(
id=convert_str_uuid_to_hex("00000000-0000-0000-0000-000000000000"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="The",
confidence=0.75,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
CachedTranscription(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="word",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"12341234-1234-1234-1234-123412341234"
),
),
]
def test_create_element_transcriptions_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_element_transcriptions(
......@@ -582,26 +925,28 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
]
def test_create_element_transcriptions(
mocker, responses, mock_elements_worker_with_cache
):
mocker.patch(
"uuid.uuid4",
side_effect=[
"56785678-5678-5678-5678-567856785678",
"67896789-6789-6789-6789-678967896789",
"78907890-7890-7890-7890-789078907890",
],
)
def test_create_element_transcriptions(responses, mock_elements_worker_with_cache):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
responses.add(
responses.POST,
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
status=200,
json=[
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{"id": "22222222-2222-2222-2222-222222222222", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
],
)
......@@ -625,9 +970,21 @@ def test_create_element_transcriptions(
"return_elements": True,
}
assert annotations == [
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{"id": "22222222-2222-2222-2222-222222222222", "created": False},
{"id": "11111111-1111-1111-1111-111111111111", "created": True},
{
"id": "56785678-5678-5678-5678-567856785678",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
{
"id": "67896789-6789-6789-6789-678967896789",
"element_id": "22222222-2222-2222-2222-222222222222",
"created": False,
},
{
"id": "78907890-7890-7890-7890-789078907890",
"element_id": "11111111-1111-1111-1111-111111111111",
"created": True,
},
]
# Check that created transcriptions and elements were properly stored in SQLite cache
......