Skip to content
Snippets Groups Projects
Commit a9aad29f authored by Martin Maarand's avatar Martin Maarand Committed by Bastien Abadie
Browse files

Add helpers to retrieve and cache worker versions

parent 74b734fb
No related branches found
No related tags found
1 merge request!29Add helpers to retrieve and cache worker versions
Pipeline #77998 passed
......@@ -198,6 +198,8 @@ class ElementsWorker(BaseWorker):
)
self.classes = {}
self._worker_version_cache = {}
def list_elements(self):
assert not (
self.args.elements_list and self.args.element
......@@ -543,3 +545,48 @@ class ElementsWorker(BaseWorker):
self.report.add_transcription(annotation["id"], transcription_type.value)
return annotations
def get_worker_version(self, worker_version_id: str) -> dict:
"""
Get worker version from cache if possible, otherwise make API request
"""
if worker_version_id in self._worker_version_cache:
return self._worker_version_cache[worker_version_id]
worker_version = self.api_client.request(
"RetrieveWorkerVersion", id=worker_version_id
)
self._worker_version_cache[worker_version_id] = worker_version
return worker_version
def get_worker_version_slug(self, worker_version_id: str) -> str:
"""
Get worker version slug from cache if possible, otherwise make API request
Should use `get_ml_result_slug` instead of using this method directly
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]
def get_ml_result_slug(self, ml_result) -> str:
"""
Helper function to get the slug from element, classification or transcription
Can handle old and new (source vs worker_version)
:type ml_result: Element or classification or transcription
"""
if (
"source" in ml_result
and ml_result["source"]
and ml_result["source"]["slug"]
):
return ml_result["source"]["slug"]
elif "worker_version" in ml_result and ml_result["worker_version"]:
return self.get_worker_version_slug(ml_result["worker_version"])
# transcriptions have worker_version_id but elements have worker_version
elif "worker_version_id" in ml_result and ml_result["worker_version_id"]:
return self.get_worker_version_slug(ml_result["worker_version_id"])
else:
raise ValueError(f"Unable to get slug from: {ml_result}")
......@@ -6,8 +6,11 @@ from pathlib import Path
import pytest
from arkindex.mock import MockApiClient
from arkindex_worker.worker import ElementsWorker
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
@pytest.fixture(autouse=True)
def setup_api(responses, monkeypatch):
......@@ -87,3 +90,29 @@ def mock_elements_worker(monkeypatch, mock_worker_version_api):
worker = ElementsWorker()
worker.configure()
return worker
@pytest.fixture
def fake_page_element():
with open(FIXTURES_DIR / "page_element.json", "r") as f:
return json.load(f)
@pytest.fixture
def fake_ufcn_worker_version():
with open(FIXTURES_DIR / "ufcn_line_historical_worker_version.json", "r") as f:
return json.load(f)
@pytest.fixture
def fake_transcriptions_small():
with open(FIXTURES_DIR / "line_transcriptions_small.json", "r") as f:
return json.load(f)
@pytest.fixture
def fake_dummy_worker():
api_client = MockApiClient()
worker = ElementsWorker()
worker.api_client = api_client
return worker
{
"count": 1,
"number": 1,
"next": null,
"previous": null,
"results": [
{
"id": "008691ae-8133-48c4-88d5-d4cc9f65c06c",
"type": "line",
"text": "J . Caron &",
"score": 0.4781,
"zone": null,
"source": null,
"worker_version_id": "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996",
"element": {
"id": "9e7ff0a5-bf89-42f5-ad85-fc3fb64ac7e8",
"type": "text_line",
"name": "16",
"zone": {
"id": "d09eb012-f367-443b-b97c-b5f6cf8ea754",
"polygon": [
[
872,
1705
],
[
872,
1747
],
[
1198,
1747
],
[
1198,
1705
],
[
872,
1705
]
]
}
}
}
]
}
\ No newline at end of file
{
"id": "863ab6e8-b409-41df-98cf-303053933986",
"type": "page",
"name": "1",
"corpus": {
"id": "b26df920-e7dd-4429-bd89-d76639ae21c4",
"name": "Test Martin",
"public": false
},
"thumbnail_url": null,
"thumbnail_put_url": null,
"zone": {
"id": "fff2cda1-94cc-4370-bc34-9d7a820d153e",
"polygon": [
[
0,
0
],
[
0,
3818
],
[
2400,
3818
],
[
2400,
0
],
[
0,
0
]
],
"image": {
"id": "dab2d320-dd55-4138-a826-3d49e1e93d7d",
"path": "ad7c2394-865f-4377-a8f2-5a41d66298f9",
"width": 2400,
"height": 3818,
"url": "https://iiif.teklia.com/preprod/iiif/2/ad7c2394-865f-4377-a8f2-5a41d66298f9",
"s3_url": "https://preprod-arkindex-iiif.s3.amazonaws.com/ad7c2394-865f-4377-a8f2-5a41d66298f9?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIXZXPY5BU27ELOUQ%2F20201014%2Feu-west-3%2Fs3%2Faws4_request&X-Amz-Date=20201014T111554Z&X-Amz-Expires=3600&X-Amz-SignedHeaders=host&X-Amz-Signature=67cdfc376f365cc8d698a6470dc80a611ed88d0f263a2616f73bc52640861774",
"status": "checked",
"server": {
"display_name": "Arkindex dev",
"url": "https://iiif.teklia.com/preprod/iiif/2",
"max_width": null,
"max_height": null
}
},
"url": "https://iiif.teklia.com/preprod/iiif/2/ad7c2394-865f-4377-a8f2-5a41d66298f9/0,0,2400,3818/full/0/default.jpg"
},
"metadata": [
{
"id": "2750842e-5e7c-404e-933c-ffe38a94ea1c",
"type": "reference",
"name": "balsac_id",
"value": "test_page_ref2",
"revision": null,
"dates": []
},
{
"id": "dd12acb4-54c4-4701-a335-d4e16f80182f",
"type": "text",
"name": "folio",
"value": "1",
"revision": null,
"dates": []
}
],
"classifications": [
{
"id": "572481c9-0367-492b-ad99-719c81838fcb",
"source": {
"id": "c718dfe8-4a08-4fe1-851f-ec3451383e81",
"type": "classifier",
"slug": "scikit_portrait_outlier_balsac",
"name": "Scikit Portrait Outlier Classifier Balsac",
"revision": "0.1",
"internal": true
},
"ml_class": {
"id": "def18db8-9202-485e-be29-1eee4522fa08",
"name": "portrait"
},
"state": "pending",
"confidence": 0.7,
"high_confidence": true,
"worker_version": null
},
{
"id": "2f0158e0-3121-4105-91db-5b210354e42c",
"source": {
"id": "c718dfe8-4a08-4fe1-851f-ec3451383e81",
"type": "classifier",
"slug": "scikit_portrait_outlier_balsac",
"name": "Scikit Portrait Outlier Classifier Balsac",
"revision": "0.1",
"internal": true
},
"ml_class": {
"id": "887fc197-0755-48be-9549-403652239aa4",
"name": "act"
},
"state": "pending",
"confidence": 0.079,
"high_confidence": true,
"worker_version": null
},
{
"id": "66baa42e-b99d-4399-a810-e88659246b11",
"source": null,
"ml_class": {
"id": "79e0b521-0806-49ed-9077-f4608611bb61",
"name": "bad_line_1-5%"
},
"state": "pending",
"confidence": 0.03773584905660377,
"high_confidence": true,
"worker_version": "fcc0cc70-4254-43cb-beaf-29dd94b4e684"
}
],
"source": null,
"worker_version": null
}
\ No newline at end of file
{
"id": "fcc0cc70-4254-43cb-beaf-29dd94b4e684",
"configuration": {
"docker": {
"build": "Dockerfile",
"image": "",
"command": null,
"environment": {}
},
"secrets": [],
"configuration": {
"std": [
94,
91,
80
],
"mean": [
132,
126,
111
],
"model": "/usr/share/teklia/models/line_historical.pth",
"min_cc": 50,
"classes": [
"background",
"text_line"
],
"input_size": 768,
"element_types": {
"text_line": "text_line"
},
"bad_prediction_score": {
"text_line": "bad_line"
}
}
},
"revision": {
"id": "d86ce269-9b4e-49cf-bf39-b1dfe7f7bcc0",
"state": "available",
"hash": "ca3ec5895982ee97d9dc78a6d9f0e9879ff1d3de",
"author": "Mélodie Boillet",
"message": "Merge branch 'use-create-elements' into 'master'\n\nUse CreateElements endpoint to speedup ufcn process\n\nSee merge request teklia/workers/ufcn!15",
"created": "2020-10-06T09:03:10.623041Z",
"commit_url": "https://gitlab.com/teklia/workers/ufcn/commit/ca3ec5895982ee97d9dc78a6d9f0e9879ff1d3de",
"refs": []
},
"docker_image": "f9dbfa89-6350-4134-bee4-909f26552434",
"docker_image_iid": "sha256:754b83209f737817a978871f013cfc57ab6ccc0e514415d97b63723ada9aba16",
"docker_image_name": "gitlab.com/teklia/workers/ufcn/ufcn_line_historical:fcc0cc70-4254-43cb-beaf-29dd94b4e684",
"state": "available",
"worker": {
"id": "13b84978-cb85-410c-86e4-9747a1dd3d66",
"name": "U-FCN Line Historical",
"slug": "ufcn_line_historical",
"type": "dla"
}
}
\ No newline at end of file
......@@ -30,6 +30,9 @@ TRANSCRIPTIONS_SAMPLE = [
},
]
TEST_VERSION_ID = "test_123"
TEST_SLUG = "some_slug"
def test_cli_default(monkeypatch, mock_worker_version_api):
_, path = tempfile.mkstemp()
......@@ -1713,3 +1716,100 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
{"id": "word1_1_2", "created": False},
{"id": "word1_1_3", "created": False},
]
def test_get_worker_version(fake_dummy_worker):
api_client = fake_dummy_worker.api_client
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
res = fake_dummy_worker.get_worker_version(TEST_VERSION_ID)
assert res == response
assert fake_dummy_worker._worker_version_cache[TEST_VERSION_ID] == response
def test_get_worker_version__uses_cache(fake_dummy_worker):
api_client = fake_dummy_worker.api_client
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
response_1 = fake_dummy_worker.get_worker_version(TEST_VERSION_ID)
response_2 = fake_dummy_worker.get_worker_version(TEST_VERSION_ID)
assert response_1 == response
assert response_1 == response_2
# assert that only one call to the API
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__old_style(fake_dummy_worker):
element = {"source": {"slug": TEST_SLUG}}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
def test_get_slug__worker_version(fake_dummy_worker):
api_client = fake_dummy_worker.api_client
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
element = {"worker_version": TEST_VERSION_ID}
slug = fake_dummy_worker.get_ml_result_slug(element)
assert slug == TEST_SLUG
# assert that only one call to the API
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__both(fake_page_element, fake_ufcn_worker_version, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
api_client.add_response(
"RetrieveWorkerVersion",
fake_ufcn_worker_version,
id=fake_ufcn_worker_version["id"],
)
expected_slugs = [
"scikit_portrait_outlier_balsac",
"scikit_portrait_outlier_balsac",
"ufcn_line_historical",
]
slugs = [
fake_dummy_worker.get_ml_result_slug(clf)
for clf in fake_page_element["classifications"]
]
assert slugs == expected_slugs
assert len(api_client.history) == 1
assert not api_client.responses
def test_get_slug__transcriptions(fake_transcriptions_small, fake_dummy_worker):
api_client = fake_dummy_worker.api_client
version_id = "3ca4a8e3-91d1-4b78-8d83-d8bbbf487996"
response = {"worker": {"slug": TEST_SLUG}}
api_client.add_response("RetrieveWorkerVersion", response, id=version_id)
slug = fake_dummy_worker.get_ml_result_slug(fake_transcriptions_small["results"][0])
assert slug == TEST_SLUG
assert len(api_client.history) == 1
assert not api_client.responses
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment