Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (89)
Showing
with 767 additions and 477 deletions
...@@ -145,9 +145,7 @@ pypi-publication: ...@@ -145,9 +145,7 @@ pypi-publication:
- pip install -e .[docs] - pip install -e .[docs]
script: script:
- cd docs - mkdocs build --strict --verbose
- make html
- mv _build/html ../public
docs-build: docs-build:
extends: .docs extends: .docs
......
...@@ -8,4 +8,4 @@ line_length = 88 ...@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
...@@ -11,7 +11,7 @@ repos: ...@@ -11,7 +11,7 @@ repos:
rev: 22.3.0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
- repo: https://gitlab.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 3.8.3 rev: 3.8.3
hooks: hooks:
- id: flake8 - id: flake8
...@@ -23,7 +23,6 @@ repos: ...@@ -23,7 +23,6 @@ repos:
rev: v3.1.0 rev: v3.1.0
hooks: hooks:
- id: check-ast - id: check-ast
- id: check-docstring-first
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
- id: check-merge-conflict - id: check-merge-conflict
- id: check-symlinks - id: check-symlinks
......
Arkindex base Worker # Arkindex base Worker
====================
An easy to use Python 3 high level API client, to build ML tasks. An easy to use Python 3 high level API client, to build ML tasks.
## Create a new worker using our template ## Create a new worker using our template
``` ```
pip install --user cookiecutter pip install --user cookiecutter
cookiecutter git@gitlab.com:arkindex/base-worker.git cookiecutter git@gitlab.com:teklia/workers/base-worker.git
``` ```
0.3.0-rc5 0.3.2-rc3
...@@ -10,6 +10,8 @@ reducing network usage. ...@@ -10,6 +10,8 @@ reducing network usage.
import json import json
import os import os
import sqlite3 import sqlite3
from pathlib import Path
from typing import Optional, Union
from peewee import ( from peewee import (
BooleanField, BooleanField,
...@@ -26,9 +28,9 @@ from peewee import ( ...@@ -26,9 +28,9 @@ from peewee import (
TextField, TextField,
UUIDField, UUIDField,
) )
from PIL import Image
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.image import open_image, polygon_bounding_box
db = SqliteDatabase(None) db = SqliteDatabase(None)
...@@ -65,6 +67,10 @@ class Version(Model): ...@@ -65,6 +67,10 @@ class Version(Model):
class CachedImage(Model): class CachedImage(Model):
"""
Cache image table
"""
id = UUIDField(primary_key=True) id = UUIDField(primary_key=True)
width = IntegerField() width = IntegerField()
height = IntegerField() height = IntegerField()
...@@ -76,6 +82,10 @@ class CachedImage(Model): ...@@ -76,6 +82,10 @@ class CachedImage(Model):
class CachedElement(Model): class CachedElement(Model):
"""
Cache element table
"""
id = UUIDField(primary_key=True) id = UUIDField(primary_key=True)
parent_id = UUIDField(null=True) parent_id = UUIDField(null=True)
type = CharField(max_length=50) type = CharField(max_length=50)
...@@ -84,26 +94,29 @@ class CachedElement(Model): ...@@ -84,26 +94,29 @@ class CachedElement(Model):
rotation_angle = IntegerField(default=0) rotation_angle = IntegerField(default=0)
mirrored = BooleanField(default=False) mirrored = BooleanField(default=False)
initial = BooleanField(default=False) initial = BooleanField(default=False)
# Needed to filter elements with cache
worker_version_id = UUIDField(null=True) worker_version_id = UUIDField(null=True)
worker_run_id = UUIDField(null=True)
confidence = FloatField(null=True) confidence = FloatField(null=True)
class Meta: class Meta:
database = db database = db
table_name = "elements" table_name = "elements"
def open_image(self, *args, max_size=None, **kwargs): def open_image(self, *args, max_size: Optional[int] = None, **kwargs) -> Image:
""" """
Open this element's image as a Pillow image. Open this element's image as a Pillow image.
This does not crop the image to the element's polygon. This does not crop the image to the element's polygon.
IIIF servers with maxWidth, maxHeight or maxArea restrictions on image size are not supported. IIIF servers with maxWidth, maxHeight or maxArea restrictions on image size are not supported.
:param \\*args: Positional arguments passed to :meth:`arkindex_worker.image.open_image` :param *args: Positional arguments passed to [arkindex_worker.image.open_image][]
:param max_size: Subresolution of the image. :param max_size: Subresolution of the image.
:type max_size: int or None :param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][]
:param \\**kwargs: Keyword arguments passed to :meth:`arkindex_worker.image.open_image`
:raises ValueError: When this element does not have an image ID or a polygon. :raises ValueError: When this element does not have an image ID or a polygon.
:returns PIL.Image: A Pillow image. :return: A Pillow image.
""" """
from arkindex_worker.image import open_image, polygon_bounding_box
if not self.image_id or not self.polygon: if not self.image_id or not self.polygon:
raise ValueError(f"Element {self.id} has no image") raise ValueError(f"Element {self.id} has no image")
...@@ -152,12 +165,18 @@ class CachedElement(Model): ...@@ -152,12 +165,18 @@ class CachedElement(Model):
class CachedTranscription(Model): class CachedTranscription(Model):
"""
Cache transcription table
"""
id = UUIDField(primary_key=True) id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="transcriptions") element = ForeignKeyField(CachedElement, backref="transcriptions")
text = TextField() text = TextField()
confidence = FloatField() confidence = FloatField()
orientation = CharField(max_length=50) orientation = CharField(max_length=50)
# Needed to filter transcriptions with cache
worker_version_id = UUIDField(null=True) worker_version_id = UUIDField(null=True)
worker_run_id = UUIDField(null=True)
class Meta: class Meta:
database = db database = db
...@@ -165,12 +184,16 @@ class CachedTranscription(Model): ...@@ -165,12 +184,16 @@ class CachedTranscription(Model):
class CachedClassification(Model): class CachedClassification(Model):
"""
Cache classification table
"""
id = UUIDField(primary_key=True) id = UUIDField(primary_key=True)
element = ForeignKeyField(CachedElement, backref="classifications") element = ForeignKeyField(CachedElement, backref="classifications")
class_name = TextField() class_name = TextField()
confidence = FloatField() confidence = FloatField()
state = CharField(max_length=10) state = CharField(max_length=10)
worker_version_id = UUIDField(null=True) worker_run_id = UUIDField(null=True)
class Meta: class Meta:
database = db database = db
...@@ -178,12 +201,16 @@ class CachedClassification(Model): ...@@ -178,12 +201,16 @@ class CachedClassification(Model):
class CachedEntity(Model): class CachedEntity(Model):
"""
Cache entity table
"""
id = UUIDField(primary_key=True) id = UUIDField(primary_key=True)
type = CharField(max_length=50) type = CharField(max_length=50)
name = TextField() name = TextField()
validated = BooleanField(default=False) validated = BooleanField(default=False)
metas = JSONField(null=True) metas = JSONField(null=True)
worker_version_id = UUIDField(null=True) worker_run_id = UUIDField(null=True)
class Meta: class Meta:
database = db database = db
...@@ -191,13 +218,17 @@ class CachedEntity(Model): ...@@ -191,13 +218,17 @@ class CachedEntity(Model):
class CachedTranscriptionEntity(Model): class CachedTranscriptionEntity(Model):
"""
Cache transcription entity table
"""
transcription = ForeignKeyField( transcription = ForeignKeyField(
CachedTranscription, backref="transcription_entities" CachedTranscription, backref="transcription_entities"
) )
entity = ForeignKeyField(CachedEntity, backref="transcription_entities") entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
offset = IntegerField(constraints=[Check("offset >= 0")]) offset = IntegerField(constraints=[Check("offset >= 0")])
length = IntegerField(constraints=[Check("length > 0")]) length = IntegerField(constraints=[Check("length > 0")])
worker_version_id = UUIDField(null=True) worker_run_id = UUIDField(null=True)
confidence = FloatField(null=True) confidence = FloatField(null=True)
class Meta: class Meta:
...@@ -216,10 +247,14 @@ MODELS = [ ...@@ -216,10 +247,14 @@ MODELS = [
CachedEntity, CachedEntity,
CachedTranscriptionEntity, CachedTranscriptionEntity,
] ]
SQL_VERSION = 1 SQL_VERSION = 2
def init_cache_db(path): def init_cache_db(path: str):
"""
Create the cache database on the given path
:param path: Where the new database should be created
"""
db.init( db.init(
path, path,
pragmas={ pragmas={
...@@ -250,7 +285,12 @@ def create_version_table(): ...@@ -250,7 +285,12 @@ def create_version_table():
Version.create(version=SQL_VERSION) Version.create(version=SQL_VERSION)
def check_version(cache_path): def check_version(cache_path: Union[str, Path]):
"""
Check the validity of the SQLite version
:param cache_path: Path towards a local SQLite database
"""
with SqliteDatabase(cache_path) as provided_db: with SqliteDatabase(cache_path) as provided_db:
with provided_db.bind_ctx([Version]): with provided_db.bind_ctx([Version]):
try: try:
...@@ -263,7 +303,16 @@ def check_version(cache_path): ...@@ -263,7 +303,16 @@ def check_version(cache_path):
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}" ), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None): def retrieve_parents_cache_path(
parent_ids: list, data_dir: str = "/data", chunk: int = None
) -> list:
"""
Retrieve the path of the given parent's in the
:param parent_ids: List of element IDs to search
:param data_dir: Base folder where to look for
:param chunk: Index of the chunk of the db that might contain the paths
:return: The corresponding list of paths
"""
assert isinstance(parent_ids, list) assert isinstance(parent_ids, list)
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
...@@ -288,9 +337,11 @@ def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None): ...@@ -288,9 +337,11 @@ def retrieve_parents_cache_path(parent_ids, data_dir="/data", chunk=None):
) )
def merge_parents_cache(paths, current_database): def merge_parents_cache(paths: list, current_database: str):
""" """
Merge all the potential parent task's databases into the existing local one Merge all the potential parent task's databases into the existing local one
:param paths: Path to cache databases
:param current_database: Path to the current database
""" """
assert os.path.exists(current_database) assert os.path.exists(current_database)
......
...@@ -6,10 +6,12 @@ import shutil ...@@ -6,10 +6,12 @@ import shutil
import time import time
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional, Union
import gitlab import gitlab
import requests import requests
import sh import sh
from gitlab.v4.objects import MergeRequest, ProjectMergeRequest
from arkindex_worker import logger from arkindex_worker import logger
...@@ -22,13 +24,13 @@ class GitlabHelper: ...@@ -22,13 +24,13 @@ class GitlabHelper:
def __init__( def __init__(
self, self,
project_id, project_id: str,
gitlab_url, gitlab_url: str,
gitlab_token, gitlab_token: str,
branch, branch: str,
rebase_wait_period=1, rebase_wait_period: Optional[int] = 1,
delete_source_branch=True, delete_source_branch: Optional[bool] = True,
max_rebase_tries=10, max_rebase_tries: Optional[int] = 10,
): ):
""" """
:param project_id: the id of the gitlab project :param project_id: the id of the gitlab project
...@@ -52,13 +54,13 @@ class GitlabHelper: ...@@ -52,13 +54,13 @@ class GitlabHelper:
self.project = self._api.projects.get(self.project_id) self.project = self._api.projects.get(self.project_id)
self.is_rebase_finished = False self.is_rebase_finished = False
def merge(self, branch_name, title) -> bool: def merge(self, branch_name: str, title: str) -> bool:
""" """
Create a merge request and try to merge. Create a merge request and try to merge.
Always rebase first to avoid conflicts from MRs made in parallel Always rebase first to avoid conflicts from MRs made in parallel
:param branch_name: source branch name :param branch_name: Source branch name
:param title: title of the merge request :param title: Title of the merge request
:return: was the branch successfully merged? :return: Whether the branch was successfully merged
""" """
mr = None mr = None
# always rebase first, because other workers might have merged already # always rebase first, because other workers might have merged already
...@@ -95,7 +97,14 @@ class GitlabHelper: ...@@ -95,7 +97,14 @@ class GitlabHelper:
return False return False
def _create_merge_request(self, branch_name, title): def _create_merge_request(self, branch_name: str, title: str) -> MergeRequest:
"""
Create a MergeRequest towards the branch with the given title
:param branch_name: Target branch of the merge request
:param title: Title of the merge request
:return: The created merge request
"""
logger.info(f"Creating a merge request for {branch_name}") logger.info(f"Creating a merge request for {branch_name}")
# retry_transient_error will retry the request on 50X errors # retry_transient_error will retry the request on 50X errors
# https://github.com/python-gitlab/python-gitlab/blob/265dbbdd37af88395574564aeb3fd0350288a18c/gitlab/__init__.py#L539 # https://github.com/python-gitlab/python-gitlab/blob/265dbbdd37af88395574564aeb3fd0350288a18c/gitlab/__init__.py#L539
...@@ -108,16 +117,24 @@ class GitlabHelper: ...@@ -108,16 +117,24 @@ class GitlabHelper:
) )
return mr return mr
def _get_merge_request(self, merge_request_id, include_rebase_in_progress=True): def _get_merge_request(
self, merge_request_id: Union[str, int], include_rebase_in_progress: bool = True
) -> ProjectMergeRequest:
"""
Retrieve a merge request by ID
:param merge_request_id: The ID of the merge request
:param include_rebase_in_progress: Whether the rebase in progree should be included
:return: The related merge request
"""
return self.project.mergerequests.get( return self.project.mergerequests.get(
merge_request_id, include_rebase_in_progress=include_rebase_in_progress merge_request_id, include_rebase_in_progress=include_rebase_in_progress
) )
def _wait_for_rebase_to_finish(self, merge_request_id) -> bool: def _wait_for_rebase_to_finish(self, merge_request_id: Union[str, int]) -> bool:
""" """
Poll the merge request until it has finished rebasing Poll the merge request until it has finished rebasing
:param merge_request_id: :param merge_request_id: The ID of the merge request
:return: rebase finished successfully? :return: Whether the rebase has finished successfully
""" """
logger.info("Checking if rebase has finished..") logger.info("Checking if rebase has finished..")
...@@ -134,10 +151,10 @@ class GitlabHelper: ...@@ -134,10 +151,10 @@ class GitlabHelper:
return False return False
def make_backup(path): def make_backup(path: str):
""" """
Create a backup file in the same directory with timestamp as suffix ".bak_{timestamp}" Create a backup file in the same directory with timestamp as suffix ".bak_{timestamp}"
:param path: file to be backed up :param path: Path to the file to be backed up
""" """
path = Path(path) path = Path(path)
if not path.exists(): if not path.exists():
...@@ -150,10 +167,10 @@ def make_backup(path): ...@@ -150,10 +167,10 @@ def make_backup(path):
def prepare_git_key( def prepare_git_key(
private_key, private_key: str,
known_hosts, known_hosts: str,
private_key_path="~/.ssh/id_ed25519", private_key_path: Optional[str] = "~/.ssh/id_ed25519",
known_hosts_path="~/.ssh/known_hosts", known_hosts_path: Optional[str] = "~/.ssh/known_hosts",
): ):
""" """
Prepare the git keys (put them in to the correct place) so that git could be used. Prepare the git keys (put them in to the correct place) so that git could be used.
...@@ -198,24 +215,25 @@ class GitHelper: ...@@ -198,24 +215,25 @@ class GitHelper:
""" """
A helper class for running git commands A helper class for running git commands
At the beginning of the workflow call `run_clone_in_background`. At the beginning of the workflow call [run_clone_in_background][arkindex_worker.git.GitHelper.run_clone_in_background].
When all the files are ready to be added to git then call When all the files are ready to be added to git then call
`save_files` to move the files in to the git repository [save_files][arkindex_worker.git.GitHelper.save_files] to move the files in to the git repository
and try to push them. and try to push them.
Pseudo code example: Examples
in worker.configure() configure the git helper and start the cloning: --------
``` in worker.configure() configure the git helper and start the cloning:
gitlab = GitlabHelper(...) ```
prepare_git_key(...) gitlab = GitlabHelper(...)
self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...) prepare_git_key(...)
self.git_helper.run_clone_in_background() self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...)
``` self.git_helper.run_clone_in_background()
```
at the end of the workflow (at the end of worker.run()) push the files to git:
``` at the end of the workflow (at the end of worker.run()) push the files to git:
self.git_helper.save_files(self.out_dir) ```
``` self.git_helper.save_files(self.out_dir)
```
""" """
def __init__( def __init__(
...@@ -294,6 +312,9 @@ class GitHelper: ...@@ -294,6 +312,9 @@ class GitHelper:
""" """
Move files in export_out_dir to the cloned git repository Move files in export_out_dir to the cloned git repository
and try to merge the created files if possible. and try to merge the created files if possible.
:param export_out_dir: Path to the files to be saved
:raises sh.ErrorReturnCode: _description_
:raises Exception: _description_
""" """
self._wait_for_clone_to_finish() self._wait_for_clone_to_finish()
...@@ -350,6 +371,8 @@ class GitHelper: ...@@ -350,6 +371,8 @@ class GitHelper:
""" """
Move all files in the export_out_dir to the git repository Move all files in the export_out_dir to the git repository
while keeping the same directory structure while keeping the same directory structure
:param export_out_dir: Path to the files to be moved
:return: Total count of moved files
""" """
file_count = 0 file_count = 0
file_names = [ file_names = [
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
""" """
Helper methods to download and open IIIF images, and manage polygons. Helper methods to download and open IIIF images, and manage polygons.
""" """
import os import os
from collections import namedtuple from collections import namedtuple
from io import BytesIO from io import BytesIO
from math import ceil from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
import requests import requests
from PIL import Image from PIL import Image
...@@ -21,28 +22,38 @@ from tenacity import ( ...@@ -21,28 +22,38 @@ from tenacity import (
from arkindex_worker import logger from arkindex_worker import logger
# Avoid circular imports error when type checking
if TYPE_CHECKING:
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts # See http://docs.python-requests.org/en/master/user/advanced/#timeouts
DOWNLOAD_TIMEOUT = (30, 60) DOWNLOAD_TIMEOUT = (30, 60)
BoundingBox = namedtuple("BoundingBox", ["x", "y", "width", "height"]) BoundingBox = namedtuple("BoundingBox", ["x", "y", "width", "height"])
def open_image(path, mode="RGB", rotation_angle=0, mirrored=False): def open_image(
path: Union[str, Path],
mode: Optional[str] = "RGB",
rotation_angle: Optional[int] = 0,
mirrored: Optional[bool] = False,
) -> Image:
""" """
Open an image from a path or a URL. Open an image from a path or a URL.
.. warning:: Prefer :meth:`Element.open_image` whenever possible. Warns:
Prefer [arkindex_worker.models.Element.open_image][] whenever possible.
:param path str: Path or URL to open the image from. :param path: Path or URL to open the image from.
This parameter will be interpreted as a URL when it has a `http` or `https` scheme This parameter will be interpreted as a URL when it has a `http` or `https` scheme
and no file exist with this path locally. and no file exist with this path locally.
:param mode str: Pillow mode for the image. See `the Pillow documentation :param mode: Pillow mode for the image. See [the Pillow documentation](https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes).
<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`_. :param rotation_angle: Rotation angle to apply to the image, in degrees.
:param int rotation_angle: Rotation angle to apply to the image, in degrees.
If it is not a multiple of 90°, then the rotation can cause empty pixels of If it is not a multiple of 90°, then the rotation can cause empty pixels of
the mode's default color to be added for padding. the mode's default color to be added for padding.
:param bool mirrored: Whether or not to mirror the image horizontally. :param mirrored: Whether or not to mirror the image horizontally.
:returns PIL.Image: A Pillow image. :returns: A Pillow image.
""" """
if ( if (
path.startswith("http://") path.startswith("http://")
...@@ -68,12 +79,12 @@ def open_image(path, mode="RGB", rotation_angle=0, mirrored=False): ...@@ -68,12 +79,12 @@ def open_image(path, mode="RGB", rotation_angle=0, mirrored=False):
return image return image
def download_image(url): def download_image(url: str) -> Image:
""" """
Download an image and open it with Pillow. Download an image and open it with Pillow.
:param url str: URL of the image. :param url: URL of the image.
:returns PIL.Image: A Pillow image. :returns: A Pillow image.
""" """
assert url.startswith("http"), "Image URL must be HTTP(S)" assert url.startswith("http"), "Image URL must be HTTP(S)"
...@@ -110,13 +121,12 @@ def download_image(url): ...@@ -110,13 +121,12 @@ def download_image(url):
return image return image
def polygon_bounding_box(polygon): def polygon_bounding_box(polygon: List[List[Union[int, float]]]) -> BoundingBox:
""" """
Compute the rectangle bounding box of a polygon. Compute the rectangle bounding box of a polygon.
:param polygon: Polygon to get the bounding box of. :param polygon: Polygon to get the bounding box of.
:type polygon: list(list(int or float)) :returns: Bounding box of this polygon.
:returns BoundingBox: Bounding box of this polygon.
""" """
x_coords, y_coords = zip(*polygon) x_coords, y_coords = zip(*polygon)
x, y = min(x_coords), min(y_coords) x, y = min(x_coords), min(y_coords)
...@@ -144,12 +154,12 @@ def _retried_request(url): ...@@ -144,12 +154,12 @@ def _retried_request(url):
return resp return resp
def download_tiles(url): def download_tiles(url: str) -> Image:
""" """
Reconstruct a full IIIF image on servers that cannot serve the full-sized image, using tiles. Reconstruct a full IIIF image on servers that cannot serve the full-sized image, using tiles.
:param str url: URL of the image. :param url: URL of the image.
:returns PIL.Image: A Pillow image. :returns: A Pillow image.
""" """
if not url.endswith("/"): if not url.endswith("/"):
url += "/" url += "/"
...@@ -217,19 +227,19 @@ def download_tiles(url): ...@@ -217,19 +227,19 @@ def download_tiles(url):
return full_image return full_image
def trim_polygon(polygon, image_width: int, image_height: int): def trim_polygon(
polygon: List[List[int]], image_width: int, image_height: int
) -> List[List[int]]:
""" """
Trim a polygon to an image's boundaries, with non-negative coordinates. Trim a polygon to an image's boundaries, with non-negative coordinates.
:param polygon: A polygon to trim. :param polygon: A polygon to trim.
:type: list(list(int or float) or tuple(int or float)) or tuple(tuple(int or float) or list(int or float)) :param image_width: Width of the image.
:param image_width int: Width of the image. :param image_height: Height of the image.
:param image_height int: Height of the image.
:returns: A polygon trimmed to the image's bounds. :returns: A polygon trimmed to the image's bounds.
Some points may appear as missing, as the trimming can deduplicate points. Some points may appear as missing, as the trimming can deduplicate points.
The first and last point are always equal, to reproduce the behavior The first and last point are always equal, to reproduce the behavior
of the Arkindex backend. of the Arkindex backend.
:rtype: list(list(int or float))
:raises AssertionError: When argument types are invalid or when the trimmed polygon :raises AssertionError: When argument types are invalid or when the trimmed polygon
is entirely outside of the image's bounds. is entirely outside of the image's bounds.
""" """
...@@ -270,28 +280,28 @@ def trim_polygon(polygon, image_width: int, image_height: int): ...@@ -270,28 +280,28 @@ def trim_polygon(polygon, image_width: int, image_height: int):
return updated_polygon return updated_polygon
def revert_orientation(element, polygon, reverse: bool = False): def revert_orientation(
element: Union["Element", "CachedElement"],
polygon: List[List[Union[int, float]]],
reverse: Optional[bool] = False,
) -> List[List[int]]:
""" """
Update the coordinates of the polygon of a child element based on the orientation of Update the coordinates of the polygon of a child element based on the orientation of
its parent. its parent.
This method should be called before sending any polygon to Arkindex, to undo the possible This method should be called before sending any polygon to Arkindex, to undo the possible
orientation applied by :meth:`Element.open_image`. orientation applied by [arkindex_worker.models.Element.open_image][].
In some cases, we want to apply the parent's orientation on the child's polygon instead. This is done In some cases, we want to apply the parent's orientation on the child's polygon instead. This is done
by enabling `reverse=True`. by enabling `reverse=True`.
:param element: Parent element. :param element: Parent element.
:type element: Element or CachedElement
:param polygon: Polygon corresponding to the child element. :param polygon: Polygon corresponding to the child element.
:type polygon: list(list(int or float)) :param reverse: Whether we should revert or apply the parent's orientation.
:param mode: Whether we should revert (`revert`) or apply (`apply`) the parent's orientation.
:type mode: str
:return: A polygon with updated coordinates. :return: A polygon with updated coordinates.
:rtype: list(list(int))
""" """
from arkindex_worker.models import Element
from arkindex_worker.cache import CachedElement from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
......
...@@ -5,12 +5,12 @@ Wrappers around API results to provide more convenient attribute access and IIIF ...@@ -5,12 +5,12 @@ Wrappers around API results to provide more convenient attribute access and IIIF
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Generator, List, Optional
from PIL import Image
from requests import HTTPError from requests import HTTPError
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.image import download_tiles, open_image, polygon_bounding_box
class MagicDict(dict): class MagicDict(dict):
...@@ -63,7 +63,12 @@ class Element(MagicDict): ...@@ -63,7 +63,12 @@ class Element(MagicDict):
Describes an Arkindex element. Describes an Arkindex element.
""" """
def resize_zone_url(self, size="full"): def resize_zone_url(self, size: str = "full") -> str:
"""
Compute the URL of the image corresponding to the size
:param size: Requested size
:return: The URL corresponding to the size
"""
if size == "full": if size == "full":
return self.zone.url return self.zone.url
else: else:
...@@ -71,13 +76,12 @@ class Element(MagicDict): ...@@ -71,13 +76,12 @@ class Element(MagicDict):
parts[-3] = size parts[-3] = size
return "/".join(parts) return "/".join(parts)
def image_url(self, size="full") -> Optional[str]: def image_url(self, size: str = "full") -> Optional[str]:
""" """
Build an URL to access the image. Build an URL to access the image.
When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers. When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers.
:param str size: Subresolution of the image, following the syntax of the IIIF resize parameter. :param size: Subresolution of the image, following the syntax of the IIIF resize parameter.
:returns: An URL to the image, or None if the element does not have an image. :returns: An URL to the image, or None if the element does not have an image.
:rtype: str or None
""" """
if not self.get("zone"): if not self.get("zone"):
return return
...@@ -89,12 +93,27 @@ class Element(MagicDict): ...@@ -89,12 +93,27 @@ class Element(MagicDict):
url += "/" url += "/"
return "{}full/{}/0/default.jpg".format(url, size) return "{}full/{}/0/default.jpg".format(url, size)
@property
def polygon(self) -> List[float]:
"""
Access an Element's polygon.
This is a shortcut to an Element's polygon, normally accessed via
its zone field via `zone.polygon`. This is mostly done
to facilitate access to this important field by matching
the [CachedElement][arkindex_worker.cache.CachedElement].polygon field.
"""
if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id))
return self.zone.polygon
@property @property
def requires_tiles(self) -> bool: def requires_tiles(self) -> bool:
""" """
:return bool: Whether or not downloading and combining IIIF tiles will be necessary Whether or not downloading and combining IIIF tiles will be necessary
to retrieve this element's image. Will be False if the element has no image. to retrieve this element's image. Will be False if the element has no image.
""" """
from arkindex_worker.image import polygon_bounding_box
if not self.get("zone") or self.zone.image.get("s3_url"): if not self.get("zone") or self.zone.image.get("s3_url"):
return False return False
max_width = self.zone.image.server.max_width or float("inf") max_width = self.zone.image.server.max_width or float("inf")
...@@ -102,7 +121,13 @@ class Element(MagicDict): ...@@ -102,7 +121,13 @@ class Element(MagicDict):
bounding_box = polygon_bounding_box(self.zone.polygon) bounding_box = polygon_bounding_box(self.zone.polygon)
return bounding_box.width > max_width or bounding_box.height > max_height return bounding_box.width > max_width or bounding_box.height > max_height
def open_image(self, *args, max_size=None, use_full_image=False, **kwargs): def open_image(
self,
*args,
max_size: Optional[int] = None,
use_full_image: Optional[bool] = False,
**kwargs
) -> Image:
""" """
Open this element's image using Pillow, rotating and mirroring it according Open this element's image using Pillow, rotating and mirroring it according
to the ``rotation_angle`` and ``mirrored`` attributes. to the ``rotation_angle`` and ``mirrored`` attributes.
...@@ -111,12 +136,12 @@ class Element(MagicDict): ...@@ -111,12 +136,12 @@ class Element(MagicDict):
to bypass IIIF servers, the image will be cropped to the rectangle bounding box to bypass IIIF servers, the image will be cropped to the rectangle bounding box
of the ``zone.polygon`` attribute. of the ``zone.polygon`` attribute.
.. warning:: Warns:
----
This method implicitly applies the element's orientation to the image. This method implicitly applies the element's orientation to the image.
If your process uses the returned image to find more polygons and send them If your process uses the returned image to find more polygons and send them
back to Arkindex, use the :meth:`arkindex_worker.image.revert_orientation` back to Arkindex, use the [arkindex_worker.image.revert_orientation][]
helper to undo the orientation on all polygons before sending them, as the helper to undo the orientation on all polygons before sending them, as the
Arkindex API expects unoriented polygons. Arkindex API expects unoriented polygons.
...@@ -125,19 +150,24 @@ class Element(MagicDict): ...@@ -125,19 +150,24 @@ class Element(MagicDict):
:param max_size: The maximum size of the requested image. :param max_size: The maximum size of the requested image.
:type max_size: int or None :param use_full_image: Ignore the ``zone.polygon`` and always
:param bool use_full_image: Ignore the ``zone.polygon`` and always
retrieve the image without cropping. retrieve the image without cropping.
:param \\*args: Positional arguments passed to :meth:`arkindex_worker.image.open_image`. :param *args: Positional arguments passed to [arkindex_worker.image.open_image][].
:param \\**kwargs: Keyword arguments passed to :meth:`arkindex_worker.image.open_image`. :param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][].
:raises ValueError: When the element does not have an image. :raises ValueError: When the element does not have an image.
:raises NotImplementedError: When the ``max_size`` parameter is set, :raises NotImplementedError: When the ``max_size`` parameter is set,
but the IIIF server's configuration requires downloading and combining tiles but the IIIF server's configuration requires downloading and combining tiles
to retrieve the image. to retrieve the image.
:raises NotImplementedError: When an S3 URL has been used to download the image, :raises NotImplementedError: When an S3 URL has been used to download the image,
but the URL has expired. Re-fetching the URL automatically is not supported. but the URL has expired. Re-fetching the URL automatically is not supported.
:return PIL.Image: A Pillow image. :return: A Pillow image.
""" """
from arkindex_worker.image import (
download_tiles,
open_image,
polygon_bounding_box,
)
if not self.get("zone"): if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id)) raise ValueError("Element {} has no zone".format(self.id))
...@@ -197,18 +227,25 @@ class Element(MagicDict): ...@@ -197,18 +227,25 @@ class Element(MagicDict):
raise raise
@contextmanager @contextmanager
def open_image_tempfile(self, format="jpeg", *args, **kwargs): def open_image_tempfile(
self, format: Optional[str] = "jpeg", *args, **kwargs
) -> Generator[tempfile.NamedTemporaryFile, None, None]:
""" """
Get the element's image as a temporary file stored on the disk. Get the element's image as a temporary file stored on the disk.
To be used as a context manager:: To be used as a context manager.
with element.open_image_tempfile() as f: Example
... ----
```
with element.open_image_tempfile() as f:
...
```
:param format str: File format to use the store the image on the disk. :param format: File format to use the store the image on the disk.
Must be a format supported by Pillow. Must be a format supported by Pillow.
:param \\*args: Positional arguments passed to :meth:`arkindex_worker.image.open_image`. :param *args: Positional arguments passed to [arkindex_worker.image.open_image][].
:param \\**kwargs: Keyword arguments passed to :meth:`arkindex_worker.image.open_image`. :param **kwargs: Keyword arguments passed to [arkindex_worker.image.open_image][].
""" """
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
self.open_image(*args, **kwargs).save(f, format=format) self.open_image(*args, **kwargs).save(f, format=format)
...@@ -229,12 +266,3 @@ class Transcription(MagicDict): ...@@ -229,12 +266,3 @@ class Transcription(MagicDict):
def __str__(self): def __str__(self):
return "Transcription ({})".format(self.id) return "Transcription ({})".format(self.id)
class Corpus(MagicDict):
"""
Describes an Arkindex corpus.
"""
def __str__(self):
return "Corpus {} ({})".format(self.name, self.id)
...@@ -7,10 +7,14 @@ import json ...@@ -7,10 +7,14 @@ import json
import traceback import traceback
from collections import Counter from collections import Counter
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Union
from uuid import UUID
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.models import Transcription
class Reporter(object): class Reporter(object):
...@@ -19,7 +23,11 @@ class Reporter(object): ...@@ -19,7 +23,11 @@ class Reporter(object):
""" """
def __init__( def __init__(
self, name="Unknown worker", slug="unknown-slug", version=None, **kwargs self,
name: Optional[str] = "Unknown worker",
slug: Optional[str] = "unknown-slug",
version: Optional[str] = None,
**kwargs,
): ):
self.report_data = { self.report_data = {
"name": name, "name": name,
...@@ -46,57 +54,56 @@ class Reporter(object): ...@@ -46,57 +54,56 @@ class Reporter(object):
"classifications": {}, "classifications": {},
# Created entities ({"id": "", "type": "", "name": ""}) from this element # Created entities ({"id": "", "type": "", "name": ""}) from this element
"entities": [], "entities": [],
# Created transcription entities ({"transcription_id": "", "entity_id": ""}) from this element
"transcription_entities": [],
# Created metadata ({"id": "", "type": "", "name": ""}) from this element # Created metadata ({"id": "", "type": "", "name": ""}) from this element
"metadata": [], "metadata": [],
"errors": [], "errors": [],
}, },
) )
def process(self, element_id): def process(self, element_id: Union[str, UUID]):
""" """
Report that a specific element ID is being processed. Report that a specific element ID is being processed.
:param element_id: ID of the element being processed. :param element_id: ID of the element being processed.
:type element_id: str or uuid.UUID
""" """
# Just call the element initializer # Just call the element initializer
self._get_element(element_id) self._get_element(element_id)
def add_element(self, parent_id, type, type_count=1): def add_element(self, parent_id: Union[str, UUID], type: str, type_count: int = 1):
""" """
Report creating an element as a child of another. Report creating an element as a child of another.
:param parent_id: ID of the parent element. :param parent_id: ID of the parent element.
:type parent_id: str or uuid.UUID :param type: Slug of the type of the child element.
:param type str: Slug of the type of the child element. :param type_count: How many elements of this type were created.
:param type_count int: How many elements of this type were created. Defaults to 1.
""" """
elements = self._get_element(parent_id)["elements"] elements = self._get_element(parent_id)["elements"]
elements.setdefault(type, 0) elements.setdefault(type, 0)
elements[type] += type_count elements[type] += type_count
def add_classification(self, element_id, class_name): def add_classification(self, element_id: Union[str, UUID], class_name: str):
""" """
Report creating a classification on an element. Report creating a classification on an element.
:param element_id: ID of the element. :param element_id: ID of the element.
:type element_id: str or uuid.UUID :param class_name: Name of the ML class of the new classification.
:param class_name str: Name of the ML class of the new classification.
""" """
classifications = self._get_element(element_id)["classifications"] classifications = self._get_element(element_id)["classifications"]
classifications.setdefault(class_name, 0) classifications.setdefault(class_name, 0)
classifications[class_name] += 1 classifications[class_name] += 1
def add_classifications(self, element_id, classifications): def add_classifications(
self, element_id: Union[str, UUID], classifications: List[Dict[str, str]]
):
""" """
Report creating one or more classifications at once on an element. Report creating one or more classifications at once on an element.
:param element_id: ID of the element. :param element_id: ID of the element.
:type element_id: str or uuid.UUID
:param classifications: List of classifications. :param classifications: List of classifications.
Each classification is represented as a ``dict`` with a ``class_name`` key Each classification is represented as a ``dict`` with a ``class_name`` key
holding the name of the ML class being used. holding the name of the ML class being used.
:type classifications: List[Dict[str, str]]
""" """
assert isinstance( assert isinstance(
classifications, list classifications, list
...@@ -110,29 +117,57 @@ class Reporter(object): ...@@ -110,29 +117,57 @@ class Reporter(object):
) )
element["classifications"] = dict(counter) element["classifications"] = dict(counter)
def add_transcription(self, element_id, count=1): def add_transcription(self, element_id: Union[str, UUID], count=1):
""" """
Report creating a transcription on an element. Report creating a transcription on an element.
:param element_id: ID of the element. :param element_id: ID of the element.
:type element_id: str or uuid.UUID :param count: Number of transcriptions created at once
:param count int: Number of transcriptions created at once, defaults to 1.
""" """
self._get_element(element_id)["transcriptions"] += count self._get_element(element_id)["transcriptions"] += count
def add_entity(self, element_id, entity_id, type, name): def add_entity(
self,
element_id: Union[str, UUID],
entity_id: Union[str, UUID],
type: str,
name: str,
):
""" """
Report creating an entity on an element. Report creating an entity on an element.
:param element_id: ID of the element. :param element_id: ID of the element.
:type element_id: str or uuid.UUID :param entity_id: ID of the new entity.
:param entity_id str: ID of the new entity. :param type: Type of the entity.
:param type str: Type of the entity. :param name: Name of the entity.
:param name str: Name of the entity.
""" """
entities = self._get_element(element_id)["entities"] entities = self._get_element(element_id)["entities"]
entities.append({"id": entity_id, "type": type, "name": name}) entities.append({"id": entity_id, "type": type, "name": name})
def add_transcription_entity(
self,
entity_id: Union[str, UUID],
transcription: Transcription,
transcription_entity_id: Union[str, UUID],
):
"""
Report creating a transcription entity on an element.
:param entity_id: ID of the entity element.
:param transcription: Transcription to add the entity on
:param transcription_entity_id: ID of the transcription entity that is created.
"""
transcription_entities = self._get_element(transcription.element.id)[
"transcription_entities"
]
transcription_entities.append(
{
"transcription_id": transcription.id,
"entity_id": entity_id,
"transcription_entity_id": transcription_entity_id,
}
)
def add_entity_link(self, *args, **kwargs): def add_entity_link(self, *args, **kwargs):
""" """
Report creating an entity link. Not currently supported. Report creating an entity link. Not currently supported.
...@@ -149,26 +184,30 @@ class Reporter(object): ...@@ -149,26 +184,30 @@ class Reporter(object):
""" """
raise NotImplementedError raise NotImplementedError
def add_metadata(self, element_id, metadata_id, type, name): def add_metadata(
self,
element_id: Union[str, UUID],
metadata_id: Union[str, UUID],
type: str,
name: str,
):
""" """
Report creating a metadata from an element. Report creating a metadata from an element.
:param element_id: ID of the element. :param element_id: ID of the element.
:type element_id: str or uuid.UUID :param metadata_id: ID of the new metadata.
:param metadata_id str: ID of the new metadata. :param type: Type of the metadata.
:param type str: Type of the metadata. :param name: Name of the metadata.
:param name str: Name of the metadata.
""" """
metadata = self._get_element(element_id)["metadata"] metadata = self._get_element(element_id)["metadata"]
metadata.append({"id": metadata_id, "type": type, "name": name}) metadata.append({"id": metadata_id, "type": type, "name": name})
def error(self, element_id, exception): def error(self, element_id: Union[str, UUID], exception: Exception):
""" """
Report that a Python exception occurred when processing an element. Report that a Python exception occurred when processing an element.
:param element_id: ID of the element. :param element_id: ID of the element.
:type element_id: str or uuid.UUID :param exception: A Python exception.
:param exception Exception: A Python exception.
""" """
error_data = { error_data = {
"class": exception.__class__.__name__, "class": exception.__class__.__name__,
...@@ -186,12 +225,11 @@ class Reporter(object): ...@@ -186,12 +225,11 @@ class Reporter(object):
self._get_element(element_id)["errors"].append(error_data) self._get_element(element_id)["errors"].append(error_data)
def save(self, path): def save(self, path: Union[str, Path]):
""" """
Save the ML report to the specified path. Save the ML report to the specified path.
:param path: Path to save the ML report to. :param path: Path to save the ML report to.
:type path: str or pathlib.Path
""" """
logger.info(f"Saving ML report to {path}") logger.info(f"Saving ML report to {path}")
with open(path, "w") as f: with open(path, "w") as f:
......
...@@ -9,11 +9,15 @@ from timeit import default_timer ...@@ -9,11 +9,15 @@ from timeit import default_timer
class Timer(object): class Timer(object):
""" """
A context manager to help measure execution times. Example usage:: A context manager to help measure execution times.
with Timer() as t: Example
# do something interesting ---
print(t.delta) # X days, X:XX:XX ```
with Timer() as t:
# do something interesting
print(t.delta) # X days, X:XX:XX
```
""" """
def __init__(self): def __init__(self):
......
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import sys import sys
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Iterable, List, Union
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
...@@ -66,7 +67,13 @@ class ElementsWorker( ...@@ -66,7 +67,13 @@ class ElementsWorker(
``arkindex.worker``, which provide helpers to read and write to the Arkindex API. ``arkindex.worker``, which provide helpers to read and write to the Arkindex API.
""" """
def __init__(self, description="Arkindex Elements Worker", support_cache=False): def __init__(
self, description: str = "Arkindex Elements Worker", support_cache: bool = False
):
"""
:param description: The worker's description
:param support_cache: Whether the worker supports cache
"""
super().__init__(description, support_cache) super().__init__(description, support_cache)
# Add mandatory argument to process elements # Add mandatory argument to process elements
...@@ -87,14 +94,13 @@ class ElementsWorker( ...@@ -87,14 +94,13 @@ class ElementsWorker(
self._worker_version_cache = {} self._worker_version_cache = {}
def list_elements(self): def list_elements(self) -> Union[Iterable[CachedElement], List[str]]:
""" """
List the elements to be processed, either from the CLI arguments or List the elements to be processed, either from the CLI arguments or
the cache database when enabled. the cache database when enabled.
:return: An iterable of :class:`CachedElement` when cache support is enabled, :return: An iterable of [CachedElement][arkindex_worker.cache.CachedElement] when cache support is enabled,
and a list of strings representing element IDs otherwise. and a list of strings representing element IDs otherwise.
:rtype: Iterable[CachedElement] or List[str]
""" """
assert not ( assert not (
self.args.elements_list and self.args.element self.args.elements_list and self.args.element
...@@ -121,12 +127,10 @@ class ElementsWorker( ...@@ -121,12 +127,10 @@ class ElementsWorker(
return out return out
@property @property
def store_activity(self): def store_activity(self) -> bool:
""" """
Whether or not WorkerActivity support has been enabled on the DataImport Whether or not WorkerActivity support has been enabled on the DataImport
used to run this worker. used to run this worker.
:rtype: bool
""" """
if self.args.dev: if self.args.dev:
return False return False
...@@ -136,6 +140,9 @@ class ElementsWorker( ...@@ -136,6 +140,9 @@ class ElementsWorker(
return self.process_information.get("activity_state") == "ready" return self.process_information.get("activity_state") == "ready"
def configure(self): def configure(self):
"""
Setup the worker using CLI arguments and environment variables.
"""
# CLI args are stored on the instance so that implementations can access them # CLI args are stored on the instance so that implementations can access them
self.args = self.parser.parse_args() self.args = self.parser.parse_args()
...@@ -153,8 +160,8 @@ class ElementsWorker( ...@@ -153,8 +160,8 @@ class ElementsWorker(
def run(self): def run(self):
""" """
Implements an Arkindex worker that goes through each element returned by Implements an Arkindex worker that goes through each element returned by
:meth:`list_elements`. It calls :meth:`process_element`, catching exceptions [list_elements][arkindex_worker.worker.ElementsWorker.list_elements]. It calls [process_element][arkindex_worker.worker.ElementsWorker.process_element], catching exceptions
and reporting them using the :class:`Reporter`, and handles saving the report and reporting them using the [Reporter][arkindex_worker.reporting.Reporter], and handles saving the report
once the process is complete as well as WorkerActivity updates when enabled. once the process is complete as well as WorkerActivity updates when enabled.
""" """
self.configure() self.configure()
...@@ -235,24 +242,24 @@ class ElementsWorker( ...@@ -235,24 +242,24 @@ class ElementsWorker(
if failed >= count: # Everything failed! if failed >= count: # Everything failed!
sys.exit(1) sys.exit(1)
def process_element(self, element): def process_element(self, element: Union[Element, CachedElement]):
""" """
Override this method to implement your worker and process a single Arkindex element at once. Override this method to implement your worker and process a single Arkindex element at once.
:param element: The element to process. :param element: The element to process.
Will be a CachedElement instance if cache support is enabled, Will be a CachedElement instance if cache support is enabled,
and an Element instance otherwise. and an Element instance otherwise.
:type element: Element or CachedElement
""" """
def update_activity(self, element_id, state) -> bool: def update_activity(
self, element_id: Union[str, uuid.UUID], state: ActivityState
) -> bool:
""" """
Update the WorkerActivity for this element and worker. Update the WorkerActivity for this element and worker.
:param element_id: ID of the element. :param element_id: ID of the element.
:type element_id: str or uuid.UUID :param state: New WorkerActivity state for this element.
:param state ActivityState: New WorkerActivity state for this element. :returns: True if the update has been successful or WorkerActivity support is disabled.
:returns bool: True if the update has been successful or WorkerActivity support is disabled.
False if the update has failed due to a conflict; this worker might have already processed False if the update has failed due to a conflict; this worker might have already processed
this element. this element.
""" """
......
...@@ -8,6 +8,7 @@ import json ...@@ -8,6 +8,7 @@ import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional
import gnupg import gnupg
import yaml import yaml
...@@ -32,11 +33,11 @@ from arkindex_worker.cache import ( ...@@ -32,11 +33,11 @@ from arkindex_worker.cache import (
) )
def _is_500_error(exc) -> bool: def _is_500_error(exc: Exception) -> bool:
""" """
Check if an Arkindex API error has a HTTP 5xx error code. Check if an Arkindex API error has a HTTP 5xx error code.
Used to retry most API calls in :class:`BaseWorker`. Used to retry most API calls in [BaseWorker][arkindex_worker.worker.base.BaseWorker].
:rtype: bool :param exc: Exception to check
""" """
if not isinstance(exc, ErrorResponse): if not isinstance(exc, ErrorResponse):
return False return False
...@@ -44,17 +45,27 @@ def _is_500_error(exc) -> bool: ...@@ -44,17 +45,27 @@ def _is_500_error(exc) -> bool:
return 500 <= exc.status_code < 600 return 500 <= exc.status_code < 600
class ModelNotFoundError(Exception):
"""
Exception raised when the path towards the model is invalid
"""
class BaseWorker(object): class BaseWorker(object):
""" """
Base class for Arkindex workers. Base class for Arkindex workers.
""" """
def __init__(self, description="Arkindex Base Worker", support_cache=False): def __init__(
self,
description: Optional[str] = "Arkindex Base Worker",
support_cache: Optional[bool] = False,
):
""" """
Initialize the worker. Initialize the worker.
:param description str: Description shown in the ``worker-...`` command line tool. :param description: Description shown in the ``worker-...`` command line tool.
:param support_cache bool: Whether or not this worker supports the cache database. :param support_cache: Whether or not this worker supports the cache database.
Override the constructor and set this parameter to start using the cache database. Override the constructor and set this parameter to start using the cache database.
""" """
...@@ -89,6 +100,12 @@ class BaseWorker(object): ...@@ -89,6 +100,12 @@ class BaseWorker(object):
action="store_true", action="store_true",
default=False, default=False,
) )
# To load models locally
self.parser.add_argument(
"--model-dir",
help=("The path to a local model's directory (development only)."),
type=Path,
)
# Call potential extra arguments # Call potential extra arguments
self.add_arguments() self.add_arguments()
...@@ -105,11 +122,10 @@ class BaseWorker(object): ...@@ -105,11 +122,10 @@ class BaseWorker(object):
self.work_dir = os.path.join(xdg_data_home, "arkindex") self.work_dir = os.path.join(xdg_data_home, "arkindex")
os.makedirs(self.work_dir, exist_ok=True) os.makedirs(self.work_dir, exist_ok=True)
self.worker_version_id = os.environ.get("WORKER_VERSION_ID") # Store task ID. This is only available when running in production
if not self.worker_version_id: # through a ponos agent
logger.warning( self.task_id = os.environ.get("PONOS_TASK")
"Missing WORKER_VERSION_ID environment variable, worker is in read-only mode"
)
self.worker_run_id = os.environ.get("ARKINDEX_WORKER_RUN_ID") self.worker_run_id = os.environ.get("ARKINDEX_WORKER_RUN_ID")
if not self.worker_run_id: if not self.worker_run_id:
logger.warning( logger.warning(
...@@ -119,7 +135,11 @@ class BaseWorker(object): ...@@ -119,7 +135,11 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory") logger.info(f"Worker will use {self.work_dir} as working directory")
self.process_information = None self.process_information = None
# corpus_id will be updated in configure() using the worker_run's corpus
# or in configure_for_developers() from the environment
self.corpus_id = None
self.user_configuration = {} self.user_configuration = {}
self.model_configuration = {}
self.support_cache = support_cache self.support_cache = support_cache
# use_cache will be updated in configure() if the cache is supported and if there # use_cache will be updated in configure() if the cache is supported and if there
# is at least one available sqlite database either given or in the parent tasks # is at least one available sqlite database either given or in the parent tasks
...@@ -133,22 +153,24 @@ class BaseWorker(object): ...@@ -133,22 +153,24 @@ class BaseWorker(object):
""" """
Whether or not the worker can publish data. Whether or not the worker can publish data.
:returns: False when dev mode is enabled with the ``--dev`` CLI argument, False when dev mode is enabled with the ``--dev`` CLI argument,
or when no worker version ID is provided. when no worker run ID is provided
:rtype: bool
""" """
return ( return self.args.dev or self.worker_run_id is None
self.args.dev
or self.worker_version_id is None
or self.worker_run_id is None
)
def setup_api_client(self): def setup_api_client(self):
"""
Create an ArkindexClient to make API requests towards Arkindex instances.
"""
# Build Arkindex API client from environment variables # Build Arkindex API client from environment variables
self.api_client = ArkindexClient(**options_from_env()) self.api_client = ArkindexClient(**options_from_env())
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}") logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
def configure_for_developers(self): def configure_for_developers(self):
"""
Setup the necessary configuration needed when working in `read_only` mode.
This is the method called when running a worker locally.
"""
assert self.is_read_only assert self.is_read_only
# Setup logging level if verbose or if ARKINDEX_DEBUG is set to true # Setup logging level if verbose or if ARKINDEX_DEBUG is set to true
if self.args.verbose or os.environ.get("ARKINDEX_DEBUG"): if self.args.verbose or os.environ.get("ARKINDEX_DEBUG"):
...@@ -169,12 +191,20 @@ class BaseWorker(object): ...@@ -169,12 +191,20 @@ class BaseWorker(object):
required_secrets = [] required_secrets = []
logger.warning("Running without any extra configuration") logger.warning("Running without any extra configuration")
# Define corpus_id from environment
self.corpus_id = os.environ.get("ARKINDEX_CORPUS_ID")
if not self.corpus_id:
logger.warning(
"'ARKINDEX_CORPUS_ID' was not set in the environment. Any API request involving a `corpus_id` will fail."
)
# Load all required secrets # Load all required secrets
self.secrets = {name: self.load_secret(name) for name in required_secrets} self.secrets = {name: self.load_secret(name) for name in required_secrets}
def configure(self): def configure(self):
""" """
Configure worker using CLI args and environment variables. Setup the necessary configuration needed using CLI args and environment variables.
This is the method called when running a worker on Arkindex.
""" """
assert not self.is_read_only assert not self.is_read_only
# Setup logging level if verbose or if ARKINDEX_DEBUG is set to true # Setup logging level if verbose or if ARKINDEX_DEBUG is set to true
...@@ -188,15 +218,22 @@ class BaseWorker(object): ...@@ -188,15 +218,22 @@ class BaseWorker(object):
# Load process information # Load process information
self.process_information = worker_run["process"] self.process_information = worker_run["process"]
# Load corpus id
self.corpus_id = worker_run["process"]["corpus"]
# Load worker version information # Load worker version information
worker_version = worker_run["worker_version"] worker_version = worker_run["worker_version"]
# Store worker version id
self.worker_version_id = worker_version["id"]
self.worker_details = worker_version["worker"] self.worker_details = worker_version["worker"]
logger.info( logger.info(
f"Loaded worker {self.worker_details['name']} revision {worker_version['revision']['hash'][0:7]} from API" f"Loaded worker {self.worker_details['name']} revision {worker_version['revision']['hash'][0:7]} from API"
) )
# Retrieve initial configuration from API # Retrieve initial configuration from API
self.config = worker_version["configuration"].get("configuration") self.config = worker_version["configuration"].get("configuration", {})
if "user_configuration" in worker_version["configuration"]: if "user_configuration" in worker_version["configuration"]:
# Add default values (if set) to user_configuration # Add default values (if set) to user_configuration
for key, value in worker_version["configuration"][ for key, value in worker_version["configuration"][
...@@ -211,23 +248,30 @@ class BaseWorker(object): ...@@ -211,23 +248,30 @@ class BaseWorker(object):
# Load worker run configuration when available # Load worker run configuration when available
worker_configuration = worker_run.get("configuration") worker_configuration = worker_run.get("configuration")
self.user_configuration = ( if worker_configuration and worker_configuration.get("configuration"):
worker_configuration.get("configuration") if worker_configuration else None
)
if self.user_configuration:
logger.info("Loaded user configuration from WorkerRun") logger.info("Loaded user configuration from WorkerRun")
# if debug mode is set to true activate debug mode in logger self.user_configuration.update(worker_configuration.get("configuration"))
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG) # Load model version configuration when available
logger.debug("Debug output enabled") model_version = worker_run.get("model_version")
if model_version and model_version.get("configuration"):
logger.info("Loaded model version configuration from WorkerRun")
self.model_configuration.update(model_version.get("configuration"))
# if debug mode is set to true activate debug mode in logger
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
def configure_cache(self): def configure_cache(self):
task_id = os.environ.get("PONOS_TASK") """
Setup the necessary attribute when using the cache system of `Base-Worker`.
"""
paths = None paths = None
if self.support_cache and self.args.database is not None: if self.support_cache and self.args.database is not None:
self.use_cache = True self.use_cache = True
elif self.support_cache and task_id: elif self.support_cache and self.task_id:
task = self.request("RetrieveTaskFromAgent", id=task_id) task = self.request("RetrieveTaskFromAgent", id=self.task_id)
paths = retrieve_parents_cache_path( paths = retrieve_parents_cache_path(
task["parents"], task["parents"],
data_dir=os.environ.get("PONOS_DATA", "/data"), data_dir=os.environ.get("PONOS_DATA", "/data"),
...@@ -242,7 +286,9 @@ class BaseWorker(object): ...@@ -242,7 +286,9 @@ class BaseWorker(object):
), f"Database in {self.args.database} does not exist" ), f"Database in {self.args.database} does not exist"
self.cache_path = self.args.database self.cache_path = self.args.database
else: else:
cache_dir = os.path.join(os.environ.get("PONOS_DATA", "/data"), task_id) cache_dir = os.path.join(
os.environ.get("PONOS_DATA", "/data"), self.task_id
)
assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}" assert os.path.isdir(cache_dir), f"Missing task cache in {cache_dir}"
self.cache_path = os.path.join(cache_dir, "db.sqlite") self.cache_path = os.path.join(cache_dir, "db.sqlite")
...@@ -261,11 +307,11 @@ class BaseWorker(object): ...@@ -261,11 +307,11 @@ class BaseWorker(object):
else: else:
logger.debug("Cache is disabled") logger.debug("Cache is disabled")
def load_secret(self, name): def load_secret(self, name: str):
""" """
Load a Ponos secret by name. Load a Ponos secret by name.
:param str name: Name of the Ponos secret. :param name: Name of the Ponos secret.
:raises Exception: When the secret cannot be loaded from the API nor the local secrets directory. :raises Exception: When the secret cannot be loaded from the API nor the local secrets directory.
""" """
secret = None secret = None
...@@ -312,6 +358,35 @@ class BaseWorker(object): ...@@ -312,6 +358,35 @@ class BaseWorker(object):
# By default give raw secret payload # By default give raw secret payload
return secret return secret
def find_model_directory(self) -> Path:
"""
Find the local path to the model. This supports two modes:
- the worker runs in ponos, the model is available at `/data/current`
- the worker runs locally, the developer may specify it using either
- the `model_dir` configuration parameter
- the `--model-dir` CLI parameter
:return: Path to the model on disk
"""
if self.task_id:
# When running in production with ponos, the agent
# downloads the model and set it in the current task work dir
return Path(self.work_dir)
else:
model_dir = self.config.get("model_dir", self.args.model_dir)
if model_dir is None:
raise ModelNotFoundError(
"No path to the model was provided. "
"Please provide model_dir either through configuration "
"or as CLI argument."
)
model_dir = Path(model_dir)
if not model_dir.exists():
raise ModelNotFoundError(
f"The path {model_dir} does not link to any directory"
)
return model_dir
@retry( @retry(
retry=retry_if_exception(_is_500_error), retry=retry_if_exception(_is_500_error),
wait=wait_exponential(multiplier=2, min=3), wait=wait_exponential(multiplier=2, min=3),
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
ElementsWorker methods for classifications and ML classes. ElementsWorker methods for classifications and ML classes.
""" """
import os from typing import Dict, List, Optional, Union
from uuid import UUID
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from peewee import IntegrityError from peewee import IntegrityError
...@@ -14,50 +15,38 @@ from arkindex_worker.models import Element ...@@ -14,50 +15,38 @@ from arkindex_worker.models import Element
class ClassificationMixin(object): class ClassificationMixin(object):
""" def load_corpus_classes(self):
Mixin for the :class:`ElementsWorker` to add ``MLClass`` and ``Classification`` helpers.
"""
def load_corpus_classes(self, corpus_id):
""" """
Load all ML classes for the given corpus ID and store them in the ``self.classes`` cache. Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache.
:param corpus_id str: ID of the corpus.
""" """
corpus_classes = self.api_client.paginate( corpus_classes = self.api_client.paginate(
"ListCorpusMLClasses", "ListCorpusMLClasses",
id=corpus_id, id=self.corpus_id,
) )
self.classes[corpus_id] = { self.classes[self.corpus_id] = {
ml_class["name"]: ml_class["id"] for ml_class in corpus_classes ml_class["name"]: ml_class["id"] for ml_class in corpus_classes
} }
logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes") logger.info(f"Loaded {len(self.classes[self.corpus_id])} ML classes")
def get_ml_class_id(self, corpus_id, ml_class): def get_ml_class_id(self, ml_class: str) -> str:
""" """
Return the MLClass ID corresponding to the given class name on a specific corpus. Return the MLClass ID corresponding to the given class name on a specific corpus.
If no MLClass exists for this class name, a new one is created. If no MLClass exists for this class name, a new one is created.
:param ml_class: Name of the MLClass.
:param corpus_id: ID of the corpus, or None to use the ``ARKINDEX_CORPUS_ID`` environment variable. :returns: ID of the retrieved or created MLClass.
:type corpus_id: str or None
:param ml_class str: Name of the MLClass.
:returns str: ID of the retrieved or created MLClass.
""" """
if corpus_id is None: if not self.classes.get(self.corpus_id):
corpus_id = os.environ.get("ARKINDEX_CORPUS_ID") self.load_corpus_classes()
if not self.classes.get(corpus_id): ml_class_id = self.classes[self.corpus_id].get(ml_class)
self.load_corpus_classes(corpus_id)
ml_class_id = self.classes[corpus_id].get(ml_class)
if ml_class_id is None: if ml_class_id is None:
logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}") logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
try: try:
response = self.request( response = self.request(
"CreateMLClass", id=corpus_id, body={"name": ml_class} "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
) )
ml_class_id = self.classes[corpus_id][ml_class] = response["id"] ml_class_id = self.classes[self.corpus_id][ml_class] = response["id"]
logger.debug(f"Created ML class {response['id']}") logger.debug(f"Created ML class {response['id']}")
except ErrorResponse as e: except ErrorResponse as e:
# Only reload for 400 errors # Only reload for 400 errors
...@@ -68,26 +57,53 @@ class ClassificationMixin(object): ...@@ -68,26 +57,53 @@ class ClassificationMixin(object):
logger.info( logger.info(
f"Reloading corpus classes to see if {ml_class} already exists" f"Reloading corpus classes to see if {ml_class} already exists"
) )
self.load_corpus_classes(corpus_id) self.load_corpus_classes()
assert ( assert (
ml_class in self.classes[corpus_id] ml_class in self.classes[self.corpus_id]
), "Missing class {ml_class} even after reloading" ), "Missing class {ml_class} even after reloading"
ml_class_id = self.classes[corpus_id][ml_class] ml_class_id = self.classes[self.corpus_id][ml_class]
return ml_class_id return ml_class_id
def retrieve_ml_class(self, ml_class_id: str) -> str:
"""
Retrieve the name of the MLClass from its ID.
:param ml_class_id: ID of the searched MLClass.
:return: The MLClass's name
"""
# Load the corpus' MLclasses if they are not available yet
if self.corpus_id not in self.classes:
self.load_corpus_classes()
# Filter classes by this ml_class_id
ml_class_name = next(
filter(
lambda x: self.classes[self.corpus_id][x] == ml_class_id,
self.classes[self.corpus_id],
),
None,
)
assert (
ml_class_name is not None
), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})"
return ml_class_name
def create_classification( def create_classification(
self, element, ml_class, confidence, high_confidence=False self,
): element: Union[Element, CachedElement],
ml_class: str,
confidence: float,
high_confidence: Optional[bool] = False,
) -> Dict[str, str]:
""" """
Create a classification on the given element through the API. Create a classification on the given element through the API.
:param element: The element to create a classification on. :param element: The element to create a classification on.
:type element: Element or CachedElement :param ml_class: Name of the MLClass to use.
:param ml_class str: Name of the MLClass to use. :param confidence: Confidence score for the classification. Must be between 0 and 1.
:param confidence float: Confidence score for the classification. Must be between 0 and 1. :param high_confidence: Whether or not the classification is of high confidence.
:param high_confidence bool: Whether or not the classification is of high confidence. :returns: The created classification, as returned by the ``CreateClassification`` API endpoint.
:returns dict: The created classification, as returned by the ``CreateClassification`` API endpoint.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
...@@ -106,14 +122,13 @@ class ClassificationMixin(object): ...@@ -106,14 +122,13 @@ class ClassificationMixin(object):
"Cannot create classification as this worker is in read-only mode" "Cannot create classification as this worker is in read-only mode"
) )
return return
try: try:
created = self.request( created = self.request(
"CreateClassification", "CreateClassification",
body={ body={
"element": str(element.id), "element": str(element.id),
"ml_class": self.get_ml_class_id(None, ml_class), "ml_class": self.get_ml_class_id(ml_class),
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
"confidence": confidence, "confidence": confidence,
"high_confidence": high_confidence, "high_confidence": high_confidence,
}, },
...@@ -129,7 +144,7 @@ class ClassificationMixin(object): ...@@ -129,7 +144,7 @@ class ClassificationMixin(object):
"class_name": ml_class, "class_name": ml_class,
"confidence": created["confidence"], "confidence": created["confidence"],
"state": created["state"], "state": created["state"],
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
] ]
CachedClassification.insert_many(to_insert).execute() CachedClassification.insert_many(to_insert).execute()
...@@ -139,15 +154,23 @@ class ClassificationMixin(object): ...@@ -139,15 +154,23 @@ class ClassificationMixin(object):
) )
except ErrorResponse as e: except ErrorResponse as e:
# Detect already existing classification # Detect already existing classification
if ( if e.status_code == 400 and "non_field_errors" in e.content:
e.status_code == 400 if (
and "non_field_errors" in e.content "The fields element, worker_version, ml_class must make a unique set."
and "The fields element, worker_version, ml_class must make a unique set." in e.content["non_field_errors"]
in e.content["non_field_errors"] ):
): logger.warning(
logger.warning( f"This worker version has already set {ml_class} on element {element.id}"
f"This worker version has already set {ml_class} on element {element.id}" )
) elif (
"The fields element, worker_run, ml_class must make a unique set."
in e.content["non_field_errors"]
):
logger.warning(
f"This worker run has already set {ml_class} on element {element.id}"
)
else:
raise
return return
# Propagate any other API error # Propagate any other API error
...@@ -157,25 +180,22 @@ class ClassificationMixin(object): ...@@ -157,25 +180,22 @@ class ClassificationMixin(object):
return created return created
def create_classifications(self, element, classifications): def create_classifications(
self,
element: Union[Element, CachedElement],
classifications: List[Dict[str, Union[str, float, bool]]],
) -> List[Dict[str, Union[str, float, bool]]]:
""" """
Create multiple classifications at once on the given element through the API. Create multiple classifications at once on the given element through the API.
:param element: The element to create classifications on. :param element: The element to create classifications on.
:type element: Element or CachedElement :param classifications: The classifications to create, a list of dicts. Each of them contains
:param classifications: The classifications to create, as a list of dicts with the following keys: a **ml_class_id** (str), the ID of the MLClass for this classification;
a **confidence** (float), the confidence score, between 0 and 1;
class_name (str) a **high_confidence** (bool), the high confidence state of the classification.
Name of the MLClass for this classification.
confidence (float)
Confidence score, between 0 and 1.
high_confidence (bool)
High confidence state of the classification.
:type classifications: List[Dict[str, Union[str, float, bool]]]
:returns: List of created classifications, as returned in the ``classifications`` field by :returns: List of created classifications, as returned in the ``classifications`` field by
the ``CreateClassifications`` API endpoint. the ``CreateClassifications`` API endpoint.
:rtype: List[Dict[str, Union[str, float, bool]]]
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
...@@ -185,10 +205,18 @@ class ClassificationMixin(object): ...@@ -185,10 +205,18 @@ class ClassificationMixin(object):
), "classifications shouldn't be null and should be of type list" ), "classifications shouldn't be null and should be of type list"
for index, classification in enumerate(classifications): for index, classification in enumerate(classifications):
class_name = classification.get("class_name") ml_class_id = classification.get("ml_class_id")
assert class_name and isinstance( assert ml_class_id and isinstance(
class_name, str ml_class_id, str
), f"Classification at index {index} in classifications: class_name shouldn't be null and should be of type str" ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
# Make sure it's a valid UUID
try:
UUID(ml_class_id)
except ValueError:
raise ValueError(
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
)
confidence = classification.get("confidence") confidence = classification.get("confidence")
assert ( assert (
...@@ -213,12 +241,13 @@ class ClassificationMixin(object): ...@@ -213,12 +241,13 @@ class ClassificationMixin(object):
"CreateClassifications", "CreateClassifications",
body={ body={
"parent": str(element.id), "parent": str(element.id),
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
"classifications": classifications, "classifications": classifications,
}, },
)["classifications"] )["classifications"]
for created_cl in created_cls: for created_cl in created_cls:
created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
self.report.add_classification(element.id, created_cl["class_name"]) self.report.add_classification(element.id, created_cl["class_name"])
if self.use_cache: if self.use_cache:
...@@ -228,10 +257,10 @@ class ClassificationMixin(object): ...@@ -228,10 +257,10 @@ class ClassificationMixin(object):
{ {
"id": created_cl["id"], "id": created_cl["id"],
"element_id": element.id, "element_id": element.id,
"class_name": created_cl["class_name"], "class_name": created_cl.pop("class_name"),
"confidence": created_cl["confidence"], "confidence": created_cl["confidence"],
"state": created_cl["state"], "state": created_cl["state"],
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
for created_cl in created_cls for created_cl in created_cls
] ]
......
...@@ -2,17 +2,23 @@ ...@@ -2,17 +2,23 @@
""" """
ElementsWorker methods for elements and element types. ElementsWorker methods for elements and element types.
""" """
import uuid
from typing import Dict, Iterable, List, NamedTuple, Optional, Union from typing import Dict, Iterable, List, NamedTuple, Optional, Union
from peewee import IntegrityError from peewee import IntegrityError
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedImage from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Corpus, Element from arkindex_worker.models import Element
class ElementType(NamedTuple):
"""
Arkindex Type of an element
"""
ElementType = NamedTuple("ElementType", name=str, slug=str, is_folder=bool) name: str
slug: str
is_folder: bool
class MissingTypeError(Exception): class MissingTypeError(Exception):
...@@ -22,15 +28,11 @@ class MissingTypeError(Exception): ...@@ -22,15 +28,11 @@ class MissingTypeError(Exception):
class ElementMixin(object): class ElementMixin(object):
""" def create_required_types(self, element_types: List[ElementType]):
Mixin for the :class:`ElementsWorker` to provide ``Element`` helpers.
"""
def create_required_types(self, corpus: Corpus, element_types: List[ElementType]):
"""Creates given element types in the corpus. """Creates given element types in the corpus.
:param Corpus corpus: The corpus to create types on. :param Corpus corpus: The corpus to create types on.
:param List[ElementType] element_types: The missing element types to create. :param element_types: The missing element types to create.
""" """
for element_type in element_types: for element_type in element_types:
self.request( self.request(
...@@ -39,47 +41,42 @@ class ElementMixin(object): ...@@ -39,47 +41,42 @@ class ElementMixin(object):
"slug": element_type.slug, "slug": element_type.slug,
"display_name": element_type.name, "display_name": element_type.name,
"folder": element_type.is_folder, "folder": element_type.is_folder,
"corpus": corpus.id, "corpus": self.corpus_id,
}, },
) )
logger.info(f"Created a new element type with slug {element_type.slug}") logger.info(f"Created a new element type with slug {element_type.slug}")
def check_required_types( def check_required_types(
self, corpus_id: str, *type_slugs: str, create_missing: bool = False self, *type_slugs: str, create_missing: bool = False
) -> bool: ) -> bool:
""" """
Check that a corpus has a list of required element types, Check that a corpus has a list of required element types,
and raise an exception if any of them are missing. and raise an exception if any of them are missing.
:param str corpus_id: ID of the corpus to check types on. :param *type_slugs: Type slugs to look for.
:param str \\*type_slugs: Type slugs to look for. :param create_missing: Whether missing types should be created.
:param bool create_missing: Whether missing types should be created. :returns: Whether all of the specified type slugs have been found.
:returns bool: True if all of the specified type slugs have been found.
:raises MissingTypeError: If any of the specified type slugs were not found. :raises MissingTypeError: If any of the specified type slugs were not found.
""" """
assert isinstance(
corpus_id, (uuid.UUID, str)
), "Corpus ID should be a string or UUID"
assert len(type_slugs), "At least one element type slug is required." assert len(type_slugs), "At least one element type slug is required."
assert all( assert all(
isinstance(slug, str) for slug in type_slugs isinstance(slug, str) for slug in type_slugs
), "Element type slugs must be strings." ), "Element type slugs must be strings."
corpus = Corpus(self.request("RetrieveCorpus", id=corpus_id)) corpus = self.request("RetrieveCorpus", id=self.corpus_id)
available_slugs = {element_type.slug for element_type in corpus.types} available_slugs = {element_type["slug"] for element_type in corpus["types"]}
missing_slugs = set(type_slugs) - available_slugs missing_slugs = set(type_slugs) - available_slugs
if missing_slugs: if missing_slugs:
if create_missing: if create_missing:
self.create_required_types( self.create_required_types(
corpus,
element_types=[ element_types=[
ElementType(slug, slug, False) for slug in missing_slugs ElementType(slug, slug, False) for slug in missing_slugs
], ],
) )
else: else:
raise MissingTypeError( raise MissingTypeError(
f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus.name} corpus ({corpus.id}).' f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus["name"]} corpus ({corpus["id"]}).'
) )
return True return True
...@@ -91,19 +88,17 @@ class ElementMixin(object): ...@@ -91,19 +88,17 @@ class ElementMixin(object):
name: str, name: str,
polygon: List[List[Union[int, float]]], polygon: List[List[Union[int, float]]],
confidence: Optional[float] = None, confidence: Optional[float] = None,
slim_output: bool = True, slim_output: Optional[bool] = True,
) -> str: ) -> str:
""" """
Create a child element on the given element through the API. Create a child element on the given element through the API.
:param Element element: The parent element. :param Element element: The parent element.
:param str type: Slug of the element type for this child element. :param type: Slug of the element type for this child element.
:param str name: Name of the child element. :param name: Name of the child element.
:param polygon: Polygon of the child element. :param polygon: Polygon of the child element.
:type polygon: list(list(int or float))
:param confidence: Optional confidence score, between 0.0 and 1.0. :param confidence: Optional confidence score, between 0.0 and 1.0.
:type confidence: float or None :returns: UUID of the created element.
:returns str: UUID of the created element.
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, Element
...@@ -143,7 +138,7 @@ class ElementMixin(object): ...@@ -143,7 +138,7 @@ class ElementMixin(object):
"corpus": element.corpus.id, "corpus": element.corpus.id,
"polygon": polygon, "polygon": polygon,
"parent": element.id, "parent": element.id,
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
"confidence": confidence, "confidence": confidence,
}, },
) )
...@@ -162,7 +157,6 @@ class ElementMixin(object): ...@@ -162,7 +157,6 @@ class ElementMixin(object):
Create child elements on the given element in a single API request. Create child elements on the given element in a single API request.
:param parent: Parent element for all the new child elements. The parent must have an image and a polygon. :param parent: Parent element for all the new child elements. The parent must have an image and a polygon.
:type parent: Element or CachedElement
:param elements: List of dicts, one per element. Each dict can have the following keys: :param elements: List of dicts, one per element. Each dict can have the following keys:
name (str) name (str)
...@@ -178,9 +172,7 @@ class ElementMixin(object): ...@@ -178,9 +172,7 @@ class ElementMixin(object):
confidence (float or None) confidence (float or None)
Optional confidence score, between 0.0 and 1.0. Optional confidence score, between 0.0 and 1.0.
:type elements: list(dict(str, Any))
:return: List of dicts, with each dict having a single key, ``id``, holding the UUID of each created element. :return: List of dicts, with each dict having a single key, ``id``, holding the UUID of each created element.
:rtype: list(dict(str, str))
""" """
if isinstance(parent, Element): if isinstance(parent, Element):
assert parent.get( assert parent.get(
...@@ -241,7 +233,7 @@ class ElementMixin(object): ...@@ -241,7 +233,7 @@ class ElementMixin(object):
"CreateElements", "CreateElements",
id=parent.id, id=parent.id,
body={ body={
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
"elements": elements, "elements": elements,
}, },
) )
...@@ -273,7 +265,7 @@ class ElementMixin(object): ...@@ -273,7 +265,7 @@ class ElementMixin(object):
"type": element["type"], "type": element["type"],
"image_id": image_id, "image_id": image_id,
"polygon": element["polygon"], "polygon": element["polygon"],
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
"confidence": element.get("confidence"), "confidence": element.get("confidence"),
} }
for idx, element in enumerate(elements) for idx, element in enumerate(elements)
...@@ -301,38 +293,27 @@ class ElementMixin(object): ...@@ -301,38 +293,27 @@ class ElementMixin(object):
List children of an element. List children of an element.
:param element: Parent element to find children of. :param element: Parent element to find children of.
:type element: Union[Element, CachedElement]
:param folder: Restrict to or exclude elements with folder types. :param folder: Restrict to or exclude elements with folder types.
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type folder: Optional[bool]
:param name: Restrict to elements whose name contain a substring (case-insensitive). :param name: Restrict to elements whose name contain a substring (case-insensitive).
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type name: Optional[str]
:param recursive: Look for elements recursively (grand-children, etc.) :param recursive: Look for elements recursively (grand-children, etc.)
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type recursive: Optional[bool]
:param type: Restrict to elements with a specific type slug :param type: Restrict to elements with a specific type slug
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type type: Optional[str]
:param with_classes: Include each element's classifications in the response. :param with_classes: Include each element's classifications in the response.
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type with_classes: Optional[bool]
:param with_corpus: Include each element's corpus in the response. :param with_corpus: Include each element's corpus in the response.
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type with_corpus: Optional[bool]
:param with_has_children: Include the ``has_children`` attribute in the response, :param with_has_children: Include the ``has_children`` attribute in the response,
indicating if this element has child elements of its own. indicating if this element has child elements of its own.
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type with_has_children: Optional[bool]
:param with_zone: Include the ``zone`` attribute in the response, :param with_zone: Include the ``zone`` attribute in the response,
holding the element's image and polygon. holding the element's image and polygon.
This parameter is not supported when caching is enabled. This parameter is not supported when caching is enabled.
:type with_zone: Optional[bool]
:param worker_version: Restrict to elements created by a worker version with this UUID. :param worker_version: Restrict to elements created by a worker version with this UUID.
:type worker_version: Optional[Union[str, bool]]
:return: An iterable of dicts from the ``ListElementChildren`` API endpoint, :return: An iterable of dicts from the ``ListElementChildren`` API endpoint,
or an iterable of :class:`CachedElement` when caching is enabled. or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedElement]]
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
ElementsWorker methods for entities. ElementsWorker methods for entities.
""" """
import os
from enum import Enum from enum import Enum
from typing import Dict, Optional, Union
from peewee import IntegrityError from peewee import IntegrityError
...@@ -28,29 +28,23 @@ class EntityType(Enum): ...@@ -28,29 +28,23 @@ class EntityType(Enum):
class EntityMixin(object): class EntityMixin(object):
"""
Mixin for the :class:`ElementsWorker` to add ``Entity`` helpers.
"""
def create_entity( def create_entity(
self, element, name, type, corpus=None, metas=dict(), validated=None self,
element: Union[Element, CachedElement],
name: str,
type: EntityType,
metas=dict(),
validated=None,
): ):
""" """
Create an entity on the given corpus. Create an entity on the given corpus.
If cache support is enabled, a :class:`CachedEntity` will also be created. If cache support is enabled, a [CachedEntity][arkindex_worker.cache.CachedEntity] will also be created.
:param element: An element on which the entity will be reported with the :class:`Reporter`. :param element: An element on which the entity will be reported with the [Reporter][arkindex_worker.reporting.Reporter].
This does not have any effect on the entity itself. This does not have any effect on the entity itself.
:type element: Element or CachedElement :param name: Name of the entity.
:param name str: Name of the entity. :param type: Type of the entity.
:param type EntityType: Type of the entity.
:param corpus: UUID of the corpus to create an entity on, or None to use the
value of the ``ARKINDEX_CORPUS_ID`` environment variable.
:type corpus: str or None
""" """
if corpus is None:
corpus = os.environ.get("ARKINDEX_CORPUS_ID")
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
...@@ -60,9 +54,6 @@ class EntityMixin(object): ...@@ -60,9 +54,6 @@ class EntityMixin(object):
assert type and isinstance( assert type and isinstance(
type, EntityType type, EntityType
), "type shouldn't be null and should be of type EntityType" ), "type shouldn't be null and should be of type EntityType"
assert corpus and isinstance(
corpus, str
), "corpus shouldn't be null and should be of type str"
if metas: if metas:
assert isinstance(metas, dict), "metas should be of type dict" assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None: if validated is not None:
...@@ -78,8 +69,8 @@ class EntityMixin(object): ...@@ -78,8 +69,8 @@ class EntityMixin(object):
"type": type.value, "type": type.value,
"metas": metas, "metas": metas,
"validated": validated, "validated": validated,
"corpus": corpus, "corpus": self.corpus_id,
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
}, },
) )
self.report.add_entity(element.id, entity["id"], type.value, name) self.report.add_entity(element.id, entity["id"], type.value, name)
...@@ -94,7 +85,7 @@ class EntityMixin(object): ...@@ -94,7 +85,7 @@ class EntityMixin(object):
"name": name, "name": name,
"validated": validated if validated is not None else False, "validated": validated if validated is not None else False,
"metas": metas, "metas": metas,
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
] ]
CachedEntity.insert_many(to_insert).execute() CachedEntity.insert_many(to_insert).execute()
...@@ -104,26 +95,29 @@ class EntityMixin(object): ...@@ -104,26 +95,29 @@ class EntityMixin(object):
return entity["id"] return entity["id"]
def create_transcription_entity( def create_transcription_entity(
self, transcription, entity, offset, length, confidence=None self,
): transcription: Transcription,
entity: str,
offset: int,
length: int,
confidence: Optional[float] = None,
) -> Optional[Dict[str, Union[str, int]]]:
""" """
Create a link between an existing entity and an existing transcription. Create a link between an existing entity and an existing transcription.
If cache support is enabled, a :class:`CachedTranscriptionEntity` will also be created. If cache support is enabled, a `CachedTranscriptionEntity` will also be created.
:param transcription str: UUID of the existing transcription. :param transcription: Transcription to create the entity on.
:param entity str: UUID of the existing entity. :param entity: UUID of the existing entity.
:param offset int: Starting position of the entity in the transcription's text, :param offset: Starting position of the entity in the transcription's text,
as a 0-based index. as a 0-based index.
:param length int: Length of the entity in the transcription's text. :param length: Length of the entity in the transcription's text.
:param confidence: Optional confidence score between 0 or 1. :param confidence: Optional confidence score between 0 or 1.
:type confidence: float or None
:returns: A dict as returned by the ``CreateTranscriptionEntity`` API endpoint, :returns: A dict as returned by the ``CreateTranscriptionEntity`` API endpoint,
or None if the worker is in read-only mode. or None if the worker is in read-only mode.
:rtype: dict(str, str or int) or None
""" """
assert transcription and isinstance( assert transcription and isinstance(
transcription, str transcription, Transcription
), "transcription shouldn't be null and should be of type str" ), "transcription shouldn't be null and should be a Transcription"
assert entity and isinstance( assert entity and isinstance(
entity, str entity, str
), "entity shouldn't be null and should be of type str" ), "entity shouldn't be null and should be of type str"
...@@ -146,27 +140,27 @@ class EntityMixin(object): ...@@ -146,27 +140,27 @@ class EntityMixin(object):
"entity": entity, "entity": entity,
"length": length, "length": length,
"offset": offset, "offset": offset,
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
if confidence is not None: if confidence is not None:
body["confidence"] = confidence body["confidence"] = confidence
transcription_ent = self.request( transcription_ent = self.request(
"CreateTranscriptionEntity", "CreateTranscriptionEntity",
id=transcription, id=transcription.id,
body=body, body=body,
) )
# TODO: Report transcription entity creation self.report.add_transcription_entity(entity, transcription, transcription_ent)
if self.use_cache: if self.use_cache:
# Store transcription entity in local cache # Store transcription entity in local cache
try: try:
CachedTranscriptionEntity.create( CachedTranscriptionEntity.create(
transcription=transcription, transcription=transcription.id,
entity=entity, entity=entity,
offset=offset, offset=offset,
length=length, length=length,
worker_version_id=self.worker_version_id, worker_run_id=self.worker_run_id,
confidence=confidence, confidence=confidence,
) )
except IntegrityError as e: except IntegrityError as e:
...@@ -178,14 +172,14 @@ class EntityMixin(object): ...@@ -178,14 +172,14 @@ class EntityMixin(object):
def list_transcription_entities( def list_transcription_entities(
self, self,
transcription: Transcription, transcription: Transcription,
worker_version: bool = None, worker_version: Optional[Union[str, bool]] = None,
): ):
""" """
List existing entities on a transcription List existing entities on a transcription
This method does not support cache This method does not support cache
:param transcription Transcription: The transcription to list entities on. :param transcription: The transcription to list entities on.
:param worker_version str or bool: Restrict to entities created by a worker version with this UUID. Set to False to look for manually created transcriptions. :param worker_version: Restrict to entities created by a worker version with this UUID. Set to False to look for manually created transcriptions.
""" """
query_params = {} query_params = {}
assert transcription and isinstance( assert transcription and isinstance(
...@@ -206,3 +200,28 @@ class EntityMixin(object): ...@@ -206,3 +200,28 @@ class EntityMixin(object):
return self.api_client.paginate( return self.api_client.paginate(
"ListTranscriptionEntities", id=transcription.id, **query_params "ListTranscriptionEntities", id=transcription.id, **query_params
) )
def list_corpus_entities(
self,
name: Optional[str] = None,
parent: Optional[Element] = None,
):
"""
List all entities in the worker's corpus
This method does not support cache
:param name: Filter entities by part of their name (case-insensitive)
:param parent Element: Restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored.
"""
query_params = {}
if name is not None:
assert name and isinstance(name, str), "name should be of type str"
query_params["name"] = name
if parent is not None:
assert isinstance(parent, Element), "parent should be of type Element"
query_params["parent"] = parent.id
return self.api_client.paginate(
"ListCorpusEntities", id=self.corpus_id, **query_params
)
...@@ -4,8 +4,10 @@ ElementsWorker methods for metadata. ...@@ -4,8 +4,10 @@ ElementsWorker methods for metadata.
""" """
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
from arkindex_worker.models import Element from arkindex_worker.models import Element
...@@ -56,25 +58,27 @@ class MetaType(Enum): ...@@ -56,25 +58,27 @@ class MetaType(Enum):
class MetaDataMixin(object): class MetaDataMixin(object):
""" def create_metadata(
Mixin for the :class:`ElementsWorker` to add ``MetaData`` helpers. self,
""" element: Union[Element, CachedElement],
type: MetaType,
def create_metadata(self, element, type, name, value, entity=None): name: str,
value: str,
entity: Optional[str] = None,
) -> str:
""" """
Create a metadata on the given element through API. Create a metadata on the given element through API.
:param element Element: The element to create a metadata on. :param element: The element to create a metadata on.
:param type MetaType: Type of the metadata. :param type: Type of the metadata.
:param name str: Name of the metadata. :param name: Name of the metadata.
:param value str: Value of the metadata. :param value: Value of the metadata.
:param entity: UUID of an entity this metadata is related to. :param entity: UUID of an entity this metadata is related to.
:type entity: str or None :returns: UUID of the created metadata.
:returns str: UUID of the created metadata.
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element" ), "element shouldn't be null and should be of type Element or CachedElement"
assert type and isinstance( assert type and isinstance(
type, MetaType type, MetaType
), "type shouldn't be null and should be of type MetaType" ), "type shouldn't be null and should be of type MetaType"
...@@ -98,7 +102,7 @@ class MetaDataMixin(object): ...@@ -98,7 +102,7 @@ class MetaDataMixin(object):
"name": name, "name": name,
"value": value, "value": value,
"entity_id": entity, "entity_id": entity,
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
}, },
) )
self.report.add_metadata(element.id, metadata["id"], type.value, name) self.report.add_metadata(element.id, metadata["id"], type.value, name)
...@@ -107,9 +111,13 @@ class MetaDataMixin(object): ...@@ -107,9 +111,13 @@ class MetaDataMixin(object):
def create_metadatas( def create_metadatas(
self, self,
element: Element, element: Union[Element, CachedElement],
metadatas: list, metadatas: List[
): Dict[
str, Union[MetaType, str, Union[str, Union[int, float]], Optional[str]]
]
],
) -> List[Dict[str, str]]:
""" """
Create multiple metadatas on an existing element. Create multiple metadatas on an existing element.
This method does not support cache. This method does not support cache.
...@@ -122,8 +130,8 @@ class MetaDataMixin(object): ...@@ -122,8 +130,8 @@ class MetaDataMixin(object):
- entity_id : Union[str, None] - entity_id : Union[str, None]
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element" ), "element shouldn't be null and should be of type Element or CachedElement"
assert metadatas and isinstance( assert metadatas and isinstance(
metadatas, list metadatas, list
...@@ -144,7 +152,7 @@ class MetaDataMixin(object): ...@@ -144,7 +152,7 @@ class MetaDataMixin(object):
metadata.get("name"), str metadata.get("name"), str
), "name shouldn't be null and should be of type str" ), "name shouldn't be null and should be of type str"
assert metadata.get("value") and isinstance( assert metadata.get("value") is not None and isinstance(
metadata.get("value"), (str, float, int) metadata.get("value"), (str, float, int)
), "value shouldn't be null and should be of type (str or float or int)" ), "value shouldn't be null and should be of type (str or float or int)"
...@@ -169,7 +177,6 @@ class MetaDataMixin(object): ...@@ -169,7 +177,6 @@ class MetaDataMixin(object):
"CreateMetaDataBulk", "CreateMetaDataBulk",
id=element.id, id=element.id,
body={ body={
"worker_version": self.worker_version_id,
"worker_run_id": self.worker_run_id, "worker_run_id": self.worker_run_id,
"metadata_list": metas, "metadata_list": metas,
}, },
...@@ -180,15 +187,17 @@ class MetaDataMixin(object): ...@@ -180,15 +187,17 @@ class MetaDataMixin(object):
return created_metadatas return created_metadatas
def list_metadata(self, element: Element): def list_element_metadata(
self, element: Union[Element, CachedElement]
) -> List[Dict[str, str]]:
""" """
List all metadata linked to an element. List all metadata linked to an element.
This method does not support cache. This method does not support cache.
:param element Element: The element to list metadata on. :param element: The element to list metadata on.
""" """
assert element and isinstance( assert element and isinstance(
element, Element element, (Element, CachedElement)
), "element shouldn't be null and should be of type Element" ), "element shouldn't be null and should be of type Element or CachedElement"
return self.api_client.paginate("ListElementMetaData", id=element.id) return self.api_client.paginate("ListElementMetaData", id=element.id)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
BaseWorker methods for training.
"""
import hashlib import hashlib
import os import os
import tarfile import tarfile
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from typing import NewType, Tuple from pathlib import Path
from typing import NewType, Optional, Tuple
import requests import requests
import zstandard as zstd import zstandard as zstd
...@@ -15,15 +20,24 @@ from arkindex_worker import logger ...@@ -15,15 +20,24 @@ from arkindex_worker import logger
CHUNK_SIZE = 1024 CHUNK_SIZE = 1024
DirPath = NewType("DirPath", str) DirPath = NewType("DirPath", str)
"""Path to a directory"""
Hash = NewType("Hash", str) Hash = NewType("Hash", str)
"""MD5 Hash"""
FileSize = NewType("FileSize", int) FileSize = NewType("FileSize", int)
Archive = Tuple[DirPath, Hash, FileSize] """File size"""
@contextmanager @contextmanager
def create_archive(path: DirPath) -> Archive: def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]:
"""First create a tar archive, then compress to a zst archive. """
Finally, get its hash and size Create a tar archive from the files at the given location then compress it to a zst archive.
Yield its location, its hash, its size and its content's hash.
:param path: Create a compressed tar archive from the files
:returns: The location of the created archive, its hash, its size and its content's hash
""" """
assert path.is_dir(), "create_archive needs a directory" assert path.is_dir(), "create_archive needs a directory"
...@@ -41,7 +55,9 @@ def create_archive(path: DirPath) -> Archive: ...@@ -41,7 +55,9 @@ def create_archive(path: DirPath) -> Archive:
for p in path.glob("**/*"): for p in path.glob("**/*"):
x = p.relative_to(path) x = p.relative_to(path)
tar.add(p, arcname=x, recursive=False) tar.add(p, arcname=x, recursive=False)
file_list.append(p) # Only keep files when computing the hash
if p.is_file():
file_list.append(p)
# Sort by path # Sort by path
file_list.sort() file_list.sort()
...@@ -76,17 +92,25 @@ def create_archive(path: DirPath) -> Archive: ...@@ -76,17 +92,25 @@ def create_archive(path: DirPath) -> Archive:
class TrainingMixin(object): class TrainingMixin(object):
"""
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version( def publish_model_version(
self, model_path: DirPath, model_id: str, tag: str = None, description: str = "" self,
model_path: DirPath,
model_id: str,
tag: Optional[str] = None,
description: Optional[str] = None,
configuration: Optional[dict] = {},
): ):
""" """
This method creates a model archive and its associated hash, This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex. to create a unique version that will be stored on a bucket and published on Arkindex.
:param model_path: Path to the directory containing the model version's files.
:param model_id: ID of the model
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
""" """
if self.is_read_only: if self.is_read_only:
logger.warning( logger.warning(
"Cannot publish a new model version as this worker is in read-only mode" "Cannot publish a new model version as this worker is in read-only mode"
...@@ -118,7 +142,7 @@ class TrainingMixin(object): ...@@ -118,7 +142,7 @@ class TrainingMixin(object):
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker) # Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self.update_model_version( self.update_model_version(
model_version_details=model_version_details, model_version_details=model_version_details, configuration=configuration
) )
def create_model_version( def create_model_version(
...@@ -144,30 +168,29 @@ class TrainingMixin(object): ...@@ -144,30 +168,29 @@ class TrainingMixin(object):
# Create a new model version with hash and size # Create a new model version with hash and size
try: try:
payload = {"hash": hash, "size": size, "archive_hash": archive_hash}
if tag:
payload["tag"] = tag
if description:
payload["description"] = description
model_version_details = self.request( model_version_details = self.request(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
body={ body=payload,
"hash": hash,
"size": size,
"archive_hash": archive_hash,
"tag": tag,
"description": description,
},
) )
logger.info( logger.info(
f"Model version ({model_version_details['id']}) was created successfully" f"Model version ({model_version_details['id']}) was created successfully"
) )
except ErrorResponse as e: except ErrorResponse as e:
if e.status_code >= 500: model_version_details = (
e.content.get("hash") if hasattr(e, "content") else None
)
if e.status_code >= 500 or model_version_details is None:
logger.error(f"Failed to create model version: {e.content}") logger.error(f"Failed to create model version: {e.content}")
raise e
model_version_details = e.content.get("hash")
# If the existing model is in Created state, this model is returned as a dict. # If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned. # Else an error in a list is returned.
if model_version_details and isinstance( if isinstance(model_version_details, (list, tuple)):
model_version_details, (list, tuple)
):
logger.error(model_version_details[0]) logger.error(model_version_details[0])
return return
...@@ -201,7 +224,7 @@ class TrainingMixin(object): ...@@ -201,7 +224,7 @@ class TrainingMixin(object):
def update_model_version( def update_model_version(
self, self,
model_version_details: dict, model_version_details: dict,
configuration: dict = {}, configuration: dict,
) -> None: ) -> None:
""" """
Update the specified model version to the state `Available` and use the given information" Update the specified model version to the state `Available` and use the given information"
......
...@@ -4,7 +4,7 @@ ElementsWorker methods for transcriptions. ...@@ -4,7 +4,7 @@ ElementsWorker methods for transcriptions.
""" """
from enum import Enum from enum import Enum
from typing import Iterable, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from peewee import IntegrityError from peewee import IntegrityError
...@@ -41,28 +41,22 @@ class TextOrientation(Enum): ...@@ -41,28 +41,22 @@ class TextOrientation(Enum):
class TranscriptionMixin(object): class TranscriptionMixin(object):
"""
Mixin for the :class:`ElementsWorker` to provide ``Transcription`` helpers.
"""
def create_transcription( def create_transcription(
self, self,
element, element: Union[Element, CachedElement],
text, text: str,
confidence, confidence: float,
orientation=TextOrientation.HorizontalLeftToRight, orientation: TextOrientation = TextOrientation.HorizontalLeftToRight,
): ) -> Optional[Dict[str, Union[str, float]]]:
""" """
Create a transcription on the given element through the API. Create a transcription on the given element through the API.
:param element: Element to create a transcription on. :param element: Element to create a transcription on.
:type element: Element or CachedElement :param text: Text of the transcription.
:param text str: Text of the transcription. :param confidence: Confidence score, between 0 and 1.
:param confidence float: Confidence score, between 0 and 1. :param orientation: Orientation of the transcription's text.
:param orientation TextOrientation: Orientation of the transcription's text.
:returns: A dict as returned by the ``CreateTranscription`` API endpoint, :returns: A dict as returned by the ``CreateTranscription`` API endpoint,
or None if the worker is in read-only mode. or None if the worker is in read-only mode.
:rtype: Dict[str, Union[str, float]] or None
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
...@@ -88,7 +82,7 @@ class TranscriptionMixin(object): ...@@ -88,7 +82,7 @@ class TranscriptionMixin(object):
id=element.id, id=element.id,
body={ body={
"text": text, "text": text,
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
"confidence": confidence, "confidence": confidence,
"orientation": orientation.value, "orientation": orientation.value,
}, },
...@@ -106,7 +100,7 @@ class TranscriptionMixin(object): ...@@ -106,7 +100,7 @@ class TranscriptionMixin(object):
"text": created["text"], "text": created["text"],
"confidence": created["confidence"], "confidence": created["confidence"],
"orientation": created["orientation"], "orientation": created["orientation"],
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
] ]
CachedTranscription.insert_many(to_insert).execute() CachedTranscription.insert_many(to_insert).execute()
...@@ -117,21 +111,24 @@ class TranscriptionMixin(object): ...@@ -117,21 +111,24 @@ class TranscriptionMixin(object):
return created return created
def create_transcriptions(self, transcriptions): def create_transcriptions(
self,
transcriptions: List[Dict[str, Union[str, float, Optional[TextOrientation]]]],
) -> List[Dict[str, Union[str, float]]]:
""" """
Create multiple transcriptions at once on existing elements through the API, Create multiple transcriptions at once on existing elements through the API,
and creates :class:`CachedTranscription` instances if cache support is enabled. and creates [CachedTranscription][arkindex_worker.cache.CachedTranscription] instances if cache support is enabled.
:param transcriptions: A list of dicts representing a transcription each, with the following keys: :param transcriptions: A list of dicts representing a transcription each, with the following keys:
element_id (str) element_id (str)
Required. UUID of the element to create this transcription on. Required. UUID of the element to create this transcription on.
text (str) text (str)
Required. Text of the transcription. Required. Text of the transcription.
confidence (float) confidence (float)
Required. Confidence score between 0 and 1. Required. Confidence score between 0 and 1.
orientation (:class:`TextOrientation`) orientation (TextOrientation)
Optional. Orientation of the transcription's text. Optional. Orientation of the transcription's text.
:returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint. :returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint.
""" """
...@@ -170,10 +167,16 @@ class TranscriptionMixin(object): ...@@ -170,10 +167,16 @@ class TranscriptionMixin(object):
if orientation: if orientation:
transcription["orientation"] = orientation.value transcription["orientation"] = orientation.value
if self.is_read_only:
logger.warning(
"Cannot create transcription as this worker is in read-only mode"
)
return
created_trs = self.request( created_trs = self.request(
"CreateTranscriptions", "CreateTranscriptions",
body={ body={
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
"transcriptions": transcriptions_payload, "transcriptions": transcriptions_payload,
}, },
)["transcriptions"] )["transcriptions"]
...@@ -191,7 +194,7 @@ class TranscriptionMixin(object): ...@@ -191,7 +194,7 @@ class TranscriptionMixin(object):
"text": created_tr["text"], "text": created_tr["text"],
"confidence": created_tr["confidence"], "confidence": created_tr["confidence"],
"orientation": created_tr["orientation"], "orientation": created_tr["orientation"],
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
for created_tr in created_trs for created_tr in created_trs
] ]
...@@ -203,23 +206,27 @@ class TranscriptionMixin(object): ...@@ -203,23 +206,27 @@ class TranscriptionMixin(object):
return created_trs return created_trs
def create_element_transcriptions(self, element, sub_element_type, transcriptions): def create_element_transcriptions(
self,
element: Union[Element, CachedElement],
sub_element_type: str,
transcriptions: List[Dict[str, Union[str, float]]],
) -> Dict[str, Union[str, bool]]:
""" """
Create multiple elements and transcriptions at once on a single parent element through the API. Create multiple elements and transcriptions at once on a single parent element through the API.
:param element: Element to create elements and transcriptions on. :param element: Element to create elements and transcriptions on.
:type element: Element or CachedElement :param sub_element_type: Slug of the element type to use for the new elements.
:param str sub_element_type: Slug of the element type to use for the new elements.
:param transcriptions: A list of dicts representing an element and transcription each, with the following keys: :param transcriptions: A list of dicts representing an element and transcription each, with the following keys:
polygon (list(list(int or float))) polygon (list(list(int or float)))
Required. Polygon of the element. Required. Polygon of the element.
text (str) text (str)
Required. Text of the transcription. Required. Text of the transcription.
confidence (float) confidence (float)
Required. Confidence score between 0 and 1. Required. Confidence score between 0 and 1.
orientation (:class:`TextOrientation`) orientation ([TextOrientation][arkindex_worker.worker.transcription.TextOrientation])
Optional. Orientation of the transcription's text. Optional. Orientation of the transcription's text.
:returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint. :returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint.
""" """
...@@ -282,7 +289,7 @@ class TranscriptionMixin(object): ...@@ -282,7 +289,7 @@ class TranscriptionMixin(object):
id=element.id, id=element.id,
body={ body={
"element_type": sub_element_type, "element_type": sub_element_type,
"worker_version": self.worker_version_id, "worker_run_id": self.worker_run_id,
"transcriptions": transcriptions_payload, "transcriptions": transcriptions_payload,
"return_elements": True, "return_elements": True,
}, },
...@@ -319,7 +326,7 @@ class TranscriptionMixin(object): ...@@ -319,7 +326,7 @@ class TranscriptionMixin(object):
"type": sub_element_type, "type": sub_element_type,
"image_id": element.image_id, "image_id": element.image_id,
"polygon": transcription["polygon"], "polygon": transcription["polygon"],
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
) )
...@@ -334,7 +341,7 @@ class TranscriptionMixin(object): ...@@ -334,7 +341,7 @@ class TranscriptionMixin(object):
"orientation": transcription.get( "orientation": transcription.get(
"orientation", TextOrientation.HorizontalLeftToRight "orientation", TextOrientation.HorizontalLeftToRight
).value, ).value,
"worker_version_id": self.worker_version_id, "worker_run_id": self.worker_run_id,
} }
) )
...@@ -359,16 +366,11 @@ class TranscriptionMixin(object): ...@@ -359,16 +366,11 @@ class TranscriptionMixin(object):
List transcriptions on an element. List transcriptions on an element.
:param element: The element to list transcriptions on. :param element: The element to list transcriptions on.
:type element: Element or CachedElement
:param element_type: Restrict to transcriptions whose elements have an element type with this slug. :param element_type: Restrict to transcriptions whose elements have an element type with this slug.
:type element_type: Optional[str]
:param recursive: Include transcriptions of any descendant of this element, recursively. :param recursive: Include transcriptions of any descendant of this element, recursively.
:type recursive: Optional[bool]
:param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions. :param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
:type worker_version: Optional[Union[str, bool]]
:returns: An iterable of dicts representing each transcription, :returns: An iterable of dicts representing each transcription,
or an iterable of CachedTranscription when cache support is enabled. or an iterable of CachedTranscription when cache support is enabled.
:rtype: Union[Iterable[dict], Iterable[CachedTranscription]]
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
......
...@@ -5,16 +5,12 @@ ElementsWorker methods for worker versions. ...@@ -5,16 +5,12 @@ ElementsWorker methods for worker versions.
class WorkerVersionMixin(object): class WorkerVersionMixin(object):
"""
Mixin for the :class:`ElementsWorker` to provide ``WorkerVersion`` helpers.
"""
def get_worker_version(self, worker_version_id: str) -> dict: def get_worker_version(self, worker_version_id: str) -> dict:
""" """
Retrieve a worker version, using the :class:`ElementsWorker`'s internal cache when possible. Retrieve a worker version, using the [ElementsWorker][arkindex_worker.worker.ElementsWorker]'s internal cache when possible.
:param str worker_version_id: ID of the worker version to retrieve. :param worker_version_id: ID of the worker version to retrieve.
:returns dict: The requested worker version, as returned by the ``RetrieveWorkerVersion`` API endpoint. :returns: The requested worker version, as returned by the ``RetrieveWorkerVersion`` API endpoint.
""" """
if worker_version_id is None: if worker_version_id is None:
raise ValueError("No worker version ID") raise ValueError("No worker version ID")
...@@ -32,8 +28,8 @@ class WorkerVersionMixin(object): ...@@ -32,8 +28,8 @@ class WorkerVersionMixin(object):
Retrieve the slug of the worker of a worker version, from a worker version UUID. Retrieve the slug of the worker of a worker version, from a worker version UUID.
Uses a worker version from the internal cache if possible, otherwise makes an API request. Uses a worker version from the internal cache if possible, otherwise makes an API request.
:param worker_version_id str: ID of the worker version to find a slug for. :param worker_version_id: ID of the worker version to find a slug for.
:returns str: Slug of the worker of this worker version. :returns: Slug of the worker of this worker version.
""" """
worker_version = self.get_worker_version(worker_version_id) worker_version = self.get_worker_version(worker_version_id)
return worker_version["worker"]["slug"] return worker_version["worker"]["slug"]