Skip to content
Snippets Groups Projects
test_entities.py 33.07 KiB
# -*- coding: utf-8 -*-
import json
import re
from uuid import UUID

import pytest
from apistar.exceptions import ErrorResponse
from responses import matchers

from arkindex_worker.cache import (
    CachedElement,
    CachedEntity,
    CachedTranscription,
    CachedTranscriptionEntity,
)
from arkindex_worker.models import Transcription
from arkindex_worker.worker.entity import MissingEntityType
from arkindex_worker.worker.transcription import TextOrientation

from . import BASE_API_CALLS


def test_create_entity_wrong_name(mock_elements_worker):
    with pytest.raises(
        AssertionError, match="name shouldn't be null and should be of type str"
    ):
        mock_elements_worker.create_entity(
            name=None,
            type="person",
        )

    with pytest.raises(
        AssertionError, match="name shouldn't be null and should be of type str"
    ):
        mock_elements_worker.create_entity(
            name=1234,
            type="person",
        )


def test_create_entity_wrong_type(mock_elements_worker):
    with pytest.raises(
        AssertionError, match="type shouldn't be null and should be of type str"
    ):
        mock_elements_worker.create_entity(
            name="Bob Bob",
            type=None,
        )

    with pytest.raises(
        AssertionError, match="type shouldn't be null and should be of type str"
    ):
        mock_elements_worker.create_entity(
            name="Bob Bob",
            type=1234,
        )


def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker):
    # Triggering an error on metas param, not giving corpus should work since
    # ARKINDEX_CORPUS_ID environment variable is set on mock_elements_worker
    with pytest.raises(AssertionError, match="metas should be of type dict"):
        mock_elements_worker.create_entity(
            name="Bob Bob",
            type="person",
            metas="wrong metas",
        )


def test_create_entity_wrong_metas(mock_elements_worker):
    with pytest.raises(AssertionError, match="metas should be of type dict"):
        mock_elements_worker.create_entity(
            name="Bob Bob",
            type="person",
            metas="wrong metas",
        )


def test_create_entity_wrong_validated(mock_elements_worker):
    with pytest.raises(AssertionError, match="validated should be of type bool"):
        mock_elements_worker.create_entity(
            name="Bob Bob",
            type="person",
            validated="wrong validated",
        )


def test_create_entity_api_error(responses, mock_elements_worker):
    # Set one entity type
    mock_elements_worker.entity_types = {"person": "person-entity-type-id"}
    responses.add(
        responses.POST,
        "http://testserver/api/v1/entity/",
        status=500,
    )

    with pytest.raises(ErrorResponse):
        mock_elements_worker.create_entity(
            name="Bob Bob",
            type="person",
        )

    assert len(responses.calls) == len(BASE_API_CALLS) + 5
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        # We retry 5 times the API call
        ("POST", "http://testserver/api/v1/entity/"),
        ("POST", "http://testserver/api/v1/entity/"),
        ("POST", "http://testserver/api/v1/entity/"),
        ("POST", "http://testserver/api/v1/entity/"),
        ("POST", "http://testserver/api/v1/entity/"),
    ]


def test_create_entity(responses, mock_elements_worker):
    # Set one entity type
    mock_elements_worker.entity_types = {"person": "person-entity-type-id"}

    responses.add(
        responses.POST,
        "http://testserver/api/v1/entity/",
        status=200,
        json={"id": "12345678-1234-1234-1234-123456789123"},
    )

    entity_id = mock_elements_worker.create_entity(
        name="Bob Bob",
        type="person",
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        ("POST", "http://testserver/api/v1/entity/"),
    ]
    assert json.loads(responses.calls[-1].request.body) == {
        "name": "Bob Bob",
        "type_id": "person-entity-type-id",
        "metas": {},
        "validated": None,
        "corpus": "11111111-1111-1111-1111-111111111111",
        "worker_run_id": "56785678-5678-5678-5678-567856785678",
    }
    assert entity_id == "12345678-1234-1234-1234-123456789123"


def test_create_entity_missing_type(responses, mock_elements_worker):
    """
    Create entity with an unknown type will fail.
    """
    # Call to list entity types
    responses.add(
        responses.GET,
        "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/entity-types/",
        status=200,
        json={
            "count": 1,
            "next": None,
            "results": [
                {"id": "person-entity-type-id", "name": "person", "color": "00d1b2"}
            ],
        },
    )

    with pytest.raises(
        AssertionError, match="Entity type `new-entity` not found in the corpus."
    ):
        mock_elements_worker.create_entity(
            name="Bob Bob",
            type="new-entity",
        )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "GET",
            "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/entity-types/",
        ),
    ]


def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
    # Set one entity type
    mock_elements_worker_with_cache.entity_types = {"person": "person-entity-type-id"}
    responses.add(
        responses.POST,
        "http://testserver/api/v1/entity/",
        status=200,
        json={"id": "12345678-1234-1234-1234-123456789123"},
    )

    entity_id = mock_elements_worker_with_cache.create_entity(
        name="Bob Bob",
        type="person",
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        ("POST", "http://testserver/api/v1/entity/"),
    ]

    assert json.loads(responses.calls[-1].request.body) == {
        "name": "Bob Bob",
        "type_id": "person-entity-type-id",
        "metas": {},
        "validated": None,
        "corpus": "11111111-1111-1111-1111-111111111111",
        "worker_run_id": "56785678-5678-5678-5678-567856785678",
    }
    assert entity_id == "12345678-1234-1234-1234-123456789123"

    # Check that created entity was properly stored in SQLite cache
    assert list(CachedEntity.select()) == [
        CachedEntity(
            id=UUID("12345678-1234-1234-1234-123456789123"),
            type="person",
            name="Bob Bob",
            validated=False,
            metas={},
            worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
        )
    ]


def test_create_transcription_entity_wrong_transcription(mock_elements_worker):
    with pytest.raises(
        AssertionError,
        match="transcription shouldn't be null and should be a Transcription",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=None,
            entity="11111111-1111-1111-1111-111111111111",
            offset=5,
            length=10,
        )

    with pytest.raises(
        AssertionError,
        match="transcription shouldn't be null and should be a Transcription",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=1234,
            entity="11111111-1111-1111-1111-111111111111",
            offset=5,
            length=10,
        )


def test_create_transcription_entity_wrong_entity(mock_elements_worker):
    with pytest.raises(
        AssertionError, match="entity shouldn't be null and should be of type str"
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity=None,
            offset=5,
            length=10,
        )

    with pytest.raises(
        AssertionError, match="entity shouldn't be null and should be of type str"
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity=1234,
            offset=5,
            length=10,
        )


def test_create_transcription_entity_wrong_offset(mock_elements_worker):
    with pytest.raises(
        AssertionError,
        match="offset shouldn't be null and should be a positive integer",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity="11111111-1111-1111-1111-111111111111",
            offset=None,
            length=10,
        )

    with pytest.raises(
        AssertionError,
        match="offset shouldn't be null and should be a positive integer",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity="11111111-1111-1111-1111-111111111111",
            offset="not an int",
            length=10,
        )

    with pytest.raises(
        AssertionError,
        match="offset shouldn't be null and should be a positive integer",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity="11111111-1111-1111-1111-111111111111",
            offset=-1,
            length=10,
        )


def test_create_transcription_entity_wrong_length(mock_elements_worker):
    with pytest.raises(
        AssertionError,
        match="length shouldn't be null and should be a strictly positive integer",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity="11111111-1111-1111-1111-111111111111",
            offset=5,
            length=None,
        )

    with pytest.raises(
        AssertionError,
        match="length shouldn't be null and should be a strictly positive integer",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity="11111111-1111-1111-1111-111111111111",
            offset=5,
            length="not an int",
        )

    with pytest.raises(
        AssertionError,
        match="length shouldn't be null and should be a strictly positive integer",
    ):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity="11111111-1111-1111-1111-111111111111",
            offset=5,
            length=0,
        )


