Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (4)
0.2.0-rc3
0.2.0-rc5
......@@ -6,6 +6,8 @@ import sqlite3
from peewee import (
BooleanField,
CharField,
Check,
CompositeKey,
Field,
FloatField,
ForeignKeyField,
......@@ -138,6 +140,20 @@ class CachedEntity(Model):
table_name = "entities"
class CachedTranscriptionEntity(Model):
transcription = ForeignKeyField(
CachedTranscription, backref="transcription_entities"
)
entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
class Meta:
primary_key = CompositeKey("transcription", "entity")
database = db
table_name = "transcription_entities"
# Add all the managed models in that list
# It's used here, but also in unit tests
MODELS = [
......@@ -146,6 +162,7 @@ MODELS = [
CachedTranscription,
CachedClassification,
CachedEntity,
CachedTranscriptionEntity,
]
......
......@@ -5,7 +5,7 @@ from enum import Enum
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedEntity
from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity
from arkindex_worker.models import Element
......@@ -81,3 +81,53 @@ class EntityMixin(object):
logger.warning(f"Couldn't save created entity in local cache: {e}")
return entity["id"]
def create_transcription_entity(self, transcription, entity, offset, length):
"""
Create a link between an existing entity and an existing transcription through API
"""
assert transcription and isinstance(
transcription, str
), "transcription shouldn't be null and should be of type str"
assert entity and isinstance(
entity, str
), "entity shouldn't be null and should be of type str"
assert (
offset is not None and isinstance(offset, int) and offset >= 0
), "offset shouldn't be null and should be a positive integer"
assert (
length is not None and isinstance(length, int) and length > 0
), "length shouldn't be null and should be a strictly positive integer"
if self.is_read_only:
logger.warning(
"Cannot create transcription entity as this worker is in read-only mode"
)
return
self.request(
"CreateTranscriptionEntity",
id=transcription,
body={
"entity": entity,
"length": length,
"offset": offset,
},
)
# TODO: Report transcription entity creation
if self.use_cache:
# Store transcription entity in local cache
try:
to_insert = [
{
"transcription": transcription,
"entity": entity,
"offset": offset,
"length": length,
}
]
CachedTranscriptionEntity.insert_many(to_insert).execute()
except IntegrityError as e:
logger.warning(
f"Couldn't save created transcription entity in local cache: {e}"
)
......@@ -58,6 +58,7 @@ def test_create_tables(tmp_path):
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
actual_schema = "\n".join(
......
......@@ -5,7 +5,12 @@ from uuid import UUID
import pytest
from apistar.exceptions import ErrorResponse
from arkindex_worker.cache import CachedElement, CachedEntity
from arkindex_worker.cache import (
CachedElement,
CachedEntity,
CachedTranscription,
CachedTranscriptionEntity,
)
from arkindex_worker.models import Element
from arkindex_worker.worker import EntityType
......@@ -258,3 +263,232 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
]
def test_create_transcription_entity_wrong_transcription(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription=None,
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert str(e.value) == "transcription shouldn't be null and should be of type str"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription=1234,
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert str(e.value) == "transcription shouldn't be null and should be of type str"
def test_create_transcription_entity_wrong_entity(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity=None,
offset=5,
length=10,
)
assert str(e.value) == "entity shouldn't be null and should be of type str"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity=1234,
offset=5,
length=10,
)
assert str(e.value) == "entity shouldn't be null and should be of type str"
def test_create_transcription_entity_wrong_offset(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=None,
length=10,
)
assert str(e.value) == "offset shouldn't be null and should be a positive integer"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset="not an int",
length=10,
)
assert str(e.value) == "offset shouldn't be null and should be a positive integer"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=-1,
length=10,
)
assert str(e.value) == "offset shouldn't be null and should be a positive integer"
def test_create_transcription_entity_wrong_length(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=None,
)
assert (
str(e.value)
== "length shouldn't be null and should be a strictly positive integer"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length="not an int",
)
assert (
str(e.value)
== "length shouldn't be null and should be a strictly positive integer"
)
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=0,
)
assert (
str(e.value)
== "length shouldn't be null and should be a strictly positive integer"
)
def test_create_transcription_entity_api_error(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=500,
)
with pytest.raises(ErrorResponse):
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert len(responses.calls) == 7
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
# We retry 5 times the API call
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
]
def test_create_transcription_entity(responses, mock_elements_worker):
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=200,
json={
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
},
)
mock_elements_worker.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
]
assert json.loads(responses.calls[2].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
}
def test_create_transcription_entity_with_cache(
responses, mock_elements_worker_with_cache
):
CachedElement.create(
id=UUID("12341234-1234-1234-1234-123412341234"),
type="page",
)
CachedTranscription.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
element=UUID("12341234-1234-1234-1234-123412341234"),
text="Hello, it's me.",
confidence=0.42,
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
CachedEntity.create(
id=UUID("11111111-1111-1111-1111-111111111111"),
type="person",
name="Bob Bob",
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
)
responses.add(
responses.POST,
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
status=200,
json={
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
},
)
mock_elements_worker_with_cache.create_transcription_entity(
transcription="11111111-1111-1111-1111-111111111111",
entity="11111111-1111-1111-1111-111111111111",
offset=5,
length=10,
)
assert len(responses.calls) == 3
assert [call.request.url for call in responses.calls] == [
"http://testserver/api/v1/user/",
"http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
"http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
]
assert json.loads(responses.calls[2].request.body) == {
"entity": "11111111-1111-1111-1111-111111111111",
"offset": 5,
"length": 10,
}
# Check that created transcription entity was properly stored in SQLite cache
assert list(CachedTranscriptionEntity.select()) == [
CachedTranscriptionEntity(
transcription=UUID("11111111-1111-1111-1111-111111111111"),
entity=UUID("11111111-1111-1111-1111-111111111111"),
offset=5,
length=10,
)
]