Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (4)
......@@ -109,7 +109,7 @@ class Element(MagicDict):
@property
def requires_tiles(self) -> bool:
"""
:returns: Whether or not downloading and combining IIIF tiles will be necessary
Whether or not downloading and combining IIIF tiles will be necessary
to retrieve this element's image. Will be False if the element has no image.
"""
from arkindex_worker.image import polygon_bounding_box
......
......@@ -153,7 +153,7 @@ class BaseWorker(object):
"""
Whether or not the worker can publish data.
:returns: False when dev mode is enabled with the ``--dev`` CLI argument,
False when dev mode is enabled with the ``--dev`` CLI argument,
when no worker run ID is provided
"""
return self.args.dev or self.worker_run_id is None
......
......@@ -4,6 +4,7 @@ ElementsWorker methods for classifications and ML classes.
"""
from typing import Dict, List, Optional, Union
from uuid import UUID
from apistar.exceptions import ErrorResponse
from peewee import IntegrityError
......@@ -64,6 +65,30 @@ class ClassificationMixin(object):
return ml_class_id
def retrieve_ml_class(self, ml_class_id: str) -> str:
"""
Retrieve the name of the MLClass from its ID.
:param ml_class_id: ID of the searched MLClass.
:return: The MLClass's name
"""
# Load the corpus' MLclasses if they are not available yet
if self.corpus_id not in self.classes:
self.load_corpus_classes()
# Filter classes by this ml_class_id
ml_class_name = next(
filter(
lambda x: self.classes[self.corpus_id][x] == ml_class_id,
self.classes[self.corpus_id],
),
None,
)
assert (
ml_class_name is not None
), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})"
return ml_class_name
def create_classification(
self,
element: Union[Element, CachedElement],
......@@ -97,7 +122,6 @@ class ClassificationMixin(object):
"Cannot create classification as this worker is in read-only mode"
)
return
try:
created = self.request(
"CreateClassification",
......@@ -166,7 +190,7 @@ class ClassificationMixin(object):
:param element: The element to create classifications on.
:param classifications: The classifications to create, a list of dicts. Each of them contains
a **class_name** (str), the name of the MLClass for this classification;
a **ml_class_id** (str), the ID of the MLClass for this classification;
a **confidence** (float), the confidence score, between 0 and 1;
a **high_confidence** (bool), the high confidence state of the classification.
......@@ -181,10 +205,18 @@ class ClassificationMixin(object):
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
ml_class_id = classification.get("ml_class_id")
assert ml_class_id and isinstance(
ml_class_id, str
), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
# Make sure it's a valid UUID
try:
UUID(ml_class_id)
except ValueError:
raise ValueError(
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
)
confidence = classification.get("confidence")
assert (
......@@ -215,6 +247,7 @@ class ClassificationMixin(object):
)["classifications"]
for created_cl in created_cls:
created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
......@@ -224,7 +257,7 @@ class ClassificationMixin(object):
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"class_name": created_cl.pop("class_name"),
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_run_id": self.worker_run_id,
......
......@@ -2,6 +2,6 @@ black==22.10.0
doc8==1.0.0
mkdocs==1.4.2
mkdocs-material==8.5.11
mkdocstrings==0.19.0
mkdocstrings==0.19.1
mkdocstrings-python==0.8.2
recommonmark==0.7.1
......@@ -67,6 +67,21 @@ include:
- `word-segmenter`
- `paragraph-creator`
`gpu_usage`
: Whether or not this worker requires or supports GPUs. Defaults to `disabled`. May take one of the following values:
`required`
: This worker requires a GPU, and will only be run on Ponos agents whose hosts have a GPU.
`supported`
: This worker supports using a GPU, but may run on any available host, including those without GPUs.
`disabled`
: This worker does not support GPUs. It may run on a host that has a GPU, but it will ignore it.
`model_usage`
: Boolean. Whether or not this worker requires a model version to run. Defaults to `false`.
`docker`
: Regroups Docker-related configuration attributes:
<!--
......
# -*- coding: utf-8 -*-
import json
from uuid import UUID
from uuid import UUID, uuid4
import pytest
from apistar.exceptions import ErrorResponse
......@@ -162,6 +162,46 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
]
def test_retrieve_ml_class_in_cache(mock_elements_worker):
"""
Look for a class that exists in cache -> No API Call
"""
mock_elements_worker.classes[mock_elements_worker.corpus_id] = {"class1": "uuid1"}
assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1"
def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
"""
Retrieve class not in cache -> Retrieve corpus ml classes via API
"""
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
status=200,
json={
"count": 1,
"next": None,
"results": [
{
"id": "uuid1",
"name": "class1",
},
],
},
)
assert mock_elements_worker.retrieve_ml_class("uuid1") == "class1"
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
),
]
def test_create_classification_wrong_element(mock_elements_worker):
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_classification(
......@@ -520,12 +560,12 @@ def test_create_classifications_wrong_element(mock_elements_worker):
element=None,
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": False,
},
......@@ -541,12 +581,12 @@ def test_create_classifications_wrong_element(mock_elements_worker):
element="not element type",
classifications=[
{
"class_name": "portrait",
"ml_class_id": "uuid1",
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": "uuid2",
"confidence": 0.25,
"high_confidence": False,
},
......@@ -584,19 +624,19 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"confidence": 0.25,
"ml_class_id": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
== "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
......@@ -604,12 +644,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": None,
"ml_class_id": None,
"confidence": 0.25,
"high_confidence": False,
},
......@@ -617,7 +657,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
== "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str"
)
with pytest.raises(AssertionError) as e:
......@@ -625,12 +665,33 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"ml_class_id": 1234,
"confidence": 0.25,
"high_confidence": False,
},
],
)
assert (
str(e.value)
== "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str"
)
with pytest.raises(ValueError) as e:
mock_elements_worker.create_classifications(
element=elt,
classifications=[
{
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": 1234,
"ml_class_id": "not_an_uuid",
"confidence": 0.25,
"high_confidence": False,
},
......@@ -638,7 +699,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
)
assert (
str(e.value)
== "Classification at index 1 in classifications: class_name shouldn't be null and should be of type str"
== "Classification at index 1 in classifications: ml_class_id is not a valid uuid."
)
with pytest.raises(AssertionError) as e:
......@@ -646,12 +707,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": str(uuid4()),
"high_confidence": False,
},
],
......@@ -666,12 +727,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": str(uuid4()),
"confidence": None,
"high_confidence": False,
},
......@@ -687,12 +748,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": str(uuid4()),
"confidence": "wrong confidence",
"high_confidence": False,
},
......@@ -708,12 +769,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": str(uuid4()),
"confidence": 0,
"high_confidence": False,
},
......@@ -729,12 +790,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": str(uuid4()),
"confidence": 2.00,
"high_confidence": False,
},
......@@ -750,12 +811,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
element=elt,
classifications=[
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": str(uuid4()),
"confidence": 0.25,
"high_confidence": "wrong high_confidence",
},
......@@ -776,12 +837,12 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
classes = [
{
"class_name": "portrait",
"ml_class_id": str(uuid4()),
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": str(uuid4()),
"confidence": 0.25,
"high_confidence": False,
},
......@@ -806,15 +867,22 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
def test_create_classifications(responses, mock_elements_worker_with_cache):
# Set MLClass in cache
portrait_uuid = str(uuid4())
landscape_uuid = str(uuid4())
mock_elements_worker_with_cache.classes[
mock_elements_worker_with_cache.corpus_id
] = {"portrait": portrait_uuid, "landscape": landscape_uuid}
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"class_name": "portrait",
"ml_class_id": portrait_uuid,
"confidence": 0.75,
"high_confidence": False,
},
{
"class_name": "landscape",
"ml_class_id": landscape_uuid,
"confidence": 0.25,
"high_confidence": False,
},
......@@ -830,14 +898,14 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"class_name": "portrait",
"ml_class": portrait_uuid,
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"class_name": "landscape",
"ml_class": landscape_uuid,
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
......@@ -882,3 +950,110 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]
def test_create_classifications_not_in_cache(
responses, mock_elements_worker_with_cache
):
"""
CreateClassifications using ID that are not in `.classes` attribute.
Will load corpus MLClass to insert the corresponding name in Cache.
"""
portrait_uuid = str(uuid4())
landscape_uuid = str(uuid4())
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
classes = [
{
"ml_class_id": portrait_uuid,
"confidence": 0.75,
"high_confidence": False,
},
{
"ml_class_id": landscape_uuid,
"confidence": 0.25,
"high_confidence": False,
},
]
responses.add(
responses.POST,
"http://testserver/api/v1/classification/bulk/",
status=200,
json={
"parent": str(elt.id),
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": [
{
"id": "00000000-0000-0000-0000-000000000000",
"ml_class": portrait_uuid,
"confidence": 0.75,
"high_confidence": False,
"state": "pending",
},
{
"id": "11111111-1111-1111-1111-111111111111",
"ml_class": landscape_uuid,
"confidence": 0.25,
"high_confidence": False,
"state": "pending",
},
],
},
)
responses.add(
responses.GET,
f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
status=200,
json={
"count": 2,
"next": None,
"results": [
{
"id": portrait_uuid,
"name": "portrait",
},
{"id": landscape_uuid, "name": "landscape"},
],
},
)
mock_elements_worker_with_cache.create_classifications(
element=elt, classifications=classes
)
assert len(responses.calls) == len(BASE_API_CALLS) + 2
assert [
(call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [
("POST", "http://testserver/api/v1/classification/bulk/"),
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
),
]
assert json.loads(responses.calls[-2].request.body) == {
"parent": str(elt.id),
"worker_run_id": "56785678-5678-5678-5678-567856785678",
"classifications": classes,
}
# Check that created classifications were properly stored in SQLite cache
assert list(CachedClassification.select()) == [
CachedClassification(
id=UUID("00000000-0000-0000-0000-000000000000"),
element_id=UUID(elt.id),
class_name="portrait",
confidence=0.75,
state="pending",
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
CachedClassification(
id=UUID("11111111-1111-1111-1111-111111111111"),
element_id=UUID(elt.id),
class_name="landscape",
confidence=0.25,
state="pending",
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
),
]