def test_create_transcription_entity_api_error(responses, mock_elements_worker):
    responses.add(
        responses.POST,
        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        status=500,
    )

    with pytest.raises(ErrorResponse):
        mock_elements_worker.create_transcription_entity(
            transcription=Transcription(
                {
                    "id": "11111111-1111-1111-1111-111111111111",
                    "element": {"id": "myelement"},
                }
            ),
            entity="11111111-1111-1111-1111-111111111111",
            offset=5,
            length=10,
        )

    assert len(responses.calls) == len(BASE_API_CALLS) + 5
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        # We retry 5 times the API call
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
    ]


def test_create_transcription_entity_no_confidence(responses, mock_elements_worker):
    responses.add(
        responses.POST,
        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        status=200,
        json={
            "entity": "11111111-1111-1111-1111-111111111111",
            "offset": 5,
            "length": 10,
        },
    )

    mock_elements_worker.create_transcription_entity(
        transcription=Transcription(
            {
                "id": "11111111-1111-1111-1111-111111111111",
                "element": {"id": "myelement"},
            }
        ),
        entity="11111111-1111-1111-1111-111111111111",
        offset=5,
        length=10,
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
    ]
    assert json.loads(responses.calls[-1].request.body) == {
        "entity": "11111111-1111-1111-1111-111111111111",
        "offset": 5,
        "length": 10,
        "worker_run_id": "56785678-5678-5678-5678-567856785678",
    }


def test_create_transcription_entity_with_confidence(responses, mock_elements_worker):
    responses.add(
        responses.POST,
        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        status=200,
        json={
            "entity": "11111111-1111-1111-1111-111111111111",
            "offset": 5,
            "length": 10,
            "confidence": 0.33,
        },
    )

    mock_elements_worker.create_transcription_entity(
        transcription=Transcription(
            {
                "id": "11111111-1111-1111-1111-111111111111",
                "element": {"id": "myelement"},
            }
        ),
        entity="11111111-1111-1111-1111-111111111111",
        offset=5,
        length=10,
        confidence=0.33,
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
    ]
    assert json.loads(responses.calls[-1].request.body) == {
        "entity": "11111111-1111-1111-1111-111111111111",
        "offset": 5,
        "length": 10,
        "worker_run_id": "56785678-5678-5678-5678-567856785678",
        "confidence": 0.33,
    }


def test_create_transcription_entity_confidence_none(responses, mock_elements_worker):
    responses.add(
        responses.POST,
        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        status=200,
        json={
            "entity": "11111111-1111-1111-1111-111111111111",
            "offset": 5,
            "length": 10,
            "confidence": None,
        },
    )

    mock_elements_worker.create_transcription_entity(
        transcription=Transcription(
            {
                "id": "11111111-1111-1111-1111-111111111111",
                "element": {"id": "myelement"},
            }
        ),
        entity="11111111-1111-1111-1111-111111111111",
        offset=5,
        length=10,
        confidence=None,
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
    ]
    assert json.loads(responses.calls[-1].request.body) == {
        "entity": "11111111-1111-1111-1111-111111111111",
        "offset": 5,
        "length": 10,
        "worker_run_id": "56785678-5678-5678-5678-567856785678",
    }


