Skip to content
Snippets Groups Projects
Verified Commit 1bb2ad35 authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

Support confidence on create_sub_element and create_elements

parent 727703a0
No related branches found
No related tags found
1 merge request!162Support confidence on create_sub_element and create_elements
Pipeline #79160 passed
...@@ -68,6 +68,7 @@ class CachedElement(Model): ...@@ -68,6 +68,7 @@ class CachedElement(Model):
mirrored = BooleanField(default=False) mirrored = BooleanField(default=False)
initial = BooleanField(default=False) initial = BooleanField(default=False)
worker_version_id = UUIDField(null=True) worker_version_id = UUIDField(null=True)
confidence = FloatField(null=True)
class Meta: class Meta:
database = db database = db
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import uuid import uuid
from typing import Dict, List, Optional, Union
from peewee import IntegrityError from peewee import IntegrityError
...@@ -40,10 +41,26 @@ class ElementMixin(object): ...@@ -40,10 +41,26 @@ class ElementMixin(object):
return True return True
def create_sub_element(self, element, type, name, polygon): def create_sub_element(
self,
element: Element,
type: str,
name: str,
polygon: List[List[Union[int, float]]],
confidence: Optional[float] = None,
) -> str:
""" """
Create a child element on the given element through API Create a child element on the given element through the API.
Return the ID of the created sub element
:param Element element: The parent element.
:param str type: Slug of the element type for this child element.
:param str name: Name of the child element.
:param polygon: Polygon of the child element.
:type polygon: list(list(int or float))
:param confidence: Optional confidence score, between 0.0 and 1.0.
:type confidence: float or None
:returns: UUID of the created element.
:rtype: str
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, Element
...@@ -64,6 +81,10 @@ class ElementMixin(object): ...@@ -64,6 +81,10 @@ class ElementMixin(object):
assert all( assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point isinstance(coord, (int, float)) for point in polygon for coord in point
), "polygon points should be lists of two numbers" ), "polygon points should be lists of two numbers"
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence should be None or a float in [0..1] range"
if self.is_read_only: if self.is_read_only:
logger.warning("Cannot create element as this worker is in read-only mode") logger.warning("Cannot create element as this worker is in read-only mode")
return return
...@@ -78,16 +99,43 @@ class ElementMixin(object): ...@@ -78,16 +99,43 @@ class ElementMixin(object):
"polygon": polygon, "polygon": polygon,
"parent": element.id, "parent": element.id,
"worker_version": self.worker_version_id, "worker_version": self.worker_version_id,
"confidence": confidence,
}, },
) )
self.report.add_element(element.id, type) self.report.add_element(element.id, type)
return sub_element["id"] return sub_element["id"]
def create_elements(self, parent, elements): def create_elements(
self,
parent: Union[Element, CachedElement],
elements: List[
Dict[str, Union[str, List[List[Union[int, float]]], float, None]]
],
) -> List[Dict[str, str]]:
""" """
Create children elements on the given element through API Create child elements on the given element in a single API request.
Return the IDs of created elements
:param parent: Parent element for all the new child elements. The parent must have an image and a polygon.
:type parent: Element or CachedElement
:param elements: List of dicts, one per element. Each dict can have the following keys:
name (str)
Required. Name of the element.
type (str)
Required. Slug of the element type for this element.
polygon (list(list(int or float)))
Required. Polygon for this child element. Must have at least three points, with each point
having two non-negative coordinates and being inside of the parent element's image.
confidence (float or None)
Optional confidence score, between 0.0 and 1.0.
:type elements: list(dict(str, Any))
:return: List of dicts, with each dict having a single key, ``id``, holding the UUID of each created element.
:rtype: list(dict(str, str))
""" """
if isinstance(parent, Element): if isinstance(parent, Element):
assert parent.get( assert parent.get(
...@@ -135,6 +183,11 @@ class ElementMixin(object): ...@@ -135,6 +183,11 @@ class ElementMixin(object):
isinstance(coord, (int, float)) for point in polygon for coord in point isinstance(coord, (int, float)) for point in polygon for coord in point
), f"Element at index {index} in elements: polygon points should be lists of two numbers" ), f"Element at index {index} in elements: polygon points should be lists of two numbers"
confidence = element.get("confidence")
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
), f"Element at index {index} in elements: confidence should be None or a float in [0..1] range"
if self.is_read_only: if self.is_read_only:
logger.warning("Cannot create elements as this worker is in read-only mode") logger.warning("Cannot create elements as this worker is in read-only mode")
return return
...@@ -176,6 +229,7 @@ class ElementMixin(object): ...@@ -176,6 +229,7 @@ class ElementMixin(object):
"image_id": image_id, "image_id": image_id,
"polygon": element["polygon"], "polygon": element["polygon"],
"worker_version_id": self.worker_version_id, "worker_version_id": self.worker_version_id,
"confidence": element.get("confidence"),
} }
for idx, element in enumerate(elements) for idx, element in enumerate(elements)
] ]
......
...@@ -59,7 +59,7 @@ def test_create_tables(tmp_path): ...@@ -59,7 +59,7 @@ def test_create_tables(tmp_path):
create_tables() create_tables()
expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id")) expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))
CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id")) CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, "confidence" REAL, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
......
...@@ -424,6 +424,19 @@ def test_create_sub_element_wrong_polygon(mock_elements_worker): ...@@ -424,6 +424,19 @@ def test_create_sub_element_wrong_polygon(mock_elements_worker):
assert str(e.value) == "polygon points should be lists of two numbers" assert str(e.value) == "polygon points should be lists of two numbers"
@pytest.mark.parametrize("confidence", ["lol", "0.2", -1.0, 1.42, float("inf")])
def test_create_sub_element_wrong_confidence(mock_elements_worker, confidence):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_sub_element(
element=Element({"zone": None}),
type="something",
name="blah",
polygon=[[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]],
confidence=confidence,
)
assert str(e.value) == "confidence should be None or a float in [0..1] range"
def test_create_sub_element_api_error(responses, mock_elements_worker): def test_create_sub_element_api_error(responses, mock_elements_worker):
elt = Element( elt = Element(
{ {
...@@ -495,6 +508,49 @@ def test_create_sub_element(responses, mock_elements_worker): ...@@ -495,6 +508,49 @@ def test_create_sub_element(responses, mock_elements_worker):
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
"parent": "12341234-1234-1234-1234-123412341234", "parent": "12341234-1234-1234-1234-123412341234",
"worker_version": "12341234-1234-1234-1234-123412341234", "worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": None,
}
assert sub_element_id == "12345678-1234-1234-1234-123456789123"
def test_create_sub_element_confidence(responses, mock_elements_worker):
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"corpus": {"id": "11111111-1111-1111-1111-111111111111"},
"zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}},
}
)
responses.add(
responses.POST,
"http://testserver/api/v1/elements/create/",
status=200,
json={"id": "12345678-1234-1234-1234-123456789123"},
)
sub_element_id = mock_elements_worker.create_sub_element(
element=elt,
type="something",
name="0",
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
confidence=0.42,
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/elements/create/"),
]
assert json.loads(responses.calls[-1].request.body) == {
"type": "something",
"name": "0",
"image": "22222222-2222-2222-2222-222222222222",
"corpus": "11111111-1111-1111-1111-111111111111",
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
"parent": "12341234-1234-1234-1234-123412341234",
"worker_version": "12341234-1234-1234-1234-123412341234",
"confidence": 0.42,
} }
assert sub_element_id == "12345678-1234-1234-1234-123456789123" assert sub_element_id == "12345678-1234-1234-1234-123456789123"
...@@ -740,6 +796,26 @@ def test_create_elements_wrong_elements_polygon(mock_elements_worker): ...@@ -740,6 +796,26 @@ def test_create_elements_wrong_elements_polygon(mock_elements_worker):
) )
@pytest.mark.parametrize("confidence", ["lol", "0.2", -1.0, 1.42, float("inf")])
def test_create_elements_wrong_elements_confidence(mock_elements_worker, confidence):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_elements(
parent=Element({"zone": {"image": {"id": "image_id"}}}),
elements=[
{
"name": "a",
"type": "something",
"polygon": [[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]],
"confidence": confidence,
}
],
)
assert (
str(e.value)
== "Element at index 0 in elements: confidence should be None or a float in [0..1] range"
)
def test_create_elements_api_error(responses, mock_elements_worker): def test_create_elements_api_error(responses, mock_elements_worker):
elt = Element( elt = Element(
{ {
...@@ -930,6 +1006,80 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path): ...@@ -930,6 +1006,80 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
confidence=None,
)
]
def test_create_elements_confidence(
responses, mock_elements_worker_with_cache, tmp_path
):
elt = Element(
{
"id": "12341234-1234-1234-1234-123412341234",
"zone": {
"image": {
"id": "c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
"width": 42,
"height": 42,
"url": "http://aaaa",
}
},
}
)
responses.add(
responses.POST,
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
status=200,
json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
)
created_ids = mock_elements_worker_with_cache.create_elements(
parent=elt,
elements=[
{
"name": "0",
"type": "something",
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
"confidence": 0.42,
}
],
)
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"POST",
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
),
]
assert json.loads(responses.calls[-1].request.body) == {
"elements": [
{
"name": "0",
"type": "something",
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
"confidence": 0.42,
}
],
"worker_version": "12341234-1234-1234-1234-123412341234",
}
assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
# Check that created elements were properly stored in SQLite cache
assert (tmp_path / "db.sqlite").is_file()
assert list(CachedElement.select()) == [
CachedElement(
id=UUID("497f6eca-6276-4993-bfeb-53cbbbba6f08"),
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
type="something",
image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
confidence=0.42,
) )
] ]
......
...@@ -345,6 +345,7 @@ def test_create_transcription_orientation_with_cache( ...@@ -345,6 +345,7 @@ def test_create_transcription_orientation_with_cache(
"mirrored": False, "mirrored": False,
"initial": False, "initial": False,
"worker_version_id": None, "worker_version_id": None,
"confidence": None,
}, },
"text": "Animula vagula blandula", "text": "Animula vagula blandula",
"confidence": 0.42, "confidence": 0.42,
...@@ -809,6 +810,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_ ...@@ -809,6 +810,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_
"mirrored": False, "mirrored": False,
"initial": False, "initial": False,
"worker_version_id": None, "worker_version_id": None,
"confidence": None,
}, },
"text": "Animula vagula blandula", "text": "Animula vagula blandula",
"confidence": 0.12, "confidence": 0.12,
...@@ -827,6 +829,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_ ...@@ -827,6 +829,7 @@ def test_create_transcriptions_orientation(responses, mock_elements_worker_with_
"mirrored": False, "mirrored": False,
"initial": False, "initial": False,
"worker_version_id": None, "worker_version_id": None,
"confidence": None,
}, },
"text": "Hospes comesque corporis", "text": "Hospes comesque corporis",
"confidence": 0.21, "confidence": 0.21,
...@@ -1585,6 +1588,7 @@ def test_create_transcriptions_orientation_with_cache( ...@@ -1585,6 +1588,7 @@ def test_create_transcriptions_orientation_with_cache(
"mirrored": False, "mirrored": False,
"initial": False, "initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"confidence": None,
}, },
"text": "Animula vagula blandula", "text": "Animula vagula blandula",
"confidence": 0.5, "confidence": 0.5,
...@@ -1603,6 +1607,7 @@ def test_create_transcriptions_orientation_with_cache( ...@@ -1603,6 +1607,7 @@ def test_create_transcriptions_orientation_with_cache(
"mirrored": False, "mirrored": False,
"initial": False, "initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"confidence": None,
}, },
"text": "Hospes comesque corporis", "text": "Hospes comesque corporis",
"confidence": 0.75, "confidence": 0.75,
...@@ -1621,6 +1626,7 @@ def test_create_transcriptions_orientation_with_cache( ...@@ -1621,6 +1626,7 @@ def test_create_transcriptions_orientation_with_cache(
"mirrored": False, "mirrored": False,
"initial": False, "initial": False,
"worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
"confidence": None,
}, },
"text": "Quae nunc abibis in loca", "text": "Quae nunc abibis in loca",
"confidence": 0.9, "confidence": 0.9,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment