Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (89)
Showing
with 767 additions and 477 deletions
......@@ -145,9 +145,7 @@ pypi-publication:
- pip install -e .[docs]
script:
- cd docs
- make html
- mv _build/html ../public
- mkdocs build --strict --verbose
docs-build:
extends: .docs
......
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
......@@ -11,7 +11,7 @@ repos:
rev: 22.3.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.8.3
hooks:
- id: flake8
......@@ -23,7 +23,6 @@ repos:
rev: v3.1.0
hooks:
- id: check-ast
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: check-symlinks
......
Arkindex base Worker
====================
# Arkindex base Worker
An easy to use Python 3 high level API client, to build ML tasks.
## Create a new worker using our template
```
pip install --user cookiecutter
cookiecutter git@gitlab.com:arkindex/base-worker.git
cookiecutter git@gitlab.com:teklia/workers/base-worker.git
```
0.3.0-rc5
0.3.2-rc3
......@@ -10,6 +10,8 @@ reducing network usage.
import json
import os
import sqlite3
from pathlib import Path
from typing import Optional, Union
from peewee import (
BooleanField,
......@@ -26,9 +28,9 @@ from peewee import (
TextField,
UUIDField,
)
from PIL import Image
from arkindex_worker import logger
from arkindex_worker.image import open_image, polygon_bounding_box
db = SqliteDatabase(None)
......@@ -65,6 +67,10 @@ class Version(Model):
class CachedImage(Model):
"""
Cache image table
"""
id = UUIDField(primary_key=True)
width = IntegerField()
height = IntegerField()
......@@ -76,6 +82,10 @@ class CachedImage(Model):
class CachedElement(Model):
"""
Cache element table
"""
id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True)
type = CharField(max_length=50)
......@@ -84,26 +94,29 @@ class CachedElement(Model):
rotation_angle = IntegerField(default=0)
mirrored = BooleanField(default=False)
initial = BooleanField(default=False)
# Needed to filter elements with cache
worker_version_id = UUIDField(null=True)
worker_run_id = UUIDField(null=True)
confidence = FloatField(null=True)
class Meta:
database = db
table_name = "elements"
def open_image(self, *args, max_size=None, **kwargs):
def open_image(self, *args, max_size: Optional[int] = None, **kwargs) -> Image:
"""
Open this element's image as a Pillow image.
This does not crop the image to the element's polygon.
IIIF servers with maxWidth, maxHeight or maxArea restrictions on image size are not supported.
:param \\*args: Positional arguments passed to :meth:`arkindex_worker.image.open_image`
:param *args: Positional arguments passed to [arkindex_worker.image.open_image][]
:param max_size: Subresolution of the image.
:type max_size: int or None
:param \\**kwargs: Keyword arguments passed to :meth:`arkindex_worker.image.open_image`
:param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][]
:raises ValueError: When this element does not have an image ID or a polygon.
:returns PIL.Image: A Pillow image.
:return: A Pillow image.
"""
from arkindex_worker.image import open_image, polygon_bounding_box
if not self.image_id or not self.polygon:
raise ValueError(f"Element {self.id} has no image")
......@@ -152,12 +165,18 @@ class CachedElement(Model):
class CachedTranscription(Model):
"""
Cache transcription table
"""
id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField()
confidence = FloatField()
orientation = CharField(max_length=50)
# Needed to filter transcriptions with cache
worker_version_id = UUIDField(null=True)
worker_run_id = UUIDField(null=True)
class Meta:
database = db
......@@ -165,12 +184,16 @@ class CachedTranscription(Model):
class CachedClassification(Model):
"""
Cache classification table
"""
id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="classifications")
class_name = TextField()
confidence = FloatField()
state = CharField(max_length=10)
worker_version_id = UUIDField(null=True)
worker_run_id = UUIDField(null=True)
class Meta:
database = db
......@@ -178,12 +201,16 @@ class CachedClassification(Model):
class CachedEntity(Model):
"""
Cache entity table
"""
id = UUIDField(primary_key=True)
type = CharField(max_length=50)
name = TextField()
validated = BooleanField(default=False)
metas = JSONField(null=True)
worker_version_id = UUIDField(null=True)
worker_run_id = UUIDField(null=True)
class Meta:
database = db
......@@ -191,13 +218,17 @@ class CachedEntity(Model):
class CachedTranscriptionEntity(Model):
"""
Cache transcription entity table
"""
transcription = ForeignKeyField(
CachedTranscription, backref="transcription_entities"
)
entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField(null=True)
worker_run_id = UUIDField(null=True)
confidence = FloatField(null=True)
class Meta:
......@@ -216,10 +247,14 @@ MODELS = [
CachedEntity,
CachedTranscriptionEntity,
]
SQL_VERSION = 1
SQL_VERSION = 2
def init_cache_db(path):
def init_cache_db(path: str):
"""
Create the cache database on the given path
:param path: Where the new database should be created
"""
db.init(
path,
pragmas={
......@@ -250,7 +285,12 @@ def create_version_table():
Version.create(version=SQL_VERSION)
def check_version(cache_path):
def check_version(cache_path: Union[str, Path]):
"""
Check the validity of the SQLite version
:param cache_path: Path towards a local SQLite database
"""
with SqliteDatabase(cache_path) as provided_db:
with provided_db.bind_ctx([Version]):
try:
......@@ -263,7 +303,16 @@ def check_version(cache_path):
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
def retrieve_parents_cache_path(
parent_ids: list, data_dir: str = "/data", chunk: int = None
) -> list:
"""
Retrieve the path of the given parent's in the
:param parent_ids: List of element IDs to search
:param data_dir: Base folder where to look for
:param chunk: Index of the chunk of the db that might contain the paths
:return: The corresponding list of paths
"""
assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir)
......@@ -288,9 +337,11 @@ def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
)
def merge_parents_cache(paths, current_database):
def merge_parents_cache(paths: list, current_database: str):
"""
Merge all the potential parent task's databases into the existing local one
:param paths: Path to cache databases
:param current_database: Path to the current database
"""
assert os.path.exists(current_database)
......
......@@ -6,10 +6,12 @@ import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, Union
import gitlab
import requests
import sh
from gitlab.v4.objects import MergeRequest, ProjectMergeRequest
from arkindex_worker import logger
......@@ -22,13 +24,13 @@ class GitlabHelper:
def __init__(
self,
project_id,
gitlab_url,
gitlab_token,
branch,
rebase_wait_period=1,
delete_source_branch=True,
max_rebase_tries=10,
project_id: str,
gitlab_url: str,
gitlab_token: str,
branch: str,
rebase_wait_period: Optional[int] = 1,
delete_source_branch: Optional[bool] = True,
max_rebase_tries: Optional[int] = 10,
):
"""
:param project_id: the id of the gitlab project
......@@ -52,13 +54,13 @@ class GitlabHelper:
self.project = self._api.projects.get(self.project_id)
self.is_rebase_finished = False
def merge(self, branch_name, title) -> bool:
def merge(self, branch_name: str, title: str) -> bool:
"""
Create a merge request and try to merge.
Always rebase first to avoid conflicts from MRs made in parallel
:param branch_name: source branch name
:param title: title of the merge request
:return: was the branch successfully merged?
:param branch_name: Source branch name
:param title: Title of the merge request
:return: Whether the branch was successfully merged
"""
mr = None
# always rebase first, because other workers might have merged already
......@@ -95,7 +97,14 @@ class GitlabHelper:
return False
def _create_merge_request(self, branch_name, title):
def _create_merge_request(self, branch_name: str, title: str) -> MergeRequest:
"""
Create a MergeRequest towards the branch with the given title
:param branch_name: Target branch of the merge request
:param title: Title of the merge request
:return: The created merge request
"""
logger.info(f"Creating a merge request for {branch_name}")
# retry_transient_error will retry the request on 50X errors
# https://github.com/python-gitlab/python-gitlab/blob/265dbbdd37af88395574564aeb3fd0350288a18c/gitlab/__init__.py#L539
......@@ -108,16 +117,24 @@ class GitlabHelper:
)
return mr
def _get_merge_request(self, merge_request_id, include_rebase_in_progress=True):
def _get_merge_request(
self, merge_request_id: Union[str, int], include_rebase_in_progress: bool = True
) -> ProjectMergeRequest:
"""
Retrieve a merge request by ID
:param merge_request_id: The ID of the merge request
:param include_rebase_in_progress: Whether the rebase in progree should be included
:return: The related merge request
"""
return self.project.mergerequests.get(
merge_request_id, include_rebase_in_progress=include_rebase_in_progress
)
def _wait_for_rebase_to_finish(self, merge_request_id) -> bool:
def _wait_for_rebase_to_finish(self, merge_request_id: Union[str, int]) -> bool:
"""
Poll the merge request until it has finished rebasing
:param merge_request_id:
:return: rebase finished successfully?
:param merge_request_id: The ID of the merge request
:return: Whether the rebase has finished successfully
"""
logger.info("Checking if rebase has finished..")
......@@ -134,10 +151,10 @@ class GitlabHelper:
return False
def make_backup(path):
def make_backup(path: str):
"""
Create a backup file in the same directory with timestamp as suffix ".bak_{timestamp}"
:param path: file to be backed up
:param path: Path to the file to be backed up
"""
path = Path(path)
if not path.exists():
......@@ -150,10 +167,10 @@ def make_backup(path):
def prepare_git_key(
private_key,
known_hosts,
private_key_path="~/.ssh/id_ed25519",
known_hosts_path="~/.ssh/known_hosts",
private_key: str,
known_hosts: str,
private_key_path: Optional[str] = "~/.ssh/id_ed25519",
known_hosts_path: Optional[str] = "~/.ssh/known_hosts",
):
"""
Prepare the git keys (put them in to the correct place) so that git could be used.
......@@ -198,24 +215,25 @@ class GitHelper:
"""
A helper class for running git commands
At the beginning of the workflow call `run_clone_in_background`.
At the beginning of the workflow call [run_clone_in_background][arkindex_worker.git.GitHelper.run_clone_in_background].
When all the files are ready to be added to git then call
`save_files` to move the files in to the git repository
[save_files][arkindex_worker.git.GitHelper.save_files] to move the files in to the git repository
and try to push them.
Pseudo code example:
in worker.configure() configure the git helper and start the cloning:
```
gitlab = GitlabHelper(...)
prepare_git_key(...)
self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...)
self.git_helper.run_clone_in_background()
```
at the end of the workflow (at the end of worker.run()) push the files to git:
```
self.git_helper.save_files(self.out_dir)
```
Examples
--------
in worker.configure() configure the git helper and start the cloning:
```
gitlab = GitlabHelper(...)
prepare_git_key(...)
self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...)
self.git_helper.run_clone_in_background()
```
at the end of the workflow (at the end of worker.run()) push the files to git:
```
self.git_helper.save_files(self.out_dir)
```
"""
def __init__(
......@@ -294,6 +312,9 @@ class GitHelper:
"""
Move files in export_out_dir to the cloned git repository
and try to merge the created files if possible.
:param export_out_dir: Path to the files to be saved
:raises sh.ErrorReturnCode: _description_
:raises Exception: _description_
"""
self._wait_for_clone_to_finish()
......@@ -350,6 +371,8 @@ class GitHelper:
"""
Move all files in the export_out_dir to the git repository
while keeping the same directory structure
:param export_out_dir: Path to the files to be moved
:return: Total count of moved files
"""
file_count = 0
file_names = [
......
......@@ -2,11 +2,12 @@
"""
Helper methods to download and open IIIF images, and manage polygons.
"""
import os
from collections import namedtuple
from io import BytesIO
from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
import requests
from PIL import Image
......@@ -21,28 +22,38 @@ from tenacity import (
from arkindex_worker import logger
# Avoid circular imports error when type checking
if TYPE_CHECKING:
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60)
BoundingBox = namedtuple("BoundingBox", ["x", "y", "width", "height"])
def open_image(path, mode="RGB", rotation_angle=0, mirrored=False):
def open_image(
path: Union[str, Path],
mode: Optional[str] = "RGB",
rotation_angle: Optional[int] = 0,
mirrored: Optional[bool] = False,
) -> Image:
"""
Open an image from a path or a URL.
.. warning:: Prefer :meth:`Element.open_image` whenever possible.
Warns:
Prefer [arkindex_worker.models.Element.open_image][] whenever possible.
:param path str: Path or URL to open the image from.
:param path: Path or URL to open the image from.
This parameter will be interpreted as a URL when it has a `http` or `https` scheme
and no file exist with this path locally.
:param mode str: Pillow mode for the image. See `the Pillow documentation
<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`_.
:param int rotation_angle: Rotation angle to apply to the image, in degrees.
:param mode: Pillow mode for the image. See [the Pillow documentation](https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes).
:param rotation_angle: Rotation angle to apply to the image, in degrees.
If it is not a multiple of 90°, then the rotation can cause empty pixels of
the mode's default color to be added for padding.
:param bool mirrored: Whether or not to mirror the image horizontally.
:returns PIL.Image: A Pillow image.
:param mirrored: Whether or not to mirror the image horizontally.
:returns: A Pillow image.
"""
if (
path.startswith("http://")
......@@ -68,12 +79,12 @@ def open_image(path, mode="RGB", rotation_angle=0, mirrored=False):
return image
def download_image(url):
def download_image(url: str) -> Image:
"""
Download an image and open it with Pillow.
:param url str: URL of the image.
:returns PIL.Image: A Pillow image.
:param url: URL of the image.
:returns: A Pillow image.
"""
assert url.startswith("http"), "Image URL must be HTTP(S)"
......@@ -110,13 +121,12 @@ def download_image(url):
return image
def polygon_bounding_box(polygon):
def polygon_bounding_box(polygon: List[List[Union[int, float]]]) -> BoundingBox:
"""
Compute the rectangle bounding box of a polygon.
:param polygon: Polygon to get the bounding box of.
:type polygon: list(list(int or float))
:returns BoundingBox: Bounding box of this polygon.
:returns: Bounding box of this polygon.
"""
x_coords, y_coords = zip(*polygon)
x, y = min(x_coords), min(y_coords)
......@@ -144,12 +154,12 @@ def _retried_request(url):
return resp
def download_tiles(url):
def download_tiles(url: str) -> Image:
"""
Reconstruct a full IIIF image on servers that cannot serve the full-sized image, using tiles.
:param str url: URL of the image.
:returns PIL.Image: A Pillow image.
:param url: URL of the image.
:returns: A Pillow image.
"""
if not url.endswith("/"):
url += "/"
......@@ -217,19 +227,19 @@ def download_tiles(url):
return full_image
def trim_polygon(polygon, image_width: int, image_height: int):
def trim_polygon(
polygon: List[List[int]], image_width: int, image_height: int
) -> List[List[int]]:
"""
Trim a polygon to an image's boundaries, with non-negative coordinates.
:param polygon: A polygon to trim.
:type: list(list(int or float) or tuple(int or float)) or tuple(tuple(int or float) or list(int or float))
:param image_width int: Width of the image.
:param image_height int: Height of the image.
:param image_width: Width of the image.
:param image_height: Height of the image.
:returns: A polygon trimmed to the image's bounds.
Some points may appear as missing, as the trimming can deduplicate points.
The first and last point are always equal, to reproduce the behavior
of the Arkindex backend.
:rtype: list(list(int or float))
:raises AssertionError: When argument types are invalid or when the trimmed polygon
is entirely outside of the image's bounds.
"""
......@@ -270,28 +280,28 @@ def trim_polygon(polygon, image_width: int, image_height: int):
return updated_polygon
def revert_orientation(element, polygon, reverse: bool = False):
def revert_orientation(
element: Union["Element", "CachedElement"],
polygon: List[List[Union[int, float]]],
reverse: Optional[bool] = False,
) -> List[List[int]]:
"""
Update the coordinates of the polygon of a child element based on the orientation of
its parent.
This method should be called before sending any polygon to Arkindex, to undo the possible
orientation applied by :meth:`Element.open_image`.
orientation applied by [arkindex_worker.models.Element.open_image][].
In some cases, we want to apply the parent's orientation on the child's polygon instead. This is done
by enabling `reverse=True`.
:param element: Parent element.
:type element: Element or CachedElement
:param polygon: Polygon corresponding to the child element.
:type polygon: list(list(int or float))
:param mode: Whether we should revert (`revert`) or apply (`apply`) the parent's orientation.
:type mode: str
:param reverse: Whether we should revert or apply the parent's orientation.
:return: A polygon with updated coordinates.
:rtype: list(list(int))
"""
from arkindex_worker.models import Element
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
assert element and isinstance(
element, (Element, CachedElement)
......
......@@ -5,12 +5,12 @@ Wrappers around API results to provide more convenient attribute access and IIIF
import tempfile
from contextlib import contextmanager
from typing import Optional
from typing import Generator, List, Optional
from PIL import Image
from requests import HTTPError
from arkindex_worker import logger
from arkindex_worker.image import download_tiles, open_image, polygon_bounding_box
class MagicDict(dict):
......@@ -63,7 +63,12 @@ class Element(MagicDict):
Describes an Arkindex element.
"""
def resize_zone_url(self, size="full"):
def resize_zone_url(self, size: str = "full") -> str:
"""
Compute the URL of the image corresponding to the size
:param size: Requested size
:return: The URL corresponding to the size
"""
if size == "full":
return self.zone.url
else:
......@@ -71,13 +76,12 @@ class Element(MagicDict):
parts[-3] = size
return "/".join(parts)
def image_url(self, size="full") -> Optional[str]:
def image_url(self, size: str = "full") -> Optional[str]:
"""
Build an URL to access the image.
When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers.
:param str size: Subresolution of the image, following the syntax of the IIIF resize parameter.
:param size: Subresolution of the image, following the syntax of the IIIF resize parameter.
:returns: An URL to the image, or None if the element does not have an image.
:rtype: str or None
"""
if not self.get("zone"):
return
......@@ -89,12 +93,27 @@ class Element(MagicDict):
url += "/"
return "{}full/{}/0/default.jpg".format(url, size)
@property
def polygon(self) -> List[float]:
"""
Access an Element's polygon.
This is a shortcut to an Element's polygon, normally accessed via
its zone field via `zone.polygon`. This is mostly done
to facilitate access to this important field by matching
the [CachedElement][arkindex_worker.cache.CachedElement].polygon field.
"""
if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id))
return self.zone.polygon
@property
def requires_tiles(self) -> bool:
"""
:return bool: Whether or not downloading and combining IIIF tiles will be necessary
Whether or not downloading and combining IIIF tiles will be necessary
to retrieve this element's image. Will be False if the element has no image.
"""
from arkindex_worker.image import polygon_bounding_box
if not self.get("zone") or self.zone.image.get("s3_url"):
return False
max_width = self.zone.image.server.max_width or float("inf")
......@@ -102,7 +121,13 @@ class Element(MagicDict):
bounding_box = polygon_bounding_box(self.zone.polygon)
return bounding_box.width > max_width or bounding_box.height > max_height
def open_image(self, *args, max_size=None, use_full_image=False, **kwargs):
def open_image(
self,
*args,
max_size: Optional[int] = None,
use_full_image: Optional[bool] = False,
**kwargs
) -> Image:
"""
Open this element's image using Pillow, rotating and mirroring it according
to the ``rotation_angle`` and ``mirrored`` attributes.
......@@ -111,12 +136,12 @@ class Element(MagicDict):
to bypass IIIF servers, the image will be cropped to the rectangle bounding box
of the ``zone.polygon`` attribute.
.. warning::
Warns:
----
This method implicitly applies the element's orientation to the image.
If your process uses the returned image to find more polygons and send them
back to Arkindex, use the :meth:`arkindex_worker.image.revert_orientation`
back to Arkindex, use the [arkindex_worker.image.revert_orientation][]
helper to undo the orientation on all polygons before sending them, as the
Arkindex API expects unoriented polygons.
......@@ -125,19 +150,24 @@ class Element(MagicDict):
:param max_size: The maximum size of the requested image.
:type max_size: int or None
:param bool use_full_image: Ignore the ``zone.polygon`` and always
:param use_full_image: Ignore the ``zone.polygon`` and always
retrieve the image without cropping.
:param \\*args: Positional arguments passed to :meth:`arkindex_worker.image.open_image`.
:param \\**kwargs: Keyword arguments passed to :meth:`arkindex_worker.image.open_image`.
:param *args: Positional arguments passed to [arkindex_worker.image.open_image][].
:param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][].
:raises ValueError: When the element does not have an image.
:raises NotImplementedError: When the ``max_size`` parameter is set,
but the IIIF server's configuration requires downloading and combining tiles
to retrieve the image.
:raises NotImplementedError: When an S3 URL has been used to download the image,
but the URL has expired. Re-fetching the URL automatically is not supported.
:return PIL.Image: A Pillow image.
:return: A Pillow image.
"""
from arkindex_worker.image import (
download_tiles,
open_image,
polygon_bounding_box,
)
if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id))
......@@ -197,18 +227,25 @@ class Element(MagicDict):
raise
@contextmanager
def open_image_tempfile(self, format="jpeg", *args, **kwargs):
def open_image_tempfile(
self, format: Optional[str] = "jpeg", *args, **kwargs
) -> Generator[tempfile.NamedTemporaryFile, None, None]:
"""
Get the element's image as a temporary file stored on the disk.
To be used as a context manager::
To be used as a context manager.
with element.open_image_tempfile() as f:
...
Example
----
```
with element.open_image_tempfile() as f:
...
```
:param format str: File format to use the store the image on the disk.
:param format: File format to use the store the image on the disk.
Must be a format supported by Pillow.
:param \\*args: Positional arguments passed to :meth:`arkindex_worker.image.open_image`.
:param \\**kwargs: Keyword arguments passed to :meth:`arkindex_worker.image.open_image`.
:param *args: Positional arguments passed to [arkindex_worker.image.open_image][].
:param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][].
"""
with tempfile.NamedTemporaryFile() as f:
self.open_image(*args, **kwargs).save(f, format=format)
......@@ -229,12 +266,3 @@ class Transcription(MagicDict):
def __str__(self):
return "Transcription ({})".format(self.id)
class Corpus(MagicDict):
"""
Describes an Arkindex corpus.
"""
def __str__(self):
return "Corpus {} ({})".format(self.name, self.id)
......@@ -7,10 +7,14 @@ import json
import traceback
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.models import Transcription
class Reporter(object):
......@@ -19,7 +23,11 @@ class Reporter(object):
"""
def __init__(
self, name="Unknown worker", slug="unknown-slug", version=None, **kwargs
self,
name: Optional[str] = "Unknown worker",
slug: Optional[str] = "unknown-slug",
version: Optional[str] = None,
**kwargs,
):
self.report_data = {
"name": name,
......@@ -46,57 +54,56 @@ class Reporter(object):
"classifications": {},
# Created entities ({"id": "", "type": "", "name": ""}) from this element
"entities": [],
# Created transcription entities ({"transcription_id": "", "entity_id": ""}) from this element
"transcription_entities": [],
# Created metadata ({"id": "", "type": "", "name": ""}) from this element
"metadata": [],
"errors": [],
},
)
def process(self, element_id):
def process(self, element_id: Union[str, UUID]):
"""
Report that a specific element ID is being processed.
:param element_id: ID of the element being processed.
:type element_id: str or uuid.UUID
"""
# Just call the element initializer
self._get_element(element_id)
def add_element(self, parent_id, type, type_count=1):
def add_element(self, parent_id: Union[str, UUID], type: str, type_count: int = 1):
"""
Report creating an element as a child of another.
:param parent_id: ID of the parent element.
:type parent_id: str or uuid.UUID
:param type str: Slug of the type of the child element.
:param type_count int: How many elements of this type were created. Defaults to 1.
:param type: Slug of the type of the child element.
:param type_count: How many elements of this type were created.
"""
elements = self._get_element(parent_id)["elements"]
elements.setdefault(type, 0)
elements[type] += type_count
def add_classification(self, element_id, class_name):
def add_classification(self, element_id: Union[str, UUID], class_name: str):
"""
Report creating a classification on an element.
:param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param class_name str: Name of the ML class of the new classification.
:param class_name: Name of the ML class of the new classification.
"""
classifications = self._get_element(element_id)["classifications"]
classifications.setdefault(class_name, 0)
classifications[class_name] += 1
def add_classifications(self, element_id, classifications):
def add_classifications(
self, element_id: Union[str, UUID], classifications: List[Dict[str, str]]
):
"""
Report creating one or more classifications at once on an element.
:param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param classifications: List of classifications.
Each classification is represented as a ``dict`` with a ``class_name`` key
holding the name of the ML class being used.
:type classifications: List[Dict[str, str]]
"""
assert isinstance(
classifications, list
......@@ -110,29 +117,57 @@ class Reporter(object):
)
element["classifications"] = dict(counter)
def add_transcription(self, element_id, count=1):
def add_transcription(self, element_id: Union[str, UUID], count=1):
"""
Report creating a transcription on an element.
:param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param count int: Number of transcriptions created at once, defaults to 1.
:param count: Number of transcriptions created at once
"""
self._get_element(element_id)["transcriptions"] += count
def add_entity(self, element_id, entity_id, type, name):
def add_entity(
self,
element_id: Union[str, UUID],
entity_id: Union[str, UUID],
type: str,
name: str,
):
"""
Report creating an entity on an element.
:param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param entity_id str: ID of the new entity.
:param type str: Type of the entity.
:param name str: Name of the entity.
:param entity_id: ID of the new entity.
:param type: Type of the entity.
:param name: Name of the entity.
"""
entities = self._get_element(element_id)["entities"]
entities.append({"id": entity_id, "type": type, "name": name})
def add_transcription_entity(
self,
entity_id: Union[str, UUID],
transcription: Transcription,
transcription_entity_id: Union[str, UUID],
):
"""
Report creating a transcription entity on an element.
:param entity_id: ID of the entity element.
:param transcription: Transcription to add the entity on
:param transcription_entity_id: ID of the transcription entity that is created.
"""
transcription_entities = self._get_element(transcription.element.id)[
"transcription_entities"
]
transcription_entities.append(
{
"transcription_id": transcription.id,
"entity_id": entity_id,
"transcription_entity_id": transcription_entity_id,
}
)
def add_entity_link(self, *args, **kwargs):
"""
Report creating an entity link. Not currently supported.
......@@ -149,26 +184,30 @@ class Reporter(object):
"""
raise NotImplementedError
def add_metadata(self, element_id, metadata_id, type, name):
def add_metadata(
self,
element_id: Union[str, UUID],
metadata_id: Union[str, UUID],
type: str,
name: str,
):
"""
Report creating a metadata from an element.
:param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param metadata_id str: ID of the new metadata.
:param type str: Type of the metadata.
:param name str: Name of the metadata.
:param metadata_id: ID of the new metadata.
:param type: Type of the metadata.
:param name: Name of the metadata.
"""
metadata = self._get_element(element_id)["metadata"]
metadata.append({"id": metadata_id, "type": type, "name": name})
def error(self, element_id, exception):
def error(self, element_id: Union[str, UUID], exception: Exception):
"""
Report that a Python exception occurred when processing an element.
:param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param exception Exception: A Python exception.
:param exception: A Python exception.
"""
error_data = {
"class": exception.__class__.__name__,
......@@ -186,12 +225,11 @@ class Reporter(object):
self._get_element(element_id)["errors"].append(error_data)
def save(self, path):
def save(self, path: Union[str, Path]):
"""
Save the ML report to the specified path.
:param path: Path to save the ML report to.
:type path: str or pathlib.Path
"""
logger.info(f"Saving ML report to {path}")
with open(path, "w") as f:
......
......@@ -9,11 +9,15 @@ from timeit import default_timer
class Timer(object):
"""
A context manager to help measure execution times. Example usage::
A context manager to help measure execution times.
with Timer() as t:
# do something interesting
print(t.delta) # X days, X:XX:XX
Example
---
```
with Timer() as t:
# do something interesting
print(t.delta) # X days, X:XX:XX
```
"""
def __init__(self):
......
......@@ -8,6 +8,7 @@ import os
import sys
import uuid
from enum import Enum
from typing import Iterable, List, Union
from apistar.exceptions import ErrorResponse
......@@ -66,7 +67,13 @@ class ElementsWorker(
``arkindex.worker``, which provide helpers to read and write to the Arkindex API.
"""
def __init__(self, description="Arkindex Elements Worker", support_cache=False):
def __init__(
self, description: str = "Arkindex Elements Worker", support_cache: bool = False
):
"""
:param description: The worker's description
:param support_cache: Whether the worker supports cache
"""
super().__init__(description, support_cache)
# Add mandatory argument to process elements
......@@ -87,14 +94,13 @@ class ElementsWorker(
self._worker_version_cache = {}
def list_elements(self):
def list_elements(self) -> Union[Iterable[CachedElement], List[str]]:
"""
List the elements to be processed, either from the CLI arguments or
the cache database when enabled.
:return: An iterable of :class:`CachedElement` when cache support is enabled,
:return: An iterable of [CachedElement][arkindex_worker.cache.CachedElement] when cache support is enabled,
and a list of strings representing element IDs otherwise.
:rtype: Iterable[CachedElement] or List[str]
"""
assert not (
self.args.elements_list and self.args.element
......@@ -121,12 +127,10 @@ class ElementsWorker(
return out
@property
def store_activity(self):
def store_activity(self) -> bool:
"""
Whether or not WorkerActivity support has been enabled on the DataImport
used to run this worker.
:rtype: bool
"""
if self.args.dev:
return False
......@@ -136,6 +140,9 @@ class ElementsWorker(
return self.process_information.get("activity_state") == "ready"
def configure(self):
"""
Setup the worker using CLI arguments and environment variables.
"""
# CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args()
......@@ -153,8 +160,8 @@ class ElementsWorker(
def run(self):
"""
Implements an Arkindex worker that goes through each element returned by
:meth:`list_elements`. It calls :meth:`process_element`, catching exceptions
and reporting them using the :class:`Reporter`, and handles saving the report
[list_elements][arkindex_worker.worker.ElementsWorker.list_elements]. It calls [process_element][arkindex_worker.worker.ElementsWorker.process_element], catching exceptions
and reporting them using the [Reporter][arkindex_worker.reporting.Reporter], and handles saving the report
once the process is complete as well as WorkerActivity updates when enabled.
"""
self.configure()
......@@ -235,24 +242,24 @@ class ElementsWorker(
if failed >= count: # Everything failed!
sys.exit(1)
def process_element(self, element):
def process_element(self, element: Union[Element, CachedElement]):
"""
Override this method to implement your worker and process a single Arkindex element at once.
:param element: The element to process.
Will be a CachedElement instance if cache support is enabled,
and an Element instance otherwise.
:type element: Element or CachedElement
"""
def update_activity(self, element_id, state) -> bool:
def update_activity(
self, element_id: Union[str, uuid.UUID], state: ActivityState
) -> bool:
"""
Update the WorkerActivity for this element and worker.
:param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param state ActivityState: New WorkerActivity state for this element.
:returns bool: True if the update has been successful or WorkerActivity support is disabled.
:param state: New WorkerActivity state for this element.
:returns: True if the update has been successful or WorkerActivity support is disabled.
False if the update has failed due to a conflict; this worker might have already processed
this element.
"""
......
......@@ -8,6 +8,7 @@ import json
import logging
import os
from pathlib import Path
from typing import Optional
import gnupg
import yaml
......@@ -32,11 +33,11 @@ from arkindex_worker.cache import (
)
def _is_500_error(exc) -> bool:
def _is_500_error(exc: Exception) -> bool:
"""
Check if an Arkindex API error has a HTTP 5xx error code.
Used to retry most API calls in :class:`BaseWorker`.
:rtype: bool
Used to retry most API calls in [BaseWorker][arkindex_worker.worker.base.BaseWorker].
:param exc: Exception to check
"""
if not isinstance(exc, ErrorResponse):
return False
......@@ -44,17 +45,27 @@ def _is_500_error(exc) -> bool:
return 500 <= exc.status_code < 600
class ModelNotFoundError(Exception):
"""
Exception raised when the path towards the model is invalid
"""
class BaseWorker(object):
"""
Base class for Arkindex workers.
"""
def __init__(self, description="Arkindex Base Worker", support_cache=False):
def __init__(
self,
description: Optional[str] = "Arkindex Base Worker",
support_cache: Optional[bool] = False,
):
"""
Initialize the worker.
:param description str: Description shown in the ``worker-...`` command line tool.
:param support_cache bool: Whether or not this worker supports the cache database.
:param description: Description shown in the ``worker-...`` command line tool.
:param support_cache: Whether or not this worker supports the cache database.
Override the constructor and set this parameter to start using the cache database.
"""
......@@ -89,6 +100,12 @@ class BaseWorker(object):
action="store_true",
default=False,
)
# To load models locally
self.parser.add_argument(
"--model-dir",
help=("The path to a local model's directory (development only)."),
type=Path,
)
# Call potential extra arguments
self.add_arguments()
......@@ -105,11 +122,10 @@ class BaseWorker(object):
self.work_dir = os.path.join(xdg_data_home, "arkindex")
os.makedirs(self.work_dir, exist_ok=True)
self.worker_version_id = os.environ.get("WORKER_VERSION_ID")
if not self.worker_version_id:
logger.warning(
"Missing WORKER_VERSION_ID environment variable, worker is in read-only mode"
)
# Store task ID. This is only available when running in production
# through a ponos agent
self.task_id = os.environ.get("PONOS_TASK")
self.worker_run_id = os.environ.get("ARKINDEX_WORKER_RUN_ID")
if not self.worker_run_id:
logger.warning(
......@@ -119,7 +135,11 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory")
self.process_information = None
# corpus_id will be updated in configure() using the worker_run's corpus
# or in configure_for_developers() from the environment
self.corpus_id = None
self.user_configuration = {}
self.model_configuration = {}
self.support_cache = support_cache
# use_cache will be updated in configure() if the cache is supported and if there
# is at least one available sqlite database either given or in the parent tasks
......@@ -133,22 +153,24 @@ class BaseWorker(object):
"""
Whether or not the worker can publish data.
:returns: False when dev mode is enabled with the ``--dev`` CLI argument,
or when no worker version ID is provided.
:rtype: bool
False when dev mode is enabled with the ``--dev`` CLI argument,
when no worker run ID is provided
"""
return (
self.args.dev
or self.worker_version_id is None
or self.worker_run_id is None
)
return self.args.dev or self.worker_run_id is None
def setup_api_client(self):
"""
Create an ArkindexClient to make API requests towards Arkindex instances.
"""
# Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
def configure_for_developers(self):
"""
Setup the necessary configuration needed when working in `read_only` mode.
This is the method called when running a worker locally.
"""
assert self.is_read_only
# Setup logging level if verbose or if ARKINDEX_DEBUG is set to true
if self.args.verbose or os.environ.get("ARKINDEX_DEBUG"):
......@@ -169,12 +191,20 @@ class BaseWorker(object):
required_secrets = []
logger.warning("Running without any extra configuration")
# Define corpus_id from environment
self.corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
if not self.corpus_id:
logger.warning(
"'ARKINDEX_CORPUS_ID' was not set in the environment. Any API request involving a `corpus_id` will fail."
)
# Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets}
def configure(self):
"""
Configure worker using CLI args and environment variables.
Setup the necessary configuration needed using CLI args and environment variables.
This is the method called when running a worker on Arkindex.
"""
assert not self.is_read_only
# Setup logging level if verbose or if ARKINDEX_DEBUG is set to true
......@@ -188,15 +218,22 @@ class BaseWorker(object):
# Load process information
self.process_information = worker_run["process"]
# Load corpus id
self.corpus_id = worker_run["process"]["corpus"]
# Load worker version information
worker_version = worker_run["worker_version"]
# Store worker version id
self.worker_version_id = worker_version["id"]
self.worker_details = worker_version["worker"]
logger.info(
f"Loaded worker {self.worker_details['name']} revision {worker_version['revision']['hash'][0:7]} from API"
)
# Retrieve initial configuration from API
self.config = worker_version["configuration"].get("configuration")
self.config = worker_version["configuration"].get("configuration", {})
if "user_configuration" in worker_version["configuration"]:
# Add default values (if set) to user_configuration
for key, value in worker_version["configuration"][
......@@ -211,23 +248,30 @@ class BaseWorker(object):
# Load worker run configuration when available
worker_configuration = worker_run.get("configuration")
self.user_configuration = (
worker_configuration.get("configuration") if worker_configuration else None
)
if self.user_configuration:
if worker_configuration and worker_configuration.get("configuration"):
logger.info("Loaded user configuration from WorkerRun")
# if debug mode is set to true activate debug mode in logger
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
self.user_configuration.update(worker_configuration.get("configuration"))
# Load model version configuration when available
model_version = worker_run.get("model_version")
if model_version and model_version.get("configuration"):
logger.info("Loaded model version configuration from WorkerRun")
self.model_configuration.update(model_version.get("configuration"))
# if debug mode is set to true activate debug mode in logger
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
def configure_cache(self):
task_id = os.environ.get("PONOS_TASK")
"""
Setup the necessary attribute when using the cache system of `Base-Worker`.
"""
paths = None
if self.support_cache and self.args.database is not None:
self.use_cache = True
elif self.support_cache and task_id:
task = self.request("RetrieveTaskFromAgent", id=task_id)
elif self.support_cache and self.task_id:
task = self.request("RetrieveTaskFromAgent", id=self.task_id)
paths = retrieve_parents_cache_path(
task["parents"],
data_dir=os.environ.get("PONOS_DATA", "/data"),
......@@ -242,7 +286,9 @@ class BaseWorker(object):
), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database
else:
cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id)
cache_dir = os.path.join(
os.environ.get("PONOS_DATA", "/data"), self.task_id
)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite")
......@@ -261,11 +307,11 @@ class BaseWorker(object):
else:
logger.debug("Cache is disabled")
def load_secret(self, name):
def load_secret(self, name: str):
"""
Load a Ponos secret by name.
:param str name: Name of the Ponos secret.
:param name: Name of the Ponos secret.
:raises Exception: When the secret cannot be loaded from the API nor the local secrets directory.
"""
secret = None
......@@ -312,6 +358,35 @@ class BaseWorker(object):
# By default give raw secret payload
return secret
def find_model_directory(self) -> Path:
"""
Find the local path to the model. This supports two modes:
- the worker runs in ponos, the model is available at `/data/current`
- the worker runs locally, the developer may specify it using either
- the `model_dir` configuration parameter
- the `--model-dir` CLI parameter
:return: Path to the model on disk
"""
if self.task_id:
# When running in production with ponos, the agent
# downloads the model and set it in the current task work dir
return Path(self.work_dir)
else:
model_dir = self.config.get("model_dir", self.args.model_dir)
if model_dir is None:
raise ModelNotFoundError(
"No path to the model was provided. "
"Please provide model_dir either through configuration "
"or as CLI argument."
)
model_dir = Path(model_dir)
if not model_dir.exists():
raise ModelNotFoundError(
f"The path {model_dir} does not link to any directory"
)
return model_dir
@retry(
retry=retry_if_exception(_is_500_error),
wait=wait_exponential(multiplier=2, min=3),
......
......@@ -3,7 +3,8 @@
ElementsWorker methods for classifications and ML classes.
"""
import os
from typing import Dict, List, Optional, Union
from uuid import UUID
from apistar.exceptions import ErrorResponse
from peewee import IntegrityError
......@@ -14,50 +15,38 @@ from arkindex_worker.models import Element
class ClassificationMixin(object):
"""
Mixin for the :class:`ElementsWorker` to add ``MLClass`` and ``Classification`` helpers.
"""
def load_corpus_classes(self, corpus_id):
def load_corpus_classes(self):
"""
Load all ML classes for the given corpus ID and store them in the ``self.classes`` cache.
:param corpus_id str: ID of the corpus.
Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache.
"""
corpus_classes = self.api_client.paginate(
"ListCorpusMLClasses",
id=corpus_id,
id=self.corpus_id,
)
self.classes[corpus_id] = {
self.classes[self.corpus_id] = {
ml_class["name"]: ml_class["id"] for ml_class in corpus_classes
}
logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes")
logger.info(f"Loaded {len(self.classes[self.corpus_id])} ML classes")
def get_ml_class_id(self, corpus_id, ml_class):
def get_ml_class_id(self, ml_class: str) -> str:
"""
Return the MLClass ID corresponding to the given class name on a specific corpus.
If no MLClass exists for this class name, a new one is created.
:param corpus_id: ID of the corpus, or None to use the ``ARKINDEX_CORPUS_ID`` environment variable.
:type corpus_id: str or None
:param ml_class str: Name of the MLClass.
:returns str: ID of the retrieved or created MLClass.
:param ml_class: Name of the MLClass.
:returns: ID of the retrieved or created MLClass.
"""
if corpus_id is None:
corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
if not self.classes.get(self.corpus_id):
self.load_corpus_classes()
if not self.classes.get(corpus_id):
self.load_corpus_classes(corpus_id)
ml_class_id = self.classes[corpus_id].get(ml_class)
ml_class_id = self.classes[self.corpus_id].get(ml_class)
if ml_class_id is None:
logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}")
logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
try:
response = self.request(
"CreateMLClass", id=corpus_id, body={"name": ml_class}
"CreateMLClass", id=self.corpus_id, body={"name": ml_class}
)
ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
ml_class_id = self.classes[self.corpus_id][ml_class] = response["id"]
logger.debug(f"Created ML class {response['id']}")
except ErrorResponse as e:
# Only reload for 400 errors
......@@ -68,26 +57,53 @@ class ClassificationMixin(object):
logger.info(
f"Reloading corpus classes to see if {ml_class} already exists"
)
self.load_corpus_classes(corpus_id)
self.load_corpus_classes()
assert (
ml_class in self.classes[corpus_id]
ml_class in self.classes[self.corpus_id]
), "Missing class {ml_class} even after reloading"
ml_class_id = self.classes[corpus_id][ml_class]
ml_class_id = self.classes[self.corpus_id][ml_class]
return ml_class_id
def retrieve_ml_class(self, ml_class_id: str) -> str:
"""
Retrieve the name of the MLClass from its ID.
:param ml_class_id: ID of the searched MLClass.
:return: The MLClass's name
"""
# Load the corpus' MLclasses if they are not available yet
if self.corpus_id not in self.classes:
self.load_corpus_classes()
# Filter classes by this ml_class_id
ml_class_name = next(
filter(
lambda x: self.classes[self.corpus_id][x] == ml_class_id,
self.classes[self.corpus_id],
),
None,
)
assert (
ml_class_name is not None
), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})"
return ml_class_name
def create_classification(
self, element, ml_class, confidence, high_confidence=False
):
self,
element: Union[Element, CachedElement],
ml_class: str,
confidence: float,
high_confidence: Optional[bool] = False,
) -> Dict[str, str]:
"""
Create a classification on the given element through the API.
:param element: The element to create a classification on.
:type element: Element or CachedElement
:param ml_class str: Name of the MLClass to use.
:param confidence float: Confidence score for the classification. Must be between 0 and 1.
:param high_confidence bool: Whether or not the classification is of high confidence.
:returns dict: The created classification, as returned by the ``CreateClassification`` API endpoint.
:param ml_class: Name of the MLClass to use.
:param confidence: Confidence score for the classification. Must be between 0 and 1.
:param high_confidence: Whether or not the classification is of high confidence.
:returns: The created classification, as returned by the ``CreateClassification`` API endpoint.
"""
assert element and isinstance(
element, (Element, CachedElement)
......@@ -106,14 +122,13 @@ class ClassificationMixin(object):
"Cannot create classification as this worker is in read-only mode"
)
return
try:
created = self.request(
"CreateClassification",
body={
"element": str(element.id),
"ml_class": self.get_ml_class_id(None, ml_class),
"worker_version": self.worker_version_id,
"ml_class": self.get_ml_class_id(ml_class),
"worker_run_id": self.worker_run_id,
"confidence": confidence,
"high_confidence": high_confidence,
},
......@@ -129,7 +144,7 @@ class ClassificationMixin(object):
"class_name": ml_class,
"confidence": created["confidence"],
"state": created["state"],
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
]
CachedClassification.insert_many(to_insert).execute()
......@@ -139,15 +154,23 @@ class ClassificationMixin(object):
)
except ErrorResponse as e:
# Detect already existing classification
if (
e.status_code == 400
and "non_field_errors" in e.content
and "The fields element, worker_version, ml_class must make a unique set."
in e.content["non_field_errors"]
):
logger.warning(
f"This worker version has already set {ml_class} on element {element.id}"
)
if e.status_code == 400 and "non_field_errors" in e.content:
if (
"The fields element, worker_version, ml_class must make a unique set."
in e.content["non_field_errors"]
):
logger.warning(
f"This worker version has already set {ml_class} on element {element.id}"
)
elif (
"The fields element, worker_run, ml_class must make a unique set."
in e.content["non_field_errors"]
):
logger.warning(
f"This worker run has already set {ml_class} on element {element.id}"
)
else:
raise
return
# Propagate any other API error
......@@ -157,25 +180,22 @@ class ClassificationMixin(object):
return created
def create_classifications(self, element, classifications):
def create_classifications(
self,
element: Union[Element, CachedElement],
classifications: List[Dict[str, Union[str, float, bool]]],
) -> List[Dict[str, Union[str, float, bool]]]:
"""
Create multiple classifications at once on the given element through the API.
:param element: The element to create classifications on.
:type element: Element or CachedElement
:param classifications: The classifications to create, as a list of dicts with the following keys:
class_name (str)
Name of the MLClass for this classification.
confidence (float)
Confidence score, between 0 and 1.
high_confidence (bool)
High confidence state of the classification.
:param classifications: The classifications to create, a list of dicts. Each of them contains
a **ml_class_id** (str), the ID of the MLClass for this classification;
a **confidence** (float), the confidence score, between 0 and 1;
a **high_confidence** (bool), the high confidence state of the classification.
:type classifications: List[Dict[str, Union[str, float, bool]]]
:returns: List of created classifications, as returned in the ``classifications`` field by
the ``CreateClassifications`` API endpoint.
:rtype: List[Dict[str, Union[str, float, bool]]]
"""
assert element and isinstance(
element, (Element, CachedElement)
......@@ -185,10 +205,18 @@ class ClassificationMixin(object):
), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications):
class_name = classification.get("class_name")
assert class_name and isinstance(
class_name, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str"
ml_class_id = classification.get("ml_class_id")
assert ml_class_id and isinstance(
ml_class_id, str
), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
# Make sure it's a valid UUID
try:
UUID(ml_class_id)
except ValueError:
raise ValueError(
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
)
confidence = classification.get("confidence")
assert (
......@@ -213,12 +241,13 @@ class ClassificationMixin(object):
"CreateClassifications",
body={
"parent": str(element.id),
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"classifications": classifications,
},
)["classifications"]
for created_cl in created_cls:
created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache:
......@@ -228,10 +257,10 @@ class ClassificationMixin(object):
{
"id": created_cl["id"],
"element_id": element.id,
"class_name": created_cl["class_name"],
"class_name": created_cl.pop("class_name"),
"confidence": created_cl["confidence"],
"state": created_cl["state"],
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
for created_cl in created_cls
]
......
......@@ -2,17 +2,23 @@
"""
ElementsWorker methods for elements and element types.
"""
import uuid
from typing import Dict, Iterable, List, NamedTuple, Optional, Union
from peewee import IntegrityError
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Corpus, Element
from arkindex_worker.models import Element
class ElementType(NamedTuple):
"""
Arkindex Type of an element
"""
ElementType = NamedTuple("ElementType", name=str, slug=str, is_folder=bool)
name: str
slug: str
is_folder: bool
class MissingTypeError(Exception):
......@@ -22,15 +28,11 @@ class MissingTypeError(Exception):
class ElementMixin(object):
"""
Mixin for the :class:`ElementsWorker` to provide ``Element`` helpers.
"""
def create_required_types(self, corpus: Corpus, element_types: List[ElementType]):
def create_required_types(self, element_types: List[ElementType]):
"""Creates given element types in the corpus.
:param Corpus corpus: The corpus to create types on.
:param List[ElementType] element_types: The missing element types to create.
:param element_types: The missing element types to create.
"""
for element_type in element_types:
self.request(
......@@ -39,47 +41,42 @@ class ElementMixin(object):
"slug": element_type.slug,
"display_name": element_type.name,
"folder": element_type.is_folder,
"corpus": corpus.id,
"corpus": self.corpus_id,
},
)
logger.info(f"Created a new element type with slug {element_type.slug}")
def check_required_types(
self, corpus_id: str, *type_slugs: str, create_missing: bool = False
self, *type_slugs: str, create_missing: bool = False
) -> bool:
"""
Check that a corpus has a list of required element types,
and raise an exception if any of them are missing.
:param str corpus_id: ID of the corpus to check types on.
:param str \\*type_slugs: Type slugs to look for.
:param bool create_missing: Whether missing types should be created.
:returns bool: True if all of the specified type slugs have been found.
:param *type_slugs: Type slugs to look for.
:param create_missing: Whether missing types should be created.
:returns: Whether all of the specified type slugs have been found.
:raises MissingTypeError: If any of the specified type slugs were not found.
"""
assert isinstance(
corpus_id, (uuid.UUID, str)
), "Corpus ID should be a string or UUID"
assert len(type_slugs), "At least one element type slug is required."
assert all(
isinstance(slug, str) for slug in type_slugs
), "Element type slugs must be strings."
corpus = Corpus(self.request("RetrieveCorpus", id=corpus_id))
available_slugs = {element_type.slug for element_type in corpus.types}
corpus = self.request("RetrieveCorpus", id=self.corpus_id)
available_slugs = {element_type["slug"] for element_type in corpus["types"]}
missing_slugs = set(type_slugs) - available_slugs
if missing_slugs:
if create_missing:
self.create_required_types(
corpus,
element_types=[
ElementType(slug, slug, False) for slug in missing_slugs
],
)
else:
raise MissingTypeError(
f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus.name} corpus ({corpus.id}).'
f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus["name"]} corpus ({corpus["id"]}).'
)
return True
......@@ -91,19 +88,17 @@ class ElementMixin(object):
name: str,
polygon: List[List[Union[int, float]]],
confidence: Optional[float] = None,
slim_output: bool = True,
slim_output: Optional[bool] = True,
) -> str:
"""
Create a child element on the given element through the API.
:param Element element: The parent element.
:param str type: Slug of the element type for this child element.
:param str name: Name of the child element.
:param type: Slug of the element type for this child element.
:param name: Name of the child element.
:param polygon: Polygon of the child element.
:type polygon: list(list(int or float))
:param confidence: Optional confidence score, between 0.0 and 1.0.
:type confidence: float or None
:returns str: UUID of the created element.
:returns: UUID of the created element.
"""
assert element and isinstance(
element, Element
......@@ -143,7 +138,7 @@ class ElementMixin(object):
"corpus": element.corpus.id,
"polygon": polygon,
"parent": element.id,
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"confidence": confidence,
},
)
......@@ -162,7 +157,6 @@ class ElementMixin(object):
Create child elements on the given element in a single API request.
:param parent: Parent element for all the new child elements. The parent must have an image and a polygon.
:type parent: Element or CachedElement
:param elements: List of dicts, one per element. Each dict can have the following keys:
name (str)
......@@ -178,9 +172,7 @@ class ElementMixin(object):
confidence (float or None)
Optional confidence score, between 0.0 and 1.0.
:type elements: list(dict(str, Any))
:return: List of dicts, with each dict having a single key, ``id``, holding the UUID of each created element.
:rtype: list(dict(str, str))
"""
if isinstance(parent, Element):
assert parent.get(
......@@ -241,7 +233,7 @@ class ElementMixin(object):
"CreateElements",
id=parent.id,
body={
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"elements": elements,
},
)
......@@ -273,7 +265,7 @@ class ElementMixin(object):
"type": element["type"],
"image_id": image_id,
"polygon": element["polygon"],
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"confidence": element.get("confidence"),
}
for idx, element in enumerate(elements)
......@@ -301,38 +293,27 @@ class ElementMixin(object):
List children of an element.
:param element: Parent element to find children of.
:type element: Union[Element, CachedElement]
:param folder: Restrict to or exclude elements with folder types.
This parameter is not supported when caching is enabled.
:type folder: Optional[bool]
:param name: Restrict to elements whose name contain a substring (case-insensitive).
This parameter is not supported when caching is enabled.
:type name: Optional[str]
:param recursive: Look for elements recursively (grand-children, etc.)
This parameter is not supported when caching is enabled.
:type recursive: Optional[bool]
:param type: Restrict to elements with a specific type slug
This parameter is not supported when caching is enabled.
:type type: Optional[str]
:param with_classes: Include each element's classifications in the response.
This parameter is not supported when caching is enabled.
:type with_classes: Optional[bool]
:param with_corpus: Include each element's corpus in the response.
This parameter is not supported when caching is enabled.
:type with_corpus: Optional[bool]
:param with_has_children: Include the ``has_children`` attribute in the response,
indicating if this element has child elements of its own.
This parameter is not supported when caching is enabled.
:type with_has_children: Optional[bool]
:param with_zone: Include the ``zone`` attribute in the response,
holding the element's image and polygon.
This parameter is not supported when caching is enabled.
:type with_zone: Optional[bool]
:param worker_version: Restrict to elements created by a worker version with this UUID.
:type worker_version: Optional[Union[str, bool]]
:return: An iterable of dicts from the ``ListElementChildren`` API endpoint,
or an iterable of :class:`CachedElement` when caching is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedElement]]
or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
"""
assert element and isinstance(
element, (Element, CachedElement)
......
......@@ -3,8 +3,8 @@
ElementsWorker methods for entities.
"""
import os
from enum import Enum
from typing import Dict, Optional, Union
from peewee import IntegrityError
......@@ -28,29 +28,23 @@ class EntityType(Enum):
class EntityMixin(object):
"""
Mixin for the :class:`ElementsWorker` to add ``Entity`` helpers.
"""
def create_entity(
self, element, name, type, corpus=None, metas=dict(), validated=None
self,
element: Union[Element, CachedElement],
name: str,
type: EntityType,
metas=dict(),
validated=None,
):
"""
Create an entity on the given corpus.
If cache support is enabled, a :class:`CachedEntity` will also be created.
If cache support is enabled, a [CachedEntity][arkindex_worker.cache.CachedEntity] will also be created.
:param element: An element on which the entity will be reported with the :class:`Reporter`.
:param element: An element on which the entity will be reported with the [Reporter][arkindex_worker.reporting.Reporter].
This does not have any effect on the entity itself.
:type element: Element or CachedElement
:param name str: Name of the entity.
:param type EntityType: Type of the entity.
:param corpus: UUID of the corpus to create an entity on, or None to use the
value of the ``ARKINDEX_CORPUS_ID`` environment variable.
:type corpus: str or None
:param name: Name of the entity.
:param type: Type of the entity.
"""
if corpus is None:
corpus = os.environ.get("ARKINDEX_CORPUS_ID")
assert element and isinstance(
element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement"
......@@ -60,9 +54,6 @@ class EntityMixin(object):
assert type and isinstance(
type, EntityType
), "type shouldn't be null and should be of type EntityType"
assert corpus and isinstance(
corpus, str
), "corpus shouldn't be null and should be of type str"
if metas:
assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None:
......@@ -78,8 +69,8 @@ class EntityMixin(object):
"type": type.value,
"metas": metas,
"validated": validated,
"corpus": corpus,
"worker_version": self.worker_version_id,
"corpus": self.corpus_id,
"worker_run_id": self.worker_run_id,
},
)
self.report.add_entity(element.id, entity["id"], type.value, name)
......@@ -94,7 +85,7 @@ class EntityMixin(object):
"name": name,
"validated": validated if validated is not None else False,
"metas": metas,
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
]
CachedEntity.insert_many(to_insert).execute()
......@@ -104,26 +95,29 @@ class EntityMixin(object):
return entity["id"]
def create_transcription_entity(
self, transcription, entity, offset, length, confidence=None
):
self,
transcription: Transcription,
entity: str,
offset: int,
length: int,
confidence: Optional[float] = None,
) -> Optional[Dict[str, Union[str, int]]]:
"""
Create a link between an existing entity and an existing transcription.
If cache support is enabled, a :class:`CachedTranscriptionEntity` will also be created.
If cache support is enabled, a `CachedTranscriptionEntity` will also be created.
:param transcription str: UUID of the existing transcription.
:param entity str: UUID of the existing entity.
:param offset int: Starting position of the entity in the transcription's text,
:param transcription: Transcription to create the entity on.
:param entity: UUID of the existing entity.
:param offset: Starting position of the entity in the transcription's text,
as a 0-based index.
:param length int: Length of the entity in the transcription's text.
:param length: Length of the entity in the transcription's text.
:param confidence: Optional confidence score between 0 or 1.
:type confidence: float or None
:returns: A dict as returned by the ``CreateTranscriptionEntity`` API endpoint,
or None if the worker is in read-only mode.
:rtype: dict(str, str or int) or None
"""
assert transcription and isinstance(
transcription, str
), "transcription shouldn't be null and should be of type str"
transcription, Transcription
), "transcription shouldn't be null and should be a Transcription"
assert entity and isinstance(
entity, str
), "entity shouldn't be null and should be of type str"
......@@ -146,27 +140,27 @@ class EntityMixin(object):
"entity": entity,
"length": length,
"offset": offset,
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
if confidence is not None:
body["confidence"] = confidence
transcription_ent = self.request(
"CreateTranscriptionEntity",
id=transcription,
id=transcription.id,
body=body,
)
# TODO: Report transcription entity creation
self.report.add_transcription_entity(entity, transcription, transcription_ent)
if self.use_cache:
# Store transcription entity in local cache
try:
CachedTranscriptionEntity.create(
transcription=transcription,
transcription=transcription.id,
entity=entity,
offset=offset,
length=length,
worker_version_id=self.worker_version_id,
worker_run_id=self.worker_run_id,
confidence=confidence,
)
except IntegrityError as e:
......@@ -178,14 +172,14 @@ class EntityMixin(object):
def list_transcription_entities(
self,
transcription: Transcription,
worker_version: bool = None,
worker_version: Optional[Union[str, bool]] = None,
):
"""
List existing entities on a transcription
This method does not support cache
:param transcription Transcription: The transcription to list entities on.
:param worker_version str or bool: Restrict to entities created by a worker version with this UUID. Set to False to look for manually created transcriptions.
:param transcription: The transcription to list entities on.
:param worker_version: Restrict to entities created by a worker version with this UUID. Set to False to look for manually created transcriptions.
"""
query_params = {}
assert transcription and isinstance(
......@@ -206,3 +200,28 @@ class EntityMixin(object):
return self.api_client.paginate(
"ListTranscriptionEntities", id=transcription.id, **query_params
)
def list_corpus_entities(
self,
name: Optional[str] = None,
parent: Optional[Element] = None,
):
"""
List all entities in the worker's corpus
This method does not support cache
:param name: Filter entities by part of their name (case-insensitive)
:param parent Element: Restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored.
"""
query_params = {}
if name is not None:
assert name and isinstance(name, str), "name should be of type str"
query_params["name"] = name
if parent is not None:
assert isinstance(parent, Element), "parent should be of type Element"
query_params["parent"] = parent.id
return self.api_client.paginate(
"ListCorpusEntities", id=self.corpus_id, **query_params
)
......@@ -4,8 +4,10 @@ ElementsWorker methods for metadata.
"""
from enum import Enum
from typing import Dict, List, Optional, Union
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
......@@ -56,25 +58,27 @@ class MetaType(Enum):
class MetaDataMixin(object):
"""
Mixin for the :class:`ElementsWorker` to add ``MetaData`` helpers.
"""
def create_metadata(self, element, type, name, value, entity=None):
def create_metadata(
self,
element: Union[Element, CachedElement],
type: MetaType,
name: str,
value: str,
entity: Optional[str] = None,
) -> str:
"""
Create a metadata on the given element through API.
:param element Element: The element to create a metadata on.
:param type MetaType: Type of the metadata.
:param name str: Name of the metadata.
:param value str: Value of the metadata.
:param element: The element to create a metadata on.
:param type: Type of the metadata.
:param name: Name of the metadata.
:param value: Value of the metadata.
:param entity: UUID of an entity this metadata is related to.
:type entity: str or None
:returns str: UUID of the created metadata.
:returns: UUID of the created metadata.
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element or CachedElement"
assert type and isinstance(
type, MetaType
), "type shouldn't be null and should be of type MetaType"
......@@ -98,7 +102,7 @@ class MetaDataMixin(object):
"name": name,
"value": value,
"entity_id": entity,
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
},
)
self.report.add_metadata(element.id, metadata["id"], type.value, name)
......@@ -107,9 +111,13 @@ class MetaDataMixin(object):
def create_metadatas(
self,
element: Element,
metadatas: list,
):
element: Union[Element, CachedElement],
metadatas: List[
Dict[
str, Union[MetaType, str, Union[str, Union[int, float]], Optional[str]]
]
],
) -> List[Dict[str, str]]:
"""
Create multiple metadatas on an existing element.
This method does not support cache.
......@@ -122,8 +130,8 @@ class MetaDataMixin(object):
- entity_id : Union[str, None]
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element or CachedElement"
assert metadatas and isinstance(
metadatas, list
......@@ -144,7 +152,7 @@ class MetaDataMixin(object):
metadata.get("name"), str
), "name shouldn't be null and should be of type str"
assert metadata.get("value") and isinstance(
assert metadata.get("value") is not None and isinstance(
metadata.get("value"), (str, float, int)
), "value shouldn't be null and should be of type (str or float or int)"
......@@ -169,7 +177,6 @@ class MetaDataMixin(object):
"CreateMetaDataBulk",
id=element.id,
body={
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"metadata_list": metas,
},
......@@ -180,15 +187,17 @@ class MetaDataMixin(object):
return created_metadatas
def list_metadata(self, element: Element):
def list_element_metadata(
self, element: Union[Element, CachedElement]
) -> List[Dict[str, str]]:
"""
List all metadata linked to an element.
This method does not support cache.
:param element Element: The element to list metadata on.
:param element: The element to list metadata on.
"""
assert element and isinstance(
element, Element
), "element shouldn't be null and should be of type Element"
element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element or CachedElement"
return self.api_client.paginate("ListElementMetaData", id=element.id)
# -*- coding: utf-8 -*-
"""
BaseWorker methods for training.
"""
import hashlib
import os
import tarfile
import tempfile
from contextlib import contextmanager
from typing import NewType, Tuple
from pathlib import Path
from typing import NewType, Optional, Tuple
import requests
import zstandard as zstd
......@@ -15,15 +20,24 @@ from arkindex_worker import logger
CHUNK_SIZE = 1024
DirPath = NewType("DirPath", str)
"""Path to a directory"""
Hash = NewType("Hash", str)
"""MD5 Hash"""
FileSize = NewType("FileSize", int)
Archive = Tuple[DirPath, Hash, FileSize]
"""File size"""
@contextmanager
def create_archive(path: DirPath) -> Archive:
"""First create a tar archive, then compress to a zst archive.
Finally, get its hash and size
def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]:
"""
Create a tar archive from the files at the given location then compress it to a zst archive.
Yield its location, its hash, its size and its content's hash.
:param path: Create a compressed tar archive from the files
:returns: The location of the created archive, its hash, its size and its content's hash
"""
assert path.is_dir(), "create_archive needs a directory"
......@@ -41,7 +55,9 @@ def create_archive(path: DirPath) -> Archive:
for p in path.glob("**/*"):
x = p.relative_to(path)
tar.add(p, arcname=x, recursive=False)
file_list.append(p)
# Only keep files when computing the hash
if p.is_file():
file_list.append(p)
# Sort by path
file_list.sort()
......@@ -76,17 +92,25 @@ def create_archive(path: DirPath) -> Archive:
class TrainingMixin(object):
"""
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version(
self, model_path: DirPath, model_id: str, tag: str = None, description: str = ""
self,
model_path: DirPath,
model_id: str,
tag: Optional[str] = None,
description: Optional[str] = None,
configuration: Optional[dict] = {},
):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex.
to create a unique version that will be stored on a bucket and published on Arkindex.
:param model_path: Path to the directory containing the model version's files.
:param model_id: ID of the model
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
"""
if self.is_read_only:
logger.warning(
"Cannot publish a new model version as this worker is in read-only mode"
......@@ -118,7 +142,7 @@ class TrainingMixin(object):
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self.update_model_version(
model_version_details=model_version_details,
model_version_details=model_version_details, configuration=configuration
)
def create_model_version(
......@@ -144,30 +168,29 @@ class TrainingMixin(object):
# Create a new model version with hash and size
try:
payload = {"hash": hash, "size": size, "archive_hash": archive_hash}
if tag:
payload["tag"] = tag
if description:
payload["description"] = description
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={
"hash": hash,
"size": size,
"archive_hash": archive_hash,
"tag": tag,
"description": description,
},
body=payload,
)
logger.info(
f"Model version ({model_version_details['id']}) was created successfully"
)
except ErrorResponse as e:
if e.status_code >= 500:
model_version_details = (
e.content.get("hash") if hasattr(e, "content") else None
)
if e.status_code >= 500 or model_version_details is None:
logger.error(f"Failed to create model version: {e.content}")
model_version_details = e.content.get("hash")
raise e
# If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned.
if model_version_details and isinstance(
model_version_details, (list, tuple)
):
if isinstance(model_version_details, (list, tuple)):
logger.error(model_version_details[0])
return
......@@ -201,7 +224,7 @@ class TrainingMixin(object):
def update_model_version(
self,
model_version_details: dict,
configuration: dict = {},
configuration: dict,
) -> None:
"""
Update the specified model version to the state `Available` and use the given information"
......
......@@ -4,7 +4,7 @@ ElementsWorker methods for transcriptions.
"""
from enum import Enum
from typing import Iterable, Optional, Union
from typing import Dict, Iterable, List, Optional, Union
from peewee import IntegrityError
......@@ -41,28 +41,22 @@ class TextOrientation(Enum):
class TranscriptionMixin(object):
"""
Mixin for the :class:`ElementsWorker` to provide ``Transcription`` helpers.
"""
def create_transcription(
self,
element,
text,
confidence,
orientation=TextOrientation.HorizontalLeftToRight,
):
element: Union[Element, CachedElement],
text: str,
confidence: float,
orientation: TextOrientation = TextOrientation.HorizontalLeftToRight,
) -> Optional[Dict[str, Union[str, float]]]:
"""
Create a transcription on the given element through the API.
:param element: Element to create a transcription on.
:type element: Element or CachedElement
:param text str: Text of the transcription.
:param confidence float: Confidence score, between 0 and 1.
:param orientation TextOrientation: Orientation of the transcription's text.
:param text: Text of the transcription.
:param confidence: Confidence score, between 0 and 1.
:param orientation: Orientation of the transcription's text.
:returns: A dict as returned by the ``CreateTranscription`` API endpoint,
or None if the worker is in read-only mode.
:rtype: Dict[str, Union[str, float]] or None
"""
assert element and isinstance(
element, (Element, CachedElement)
......@@ -88,7 +82,7 @@ class TranscriptionMixin(object):
id=element.id,
body={
"text": text,
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"confidence": confidence,
"orientation": orientation.value,
},
......@@ -106,7 +100,7 @@ class TranscriptionMixin(object):
"text": created["text"],
"confidence": created["confidence"],
"orientation": created["orientation"],
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
]
CachedTranscription.insert_many(to_insert).execute()
......@@ -117,21 +111,24 @@ class TranscriptionMixin(object):
return created
def create_transcriptions(self, transcriptions):
def create_transcriptions(
self,
transcriptions: List[Dict[str, Union[str, float, Optional[TextOrientation]]]],
) -> List[Dict[str, Union[str, float]]]:
"""
Create multiple transcriptions at once on existing elements through the API,
and creates :class:`CachedTranscription` instances if cache support is enabled.
and creates [CachedTranscription][arkindex_worker.cache.CachedTranscription] instances if cache support is enabled.
:param transcriptions: A list of dicts representing a transcription each, with the following keys:
element_id (str)
Required. UUID of the element to create this transcription on.
text (str)
Required. Text of the transcription.
confidence (float)
Required. Confidence score between 0 and 1.
orientation (:class:`TextOrientation`)
Optional. Orientation of the transcription's text.
element_id (str)
Required. UUID of the element to create this transcription on.
text (str)
Required. Text of the transcription.
confidence (float)
Required. Confidence score between 0 and 1.
orientation (TextOrientation)
Optional. Orientation of the transcription's text.
:returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint.
"""
......@@ -170,10 +167,16 @@ class TranscriptionMixin(object):
if orientation:
transcription["orientation"] = orientation.value
if self.is_read_only:
logger.warning(
"Cannot create transcription as this worker is in read-only mode"
)
return
created_trs = self.request(
"CreateTranscriptions",
body={
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"transcriptions": transcriptions_payload,
},
)["transcriptions"]
......@@ -191,7 +194,7 @@ class TranscriptionMixin(object):
"text": created_tr["text"],
"confidence": created_tr["confidence"],
"orientation": created_tr["orientation"],
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
for created_tr in created_trs
]
......@@ -203,23 +206,27 @@ class TranscriptionMixin(object):
return created_trs
def create_element_transcriptions(self, element, sub_element_type, transcriptions):
def create_element_transcriptions(
self,
element: Union[Element, CachedElement],
sub_element_type: str,
transcriptions: List[Dict[str, Union[str, float]]],
) -> Dict[str, Union[str, bool]]:
"""
Create multiple elements and transcriptions at once on a single parent element through the API.
:param element: Element to create elements and transcriptions on.
:type element: Element or CachedElement
:param str sub_element_type: Slug of the element type to use for the new elements.
:param sub_element_type: Slug of the element type to use for the new elements.
:param transcriptions: A list of dicts representing an element and transcription each, with the following keys:
polygon (list(list(int or float)))
Required. Polygon of the element.
text (str)
Required. Text of the transcription.
confidence (float)
Required. Confidence score between 0 and 1.
orientation (:class:`TextOrientation`)
Optional. Orientation of the transcription's text.
polygon (list(list(int or float)))
Required. Polygon of the element.
text (str)
Required. Text of the transcription.
confidence (float)
Required. Confidence score between 0 and 1.
orientation ([TextOrientation][arkindex_worker.worker.transcription.TextOrientation])
Optional. Orientation of the transcription's text.
:returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint.
"""
......@@ -282,7 +289,7 @@ class TranscriptionMixin(object):
id=element.id,
body={
"element_type": sub_element_type,
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id,
"transcriptions": transcriptions_payload,
"return_elements": True,
},
......@@ -319,7 +326,7 @@ class TranscriptionMixin(object):
"type": sub_element_type,
"image_id": element.image_id,
"polygon": transcription["polygon"],
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
)
......@@ -334,7 +341,7 @@ class TranscriptionMixin(object):
"orientation": transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight
).value,
"worker_version_id": self.worker_version_id,
"worker_run_id": self.worker_run_id,
}
)
......@@ -359,16 +366,11 @@ class TranscriptionMixin(object):
List transcriptions on an element.
:param element: The element to list transcriptions on.
:type element: Element or CachedElement
:param element_type: Restrict to transcriptions whose elements have an element type with this slug.
:type element_type: Optional[str]
:param recursive: Include transcriptions of any descendant of this element, recursively.
:type recursive: Optional[bool]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
:type worker_version: Optional[Union[str, bool]]
:returns: An iterable of dicts representing each transcription,
or an iterable of CachedTranscription when cache support is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedTranscription]]
"""
assert element and isinstance(
element, (Element, CachedElement)
......
......@@ -5,16 +5,12 @@ ElementsWorker methods for worker versions.
class WorkerVersionMixin(object):
"""
Mixin for the :class:`ElementsWorker` to provide ``WorkerVersion`` helpers.
"""
def get_worker_version(self, worker_version_id: str) -> dict:
"""
Retrieve a worker version, using the :class:`ElementsWorker`'s internal cache when possible.
Retrieve a worker version, using the [ElementsWorker][arkindex_worker.worker.ElementsWorker]'s internal cache when possible.
:param str worker_version_id: ID of the worker version to retrieve.
:returns dict: The requested worker version, as returned by the ``RetrieveWorkerVersion`` API endpoint.
:param worker_version_id: ID of the worker version to retrieve.
:returns: The requested worker version, as returned by the ``RetrieveWorkerVersion`` API endpoint.
"""
if worker_version_id is None:
raise ValueError("No worker version ID")
......@@ -32,8 +28,8 @@ class WorkerVersionMixin(object):
Retrieve the slug of the worker of a worker version, from a worker version UUID.
Uses a worker version from the internal cache if possible, otherwise makes an API request.
:param worker_version_id str: ID of the worker version to find a slug for.
:returns str: Slug of the worker of this worker version.
:param worker_version_id: ID of the worker version to find a slug for.
:returns: Slug of the worker of this worker version.
"""
worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"]