def test_create_transcription_entity_with_cache(
    responses, mock_elements_worker_with_cache
):
    CachedElement.create(
        id=UUID("12341234-1234-1234-1234-123412341234"),
        type="page",
    )
    CachedTranscription.create(
        id=UUID("11111111-1111-1111-1111-111111111111"),
        element=UUID("12341234-1234-1234-1234-123412341234"),
        text="Hello, it's me.",
        confidence=0.42,
        orientation=TextOrientation.HorizontalLeftToRight,
        worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
    )
    CachedEntity.create(
        id=UUID("11111111-1111-1111-1111-111111111111"),
        type="person",
        name="Bob Bob",
        worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
    )

    responses.add(
        responses.POST,
        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        status=200,
        json={
            "entity": "11111111-1111-1111-1111-111111111111",
            "offset": 5,
            "length": 10,
        },
    )

    mock_elements_worker_with_cache.create_transcription_entity(
        transcription=Transcription(
            {
                "id": "11111111-1111-1111-1111-111111111111",
                "element": {"id": "myelement"},
            }
        ),
        entity="11111111-1111-1111-1111-111111111111",
        offset=5,
        length=10,
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
    ]
    assert json.loads(responses.calls[-1].request.body) == {
        "entity": "11111111-1111-1111-1111-111111111111",
        "offset": 5,
        "length": 10,
        "worker_run_id": "56785678-5678-5678-5678-567856785678",
    }
    # Check that created transcription entity was properly stored in SQLite cache
    assert list(CachedTranscriptionEntity.select()) == [
        CachedTranscriptionEntity(
            transcription=UUID("11111111-1111-1111-1111-111111111111"),
            entity=UUID("11111111-1111-1111-1111-111111111111"),
            offset=5,
            length=10,
            worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
        )
    ]


def test_create_transcription_entity_with_confidence_with_cache(
    responses, mock_elements_worker_with_cache
):
    CachedElement.create(
        id=UUID("12341234-1234-1234-1234-123412341234"),
        type="page",
    )
    CachedTranscription.create(
        id=UUID("11111111-1111-1111-1111-111111111111"),
        element=UUID("12341234-1234-1234-1234-123412341234"),
        text="Hello, it's me.",
        confidence=0.42,
        orientation=TextOrientation.HorizontalLeftToRight,
        worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
    )
    CachedEntity.create(
        id=UUID("11111111-1111-1111-1111-111111111111"),
        type="person",
        name="Bob Bob",
        worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
    )

    responses.add(
        responses.POST,
        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        status=200,
        json={
            "entity": "11111111-1111-1111-1111-111111111111",
            "offset": 5,
            "length": 10,
            "confidence": 0.77,
        },
    )

    mock_elements_worker_with_cache.create_transcription_entity(
        transcription=Transcription(
            {
                "id": "11111111-1111-1111-1111-111111111111",
                "element": {"id": "myelement"},
            }
        ),
        entity="11111111-1111-1111-1111-111111111111",
        offset=5,
        length=10,
        confidence=0.77,
    )

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "POST",
            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
        ),
    ]
    assert json.loads(responses.calls[-1].request.body) == {
        "entity": "11111111-1111-1111-1111-111111111111",
        "offset": 5,
        "length": 10,
        "worker_run_id": "56785678-5678-5678-5678-567856785678",
        "confidence": 0.77,
    }

    # Check that created transcription entity was properly stored in SQLite cache
    assert list(CachedTranscriptionEntity.select()) == [
        CachedTranscriptionEntity(
            transcription=UUID("11111111-1111-1111-1111-111111111111"),
            entity=UUID("11111111-1111-1111-1111-111111111111"),
            offset=5,
            length=10,
            worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
            confidence=0.77,
        )
    ]


def test_list_transcription_entities(fake_dummy_worker):
    transcription = Transcription({"id": "fake_transcription_id"})
    worker_version = "worker_version_id"
    fake_dummy_worker.api_client.add_response(
        "ListTranscriptionEntities",
        id=transcription.id,
        worker_version=worker_version,
        response={"id": "entity_id"},
    )
    assert fake_dummy_worker.list_transcription_entities(
        transcription, worker_version
    ) == {"id": "entity_id"}

    assert len(fake_dummy_worker.api_client.history) == 1
    assert len(fake_dummy_worker.api_client.responses) == 0


def test_list_corpus_entities(responses, mock_elements_worker):
    corpus_id = "11111111-1111-1111-1111-111111111111"
    responses.add(
        responses.GET,
        f"http://testserver/api/v1/corpus/{corpus_id}/entities/",
        json={
            "count": 1,
            "next": None,
            "results": [
                {
                    "id": "fake_entity_id",
                }
            ],
        },
    )

    # list is required to actually do the request
    assert list(mock_elements_worker.list_corpus_entities()) == [
        {
            "id": "fake_entity_id",
        }
    ]

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "GET",
            f"http://testserver/api/v1/corpus/{corpus_id}/entities/",
        ),
    ]


@pytest.mark.parametrize(
    "wrong_name",
    [
        1234,
        12.5,
    ],
)
def test_list_corpus_entities_wrong_name(mock_elements_worker, wrong_name):
    with pytest.raises(AssertionError, match="name should be of type str"):
        mock_elements_worker.list_corpus_entities(name=wrong_name)


@pytest.mark.parametrize(
    "wrong_parent",
    [{"id": "element_id"}, 12.5, "blabla"],
)
def test_list_corpus_entities_wrong_parent(mock_elements_worker, wrong_parent):
    with pytest.raises(AssertionError, match="parent should be of type Element"):
        mock_elements_worker.list_corpus_entities(parent=wrong_parent)


def test_check_required_entity_types(responses, mock_elements_worker):
    # Set one entity type
    mock_elements_worker.entity_types = {"person": "person-entity-type-id"}

    checked_types = ["person", "new-entity"]

    # Call to create new entity type
    responses.add(
        responses.POST,
        "http://testserver/api/v1/entity/types/",
        status=200,
        match=[
            matchers.json_params_matcher(
                {
                    "name": "new-entity",
                    "corpus": "11111111-1111-1111-1111-111111111111",
                }
            )
        ],
        json={
            "id": "new-entity-id",
            "corpus": "11111111-1111-1111-1111-111111111111",
            "name": "new-entity",
            "color": "ffd1b3",
        },
    )

    mock_elements_worker.check_required_entity_types(
        entity_types=checked_types,
    )

    # Make sure the entity_types entry has been updated
    assert mock_elements_worker.entity_types == {
        "person": "person-entity-type-id",
        "new-entity": "new-entity-id",
    }

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "POST",
            "http://testserver/api/v1/entity/types/",
        ),
    ]


def test_check_required_entity_types_no_creation_allowed(
    responses, mock_elements_worker
):
    # Set one entity type
    mock_elements_worker.entity_types = {"person": "person-entity-type-id"}

    checked_types = ["person", "new-entity"]

    with pytest.raises(
        MissingEntityType, match="Entity type `new-entity` was not in the corpus."
    ):
        mock_elements_worker.check_required_entity_types(
            entity_types=checked_types, create_missing=False
        )

    assert len(responses.calls) == len(BASE_API_CALLS)
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS


@pytest.mark.parametrize("transcription", (None, "not a transcription", 1))
def test_create_transcription_entities_wrong_transcription(
    mock_elements_worker, transcription
):
    with pytest.raises(
        AssertionError,
        match="transcription shouldn't be null and should be of type Transcription",
    ):
        mock_elements_worker.create_transcription_entities(
            transcription=transcription,
            entities=[],
        )


@pytest.mark.parametrize(
    "entities, error",
    (
        (None, "entities shouldn't be null and should be of type list"),
        (
            "not a list of entities",
            "entities shouldn't be null and should be of type list",
        ),
        (1, "entities shouldn't be null and should be of type list"),
        (
            [
                {
                    "name": "A",
                    "type_id": "12341234-1234-1234-1234-123412341234",
                    "offset": 0,
                    "length": 1,
                    "confidence": 0.5,
                }
            ]
            * 2,
            "entities should be unique",
        ),
    ),
)
def test_create_transcription_entities_wrong_entities(
    mock_elements_worker, entities, error
):
    with pytest.raises(AssertionError, match=error):
        mock_elements_worker.create_transcription_entities(
            transcription=Transcription(id="transcription_id"),
            entities=entities,
        )


def test_create_transcription_entities_wrong_entities_subtype(mock_elements_worker):
    with pytest.raises(
        AssertionError, match="Entity at index 0 in entities: Should be of type dict"
    ):
        mock_elements_worker.create_transcription_entities(
            transcription=Transcription(id="transcription_id"),
            entities=["not a dict"],
        )


@pytest.mark.parametrize(
    "entity, error",
    (
        (
            {
                "name": None,
                "type_id": "12341234-1234-1234-1234-123412341234",
                "offset": 0,
                "length": 1,
                "confidence": 0.5,
            },
            "Entity at index 0 in entities: name shouldn't be null and should be of type str",
        ),
        (
            {"name": "A", "type_id": None, "offset": 0, "length": 1, "confidence": 0.5},
            "Entity at index 0 in entities: type_id shouldn't be null and should be of type str",
        ),
        (
            {"name": "A", "type_id": 0, "offset": 0, "length": 1, "confidence": 0.5},
            "Entity at index 0 in entities: type_id shouldn't be null and should be of type str",
        ),
        (
            {
                "name": "A",
                "type_id": "12341234-1234-1234-1234-123412341234",
                "offset": None,
                "length": 1,
                "confidence": 0.5,
            },
            "Entity at index 0 in entities: offset shouldn't be null and should be a positive integer",
        ),
        (
            {
                "name": "A",
                "type_id": "12341234-1234-1234-1234-123412341234",
                "offset": -2,
                "length": 1,
                "confidence": 0.5,
            },
            "Entity at index 0 in entities: offset shouldn't be null and should be a positive integer",
        ),
        (
            {
                "name": "A",
                "type_id": "12341234-1234-1234-1234-123412341234",
                "offset": 0,
                "length": None,
                "confidence": 0.5,
            },
            "Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer",
        ),
        (
            {
                "name": "A",
                "type_id": "12341234-1234-1234-1234-123412341234",
                "offset": 0,
                "length": 0,
                "confidence": 0.5,
            },
            "Entity at index 0 in entities: length shouldn't be null and should be a strictly positive integer",
        ),
        (
            {
                "name": "A",
                "type_id": "12341234-1234-1234-1234-123412341234",
                "offset": 0,
                "length": 1,
                "confidence": "not None or a float",
            },
            "Entity at index 0 in entities: confidence should be None or a float in [0..1] range",
        ),
        (
            {
                "name": "A",
                "type_id": "12341234-1234-1234-1234-123412341234",
                "offset": 0,
                "length": 1,
                "confidence": 1.3,
            },
            "Entity at index 0 in entities: confidence should be None or a float in [0..1] range",
        ),
    ),
)
def test_create_transcription_entities_wrong_entity(
    mock_elements_worker, entity, error
):
    with pytest.raises(AssertionError, match=re.escape(error)):
        mock_elements_worker.create_transcription_entities(
            transcription=Transcription(id="transcription_id"),
            entities=[entity],
        )


def test_create_transcription_entities(responses, mock_elements_worker):
    transcription = Transcription(id="transcription-id")
    # Call to Transcription entities creation in bulk
    responses.add(
        responses.POST,
        "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
        status=201,
        match=[
            matchers.json_params_matcher(
                {
                    "worker_run_id": "56785678-5678-5678-5678-567856785678",
                    "entities": [
                        {
                            "name": "Teklia",
                            "type_id": "22222222-2222-2222-2222-222222222222",
                            "offset": 0,
                            "length": 6,
                            "confidence": 1.0,
                        }
                    ],
                }
            )
        ],
        json={
            "entities": [
                {
                    "transcription_entity_id": "transc-entity-id",
                    "entity_id": "entity-id",
                }
            ]
        },
    )

    # Store entity type/slug correspondence on the worker
    mock_elements_worker.entity_types = {
        "22222222-2222-2222-2222-222222222222": "organization"
    }
    created_objects = mock_elements_worker.create_transcription_entities(
        transcription=transcription,
        entities=[
            {
                "name": "Teklia",
                "type_id": "22222222-2222-2222-2222-222222222222",
                "offset": 0,
                "length": 6,
                "confidence": 1.0,
            }
        ],
    )

    assert len(created_objects) == 1

    assert len(responses.calls) == len(BASE_API_CALLS) + 1
    assert [
        (call.request.method, call.request.url) for call in responses.calls
    ] == BASE_API_CALLS + [
        (
            "POST",
            "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
        ),
    ]