Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (33)
Showing
with 363 additions and 703 deletions
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.0.278 rev: v0.1.7
hooks: hooks:
# Run the linter.
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix] args: [--fix, --exit-non-zero-on-fix]
exclude: "^worker-{{cookiecutter.slug}}/" exclude: "^worker-{{cookiecutter.slug}}/"
- repo: https://github.com/psf/black-pre-commit-mirror # Run the formatter.
rev: 23.11.0 - id: ruff-format
hooks: exclude: "^worker-{{cookiecutter.slug}}/"
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.5.0
hooks: hooks:
- id: check-ast - id: check-ast
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
...@@ -26,9 +26,11 @@ repos: ...@@ -26,9 +26,11 @@ repos:
- id: name-tests-test - id: name-tests-test
args: ['--django'] args: ['--django']
- id: check-json - id: check-json
- id: check-toml
exclude: "^worker-{{cookiecutter.slug}}/pyproject.toml$"
- id: requirements-txt-fixer - id: requirements-txt-fixer
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.2.2 rev: v2.2.6
hooks: hooks:
- id: codespell - id: codespell
args: ['--write-changes'] args: ['--write-changes']
......
MIT License
Copyright (c) 2023 Teklia
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
.PHONY: release .PHONY: release
release: release:
$(eval version:=$(shell cat VERSION)) # Grep the version from pyproject.toml, squeeze multiple spaces, delete double and single quotes, get 3rd val.
echo $(version) # This command tolerates multiple whitespace sequences around the version number.
git commit VERSION -m "Version $(version)" $(eval version:=$(shell grep -m 1 version pyproject.toml | tr -s ' ' | tr -d '"' | tr -d "'" | cut -d' ' -f3))
echo Releasing version $(version)
git commit pyproject.toml -m "Version $(version)"
git tag $(version) git tag $(version)
git push origin master $(version) git push origin master $(version)
...@@ -2,6 +2,12 @@ ...@@ -2,6 +2,12 @@
An easy to use Python 3 high level API client, to build ML tasks. An easy to use Python 3 high level API client, to build ML tasks.
This is an open-source project, licensed using [the MIT license](https://opensource.org/license/mit/).
## Documentation
The [documentation](https://workers.arkindex.org/) is made with [Material for MkDocs](https://github.com/squidfunk/mkdocs-material) and is hosted by [GitLab Pages](https://docs.gitlab.com/ee/user/project/pages/).
## Create a new worker using our template ## Create a new worker using our template
``` ```
......
0.3.6-rc1
# -*- coding: utf-8 -*-
import logging import logging
logging.basicConfig( logging.basicConfig(
......
# -*- coding: utf-8 -*-
""" """
Database mappings and helper methods for the experimental worker caching feature. Database mappings and helper methods for the experimental worker caching feature.
...@@ -10,7 +9,6 @@ reducing network usage. ...@@ -10,7 +9,6 @@ reducing network usage.
import json import json
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import Optional, Union
from peewee import ( from peewee import (
SQL, SQL,
...@@ -106,8 +104,8 @@ class CachedElement(Model): ...@@ -106,8 +104,8 @@ class CachedElement(Model):
def open_image( def open_image(
self, self,
*args, *args,
max_width: Optional[int] = None, max_width: int | None = None,
max_height: Optional[int] = None, max_height: int | None = None,
**kwargs, **kwargs,
) -> Image: ) -> Image:
""" """
...@@ -145,17 +143,15 @@ class CachedElement(Model): ...@@ -145,17 +143,15 @@ class CachedElement(Model):
if max_width is None and max_height is None: if max_width is None and max_height is None:
resize = "full" resize = "full"
else: else:
# Do not resize for polygons that do not exactly match the images
# as the resize is made directly by the IIIF server using the box parameter
if ( if (
# Do not resize for polygons that do not exactly match the images
# as the resize is made directly by the IIIF server using the box parameter
bounding_box.width != self.image.width bounding_box.width != self.image.width
or bounding_box.height != self.image.height or bounding_box.height != self.image.height
): ) or (
resize = "full" # Do not resize when the image is below the maximum size
(max_width is None or self.image.width <= max_width)
# Do not resize when the image is below the maximum size and (max_height is None or self.image.height <= max_height)
elif (max_width is None or self.image.width <= max_width) and (
max_height is None or self.image.height <= max_height
): ):
resize = "full" resize = "full"
else: else:
...@@ -319,22 +315,21 @@ def create_version_table(): ...@@ -319,22 +315,21 @@ def create_version_table():
Version.create(version=SQL_VERSION) Version.create(version=SQL_VERSION)
def check_version(cache_path: Union[str, Path]): def check_version(cache_path: str | Path):
""" """
Check the validity of the SQLite version Check the validity of the SQLite version
:param cache_path: Path towards a local SQLite database :param cache_path: Path towards a local SQLite database
""" """
with SqliteDatabase(cache_path) as provided_db: with SqliteDatabase(cache_path) as provided_db, provided_db.bind_ctx([Version]):
with provided_db.bind_ctx([Version]): try:
try: version = Version.get().version
version = Version.get().version except OperationalError:
except OperationalError: version = None
version = None
assert ( assert (
version == SQL_VERSION version == SQL_VERSION
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}" ), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
def merge_parents_cache(paths: list, current_database: Path): def merge_parents_cache(paths: list, current_database: Path):
...@@ -358,9 +353,8 @@ def merge_parents_cache(paths: list, current_database: Path): ...@@ -358,9 +353,8 @@ def merge_parents_cache(paths: list, current_database: Path):
# Check that the parent cache uses a compatible version # Check that the parent cache uses a compatible version
check_version(path) check_version(path)
with SqliteDatabase(path) as source: with SqliteDatabase(path) as source, source.bind_ctx(MODELS):
with source.bind_ctx(MODELS): source.create_tables(MODELS)
source.create_tables(MODELS)
logger.info(f"Merging parent db {path} into {current_database}") logger.info(f"Merging parent db {path} into {current_database}")
statements = [ statements = [
......
# -*- coding: utf-8 -*-
"""
Helper classes for workers that interact with Git repositories and the GitLab API.
"""
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, Union
import gitlab
import requests
import sh
from gitlab.v4.objects import MergeRequest, ProjectMergeRequest
from arkindex_worker import logger
NOTHING_TO_COMMIT_MSG = "nothing to commit, working tree clean"
MR_HAS_CONFLICTS_ERROR_CODE = 406
class GitlabHelper:
"""Helper class to save files to GitLab repository"""
def __init__(
self,
project_id: str,
gitlab_url: str,
gitlab_token: str,
branch: str,
rebase_wait_period: Optional[int] = 1,
delete_source_branch: Optional[bool] = True,
max_rebase_tries: Optional[int] = 10,
):
"""
:param project_id: the id of the gitlab project
:param gitlab_url: gitlab server url
:param gitlab_token: gitlab private token of user with permission to accept merge requests
:param branch: name of the branch to where the exported branch will be merged
:param rebase_wait_period: seconds to wait between each poll to check whether rebase has finished
:param delete_source_branch: should delete the source branch after merging?
:param max_rebase_tries: max number of tries to rebase when merging before giving up
"""
self.project_id = project_id
self.gitlab_url = gitlab_url
self.gitlab_token = str(gitlab_token).strip()
self.branch = branch
self.rebase_wait_period = rebase_wait_period
self.delete_source_branch = delete_source_branch
self.max_rebase_tries = max_rebase_tries
logger.info("Creating a Gitlab client")
self._api = gitlab.Gitlab(self.gitlab_url, private_token=self.gitlab_token)
self.project = self._api.projects.get(self.project_id)
self.is_rebase_finished = False
def merge(self, branch_name: str, title: str) -> bool:
"""
Create a merge request and try to merge.
Always rebase first to avoid conflicts from MRs made in parallel
:param branch_name: Source branch name
:param title: Title of the merge request
:return: Whether the branch was successfully merged
"""
mr = None
# always rebase first, because other workers might have merged already
for i in range(self.max_rebase_tries):
logger.info(f"Trying to merge, try nr: {i}")
try:
if mr is None:
mr = self._create_merge_request(branch_name, title)
mr.rebase()
rebase_success = self._wait_for_rebase_to_finish(mr.iid)
if not rebase_success:
logger.error("Rebase failed, won't be able to merge!")
return False
mr.merge(should_remove_source_branch=self.delete_source_branch)
logger.info("Merge successful")
return True
except gitlab.GitlabMRClosedError as e:
if e.response_code == MR_HAS_CONFLICTS_ERROR_CODE:
logger.info("Merge failed, trying to rebase and merge again.")
continue
else:
logger.error(f"Merge was not successful: {e}")
return False
except gitlab.GitlabError as e:
logger.error(f"Gitlab error: {e}")
if 400 <= e.response_code < 500:
# 4XX errors shouldn't be fixed by retrying
raise e
except requests.exceptions.ConnectionError as e:
logger.error(f"Server connection error, will wait and retry: {e}")
time.sleep(self.rebase_wait_period)
return False
def _create_merge_request(self, branch_name: str, title: str) -> MergeRequest:
"""
Create a MergeRequest towards the branch with the given title
:param branch_name: Target branch of the merge request
:param title: Title of the merge request
:return: The created merge request
"""
logger.info(f"Creating a merge request for {branch_name}")
# retry_transient_error will retry the request on 50X errors
# https://github.com/python-gitlab/python-gitlab/blob/265dbbdd37af88395574564aeb3fd0350288a18c/gitlab/__init__.py#L539
mr = self.project.mergerequests.create(
{
"source_branch": branch_name,
"target_branch": self.branch,
"title": title,
},
)
return mr
def _get_merge_request(
self, merge_request_id: Union[str, int], include_rebase_in_progress: bool = True
) -> ProjectMergeRequest:
"""
Retrieve a merge request by ID
:param merge_request_id: The ID of the merge request
:param include_rebase_in_progress: Whether the rebase in progree should be included
:return: The related merge request
"""
return self.project.mergerequests.get(
merge_request_id, include_rebase_in_progress=include_rebase_in_progress
)
def _wait_for_rebase_to_finish(self, merge_request_id: Union[str, int]) -> bool:
"""
Poll the merge request until it has finished rebasing
:param merge_request_id: The ID of the merge request
:return: Whether the rebase has finished successfully
"""
logger.info("Checking if rebase has finished..")
self.is_rebase_finished = False
while not self.is_rebase_finished:
time.sleep(self.rebase_wait_period)
mr = self._get_merge_request(merge_request_id)
self.is_rebase_finished = not mr.rebase_in_progress
if mr.merge_error is None:
logger.info("Rebase has finished")
return True
logger.error(f"Rebase failed: {mr.merge_error}")
return False
def make_backup(path: str):
"""
Create a backup file in the same directory with timestamp as suffix ".bak_{timestamp}"
:param path: Path to the file to be backed up
"""
path = Path(path)
if not path.exists():
raise ValueError(f"No file to backup! File not found: {path}")
# timestamp with milliseconds
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
backup_path = Path(str(path) + f".bak_{timestamp}")
shutil.copy(path, backup_path)
logger.info(f"Made a backup {backup_path}")
def prepare_git_key(
private_key: str,
known_hosts: str,
private_key_path: Optional[str] = "~/.ssh/id_ed25519",
known_hosts_path: Optional[str] = "~/.ssh/known_hosts",
):
"""
Prepare the git keys (put them in to the correct place) so that git could be used.
Fixes some whitespace problems that come from arkindex secrets store (Django admin).
Also creates a backup of the previous keys if they exist, to avoid losing the
original keys of the developers.
:param private_key: git private key contents
:param known_hosts: git known_hosts contents
:param private_key_path: path where to put the private key
:param known_hosts_path: path where to put the known_hosts
"""
# secrets admin UI seems to strip the trailing whitespace
# but git requires the key file to have a new line at the end
# for some reason uses CRLF line endings, but git doesn't like that
private_key = private_key.replace("\r", "") + "\n"
known_hosts = known_hosts.replace("\r", "") + "\n"
private_key_path = Path(private_key_path).expanduser()
known_hosts_path = Path(known_hosts_path).expanduser()
if private_key_path.exists():
if private_key_path.read_text() != private_key:
make_backup(private_key_path)
if known_hosts_path.exists():
if known_hosts_path.read_text() != known_hosts:
make_backup(known_hosts_path)
private_key_path.write_text(private_key)
# private key must be private, otherwise git will fail
# expecting octal for permissions
private_key_path.chmod(0o600)
known_hosts_path.write_text(known_hosts)
logger.info(f"Private key size after: {private_key_path.stat().st_size}")
logger.info(f"Known size after: {known_hosts_path.stat().st_size}")
class GitHelper:
"""
A helper class for running git commands
At the beginning of the workflow call [run_clone_in_background][arkindex_worker.git.GitHelper.run_clone_in_background].
When all the files are ready to be added to git then call
[save_files][arkindex_worker.git.GitHelper.save_files] to move the files in to the git repository
and try to push them.
Examples
--------
in worker.configure() configure the git helper and start the cloning:
```
gitlab = GitlabHelper(...)
prepare_git_key(...)
self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...)
self.git_helper.run_clone_in_background()
```
at the end of the workflow (at the end of worker.run()) push the files to git:
```
self.git_helper.save_files(self.out_dir)
```
"""
def __init__(
self,
repo_url,
git_dir,
export_path,
workflow_id,
gitlab_helper: GitlabHelper,
git_clone_wait_period=1,
):
"""
:param repo_url: the url of the git repository where the export will be pushed
:param git_dir: the directory where to clone the git repository
:param export_path: the path inside the git repository where to put the exported files
:param workflow_id: the process id to see the workflow graph in the frontend
:param gitlab_helper: helper for gitlab
:param git_clone_wait_period: check if clone has finished every N seconds at the end of the workflow
"""
logger.info("Creating git helper")
self.repo_url = repo_url
self.git_dir = Path(git_dir)
self.export_path = self.git_dir / export_path
self.workflow_id = workflow_id
self.gitlab_helper = gitlab_helper
self.git_clone_wait_period = git_clone_wait_period
self.is_clone_finished = False
self.cmd = None
self.success = None
self.exit_code = None
self.git_dir.mkdir(parents=True, exist_ok=True)
# run git commands outside of the repository (no need to change dir)
self._git = sh.git.bake("-C", self.git_dir)
def _clone_done(self, cmd, success, exit_code):
"""
Method that is called when git clone has finished in the background
"""
logger.info("Finishing cloning")
self.cmd = cmd
self.success = success
self.exit_code = exit_code
self.is_clone_finished = True
if not success:
logger.error(f"Clone failed: {cmd} : {success} : {exit_code}")
logger.info("Cloning finished")
def run_clone_in_background(self):
"""
Clones the git repository in the background in to the self.git_dir directory.
`self.is_clone_finished` can be used to know whether the cloning has finished
or not.
"""
logger.info(f"Starting clone {self.repo_url} in background")
cmd = sh.git.clone(
self.repo_url, self.git_dir, _bg=True, _done=self._clone_done
)
logger.info(f"Continuing clone {self.repo_url} in background")
return cmd
def _wait_for_clone_to_finish(self):
logger.info("Checking if cloning has finished..")
while not self.is_clone_finished:
time.sleep(self.git_clone_wait_period)
logger.info("Cloning has finished")
if not self.success:
logger.error("Clone was not a success")
logger.error(f"Clone error exit code: {str(self.exit_code)}")
raise ValueError("Clone was not a success")
def save_files(self, export_out_dir: Path):
"""
Move files in export_out_dir to the cloned git repository
and try to merge the created files if possible.
:param export_out_dir: Path to the files to be saved
:raises sh.ErrorReturnCode: _description_
:raises Exception: _description_
"""
self._wait_for_clone_to_finish()
# move exported files to git directory
file_count = self._move_files_to_git(export_out_dir)
# use timestamp to avoid branch name conflicts with multiple chunks
current_timestamp = datetime.isoformat(datetime.now())
# ":" is not allowed in a branch name
branch_timestamp = current_timestamp.replace(":", ".")
# add files to a new branch
branch_name = f"workflow_{self.workflow_id}_{branch_timestamp}"
self._git.checkout("-b", branch_name)
self._git.add("-A")
try:
self._git.commit(
"-m",
f"Exported files from workflow: {self.workflow_id} at {current_timestamp}",
)
except sh.ErrorReturnCode as e:
if NOTHING_TO_COMMIT_MSG in str(e.stdout):
logger.warning("Nothing to commit (no changes)")
return
else:
logger.error(f"Commit failed:: {e}")
raise e
# count the number of lines in the output
wc_cmd_out = str(
sh.wc(self._git.show("--stat", "--name-status", "--oneline", "HEAD"), "-l")
)
# -1 because the of the git command header
files_committed = int(wc_cmd_out.strip()) - 1
logger.info(f"Committed {files_committed} files")
if file_count != files_committed:
logger.warning(
f"Of {file_count} added files only {files_committed} were committed"
)
self._git.push("-u", "origin", "HEAD")
if self.gitlab_helper:
try:
self.gitlab_helper.merge(branch_name, f"Merge {branch_name}")
except Exception as e:
logger.error(f"Merge failed: {e}")
raise e
else:
logger.info(
"No gitlab_helper defined, not trying to merge the pushed branch"
)
def _move_files_to_git(self, export_out_dir: Path) -> int:
"""
Move all files in the export_out_dir to the git repository
while keeping the same directory structure
:param export_out_dir: Path to the files to be moved
:return: Total count of moved files
"""
file_count = 0
file_names = [
file_name for file_name in export_out_dir.rglob("*") if file_name.is_file()
]
for file in file_names:
rel_file_path = file.relative_to(export_out_dir)
out_file = self.export_path / rel_file_path
if not out_file.exists():
out_file.parent.mkdir(parents=True, exist_ok=True)
# rename does not work if the source and destination are not on the same mounts
# it will give an error: "OSError: [Errno 18] Invalid cross-device link:"
shutil.copy(file, out_file)
file.unlink()
file_count += 1
logger.info(f"Moved {file_count} files")
return file_count
# -*- coding: utf-8 -*-
""" """
Helper methods to download and open IIIF images, and manage polygons. Helper methods to download and open IIIF images, and manage polygons.
""" """
...@@ -7,7 +6,7 @@ from collections import namedtuple ...@@ -7,7 +6,7 @@ from collections import namedtuple
from io import BytesIO from io import BytesIO
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING
import requests import requests
from PIL import Image from PIL import Image
...@@ -42,9 +41,9 @@ IIIF_MAX = "max" ...@@ -42,9 +41,9 @@ IIIF_MAX = "max"
def open_image( def open_image(
path: str, path: str,
mode: Optional[str] = "RGB", mode: str | None = "RGB",
rotation_angle: Optional[int] = 0, rotation_angle: int | None = 0,
mirrored: Optional[bool] = False, mirrored: bool | None = False,
) -> Image: ) -> Image:
""" """
Open an image from a path or a URL. Open an image from a path or a URL.
...@@ -71,7 +70,7 @@ def open_image( ...@@ -71,7 +70,7 @@ def open_image(
else: else:
try: try:
image = Image.open(path) image = Image.open(path)
except (IOError, ValueError): except (OSError, ValueError):
image = download_image(path) image = download_image(path)
if image.mode != mode: if image.mode != mode:
...@@ -141,14 +140,14 @@ def download_image(url: str) -> Image: ...@@ -141,14 +140,14 @@ def download_image(url: str) -> Image:
return image return image
def polygon_bounding_box(polygon: List[List[Union[int, float]]]) -> BoundingBox: def polygon_bounding_box(polygon: list[list[int | float]]) -> BoundingBox:
""" """
Compute the rectangle bounding box of a polygon. Compute the rectangle bounding box of a polygon.
:param polygon: Polygon to get the bounding box of. :param polygon: Polygon to get the bounding box of.
:returns: Bounding box of this polygon. :returns: Bounding box of this polygon.
""" """
x_coords, y_coords = zip(*polygon) x_coords, y_coords = zip(*polygon, strict=True)
x, y = min(x_coords), min(y_coords) x, y = min(x_coords), min(y_coords)
width, height = max(x_coords) - x, max(y_coords) - y width, height = max(x_coords) - x, max(y_coords) - y
return BoundingBox(x, y, width, height) return BoundingBox(x, y, width, height)
...@@ -248,8 +247,8 @@ def download_tiles(url: str) -> Image: ...@@ -248,8 +247,8 @@ def download_tiles(url: str) -> Image:
def trim_polygon( def trim_polygon(
polygon: List[List[int]], image_width: int, image_height: int polygon: list[list[int]], image_width: int, image_height: int
) -> List[List[int]]: ) -> list[list[int]]:
""" """
Trim a polygon to an image's boundaries, with non-negative coordinates. Trim a polygon to an image's boundaries, with non-negative coordinates.
...@@ -265,10 +264,10 @@ def trim_polygon( ...@@ -265,10 +264,10 @@ def trim_polygon(
""" """
assert isinstance( assert isinstance(
polygon, (list, tuple) polygon, list | tuple
), "Input polygon must be a valid list or tuple of points." ), "Input polygon must be a valid list or tuple of points."
assert all( assert all(
isinstance(point, (list, tuple)) for point in polygon isinstance(point, list | tuple) for point in polygon
), "Polygon points must be tuples or lists." ), "Polygon points must be tuples or lists."
assert all( assert all(
len(point) == 2 for point in polygon len(point) == 2 for point in polygon
...@@ -301,10 +300,10 @@ def trim_polygon( ...@@ -301,10 +300,10 @@ def trim_polygon(
def revert_orientation( def revert_orientation(
element: Union["Element", "CachedElement"], element: "Element | CachedElement",
polygon: List[List[Union[int, float]]], polygon: list[list[int | float]],
reverse: Optional[bool] = False, reverse: bool = False,
) -> List[List[int]]: ) -> list[list[int]]:
""" """
Update the coordinates of the polygon of a child element based on the orientation of Update the coordinates of the polygon of a child element based on the orientation of
its parent. its parent.
...@@ -324,7 +323,7 @@ def revert_orientation( ...@@ -324,7 +323,7 @@ def revert_orientation(
from arkindex_worker.models import Element from arkindex_worker.models import Element
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
assert polygon and isinstance( assert polygon and isinstance(
polygon, list polygon, list
......
# -*- coding: utf-8 -*-
""" """
Wrappers around API results to provide more convenient attribute access and IIIF helpers. Wrappers around API results to provide more convenient attribute access and IIIF helpers.
""" """
import tempfile import tempfile
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator, List, Optional
from PIL import Image from PIL import Image
from requests import HTTPError from requests import HTTPError
...@@ -34,10 +33,10 @@ class MagicDict(dict): ...@@ -34,10 +33,10 @@ class MagicDict(dict):
def __getattr__(self, name): def __getattr__(self, name):
try: try:
return self[name] return self[name]
except KeyError: except KeyError as e:
raise AttributeError( raise AttributeError(
"{} object has no attribute '{}'".format(self.__class__.__name__, name) f"{self.__class__.__name__} object has no attribute '{name}'"
) ) from e
def __setattr__(self, name, value): def __setattr__(self, name, value):
return super().__setitem__(name, value) return super().__setitem__(name, value)
...@@ -74,7 +73,7 @@ class Element(MagicDict): ...@@ -74,7 +73,7 @@ class Element(MagicDict):
parts[-3] = size parts[-3] = size
return "/".join(parts) return "/".join(parts)
def image_url(self, size: str = "full") -> Optional[str]: def image_url(self, size: str = "full") -> str | None:
""" """
Build an URL to access the image. Build an URL to access the image.
When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers. When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers.
...@@ -89,10 +88,10 @@ class Element(MagicDict): ...@@ -89,10 +88,10 @@ class Element(MagicDict):
url = self.zone.image.url url = self.zone.image.url
if not url.endswith("/"): if not url.endswith("/"):
url += "/" url += "/"
return "{}full/{}/0/default.jpg".format(url, size) return f"{url}full/{size}/0/default.jpg"
@property @property
def polygon(self) -> List[float]: def polygon(self) -> list[float]:
""" """
Access an Element's polygon. Access an Element's polygon.
This is a shortcut to an Element's polygon, normally accessed via This is a shortcut to an Element's polygon, normally accessed via
...@@ -101,7 +100,7 @@ class Element(MagicDict): ...@@ -101,7 +100,7 @@ class Element(MagicDict):
the [CachedElement][arkindex_worker.cache.CachedElement].polygon field. the [CachedElement][arkindex_worker.cache.CachedElement].polygon field.
""" """
if not self.get("zone"): if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id)) raise ValueError(f"Element {self.id} has no zone")
return self.zone.polygon return self.zone.polygon
@property @property
...@@ -122,11 +121,11 @@ class Element(MagicDict): ...@@ -122,11 +121,11 @@ class Element(MagicDict):
def open_image( def open_image(
self, self,
*args, *args,
max_width: Optional[int] = None, max_width: int | None = None,
max_height: Optional[int] = None, max_height: int | None = None,
use_full_image: Optional[bool] = False, use_full_image: bool | None = False,
**kwargs, **kwargs,
) -> Image: ) -> Image.Image:
""" """
Open this element's image using Pillow, rotating and mirroring it according Open this element's image using Pillow, rotating and mirroring it according
to the ``rotation_angle`` and ``mirrored`` attributes. to the ``rotation_angle`` and ``mirrored`` attributes.
...@@ -173,7 +172,7 @@ class Element(MagicDict): ...@@ -173,7 +172,7 @@ class Element(MagicDict):
) )
if not self.get("zone"): if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id)) raise ValueError(f"Element {self.id} has no zone")
if self.requires_tiles: if self.requires_tiles:
if max_width is None and max_height is None: if max_width is None and max_height is None:
...@@ -194,10 +193,7 @@ class Element(MagicDict): ...@@ -194,10 +193,7 @@ class Element(MagicDict):
else: else:
resize = f"{max_width or ''},{max_height or ''}" resize = f"{max_width or ''},{max_height or ''}"
if use_full_image: url = self.image_url(resize) if use_full_image else self.resize_zone_url(resize)
url = self.image_url(resize)
else:
url = self.resize_zone_url(resize)
try: try:
return open_image( return open_image(
...@@ -215,13 +211,13 @@ class Element(MagicDict): ...@@ -215,13 +211,13 @@ class Element(MagicDict):
# This element uses an S3 URL: the URL may have expired. # This element uses an S3 URL: the URL may have expired.
# Call the API to get a fresh URL and try again # Call the API to get a fresh URL and try again
# TODO: this should be done by the worker # TODO: this should be done by the worker
raise NotImplementedError raise NotImplementedError from e
return open_image(self.image_url(resize), *args, **kwargs) return open_image(self.image_url(resize), *args, **kwargs)
raise raise
@contextmanager @contextmanager
def open_image_tempfile( def open_image_tempfile(
self, format: Optional[str] = "jpeg", *args, **kwargs self, format: str | None = "jpeg", *args, **kwargs
) -> Generator[tempfile.NamedTemporaryFile, None, None]: ) -> Generator[tempfile.NamedTemporaryFile, None, None]:
""" """
Get the element's image as a temporary file stored on the disk. Get the element's image as a temporary file stored on the disk.
...@@ -249,7 +245,7 @@ class Element(MagicDict): ...@@ -249,7 +245,7 @@ class Element(MagicDict):
type_name = self.type["display_name"] type_name = self.type["display_name"]
else: else:
type_name = str(self.type) type_name = str(self.type)
return "{} {} ({})".format(type_name, self.name, self.id) return f"{type_name} {self.name} ({self.id})"
class ArkindexModel(MagicDict): class ArkindexModel(MagicDict):
......
# -*- coding: utf-8 -*-
import hashlib import hashlib
import logging import logging
import os import os
import tarfile import tarfile
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union
import zstandard import zstandard
import zstandard as zstd import zstandard as zstd
...@@ -16,7 +14,7 @@ CHUNK_SIZE = 1024 ...@@ -16,7 +14,7 @@ CHUNK_SIZE = 1024
"""Chunk Size used for ZSTD compression""" """Chunk Size used for ZSTD compression"""
def decompress_zst_archive(compressed_archive: Path) -> Tuple[int, Path]: def decompress_zst_archive(compressed_archive: Path) -> tuple[int, Path]:
""" """
Decompress a ZST-compressed tar archive in data dir. The tar archive is not extracted. Decompress a ZST-compressed tar archive in data dir. The tar archive is not extracted.
This returns the path to the archive and the file descriptor. This returns the path to the archive and the file descriptor.
...@@ -29,18 +27,19 @@ def decompress_zst_archive(compressed_archive: Path) -> Tuple[int, Path]: ...@@ -29,18 +27,19 @@ def decompress_zst_archive(compressed_archive: Path) -> Tuple[int, Path]:
""" """
dctx = zstandard.ZstdDecompressor() dctx = zstandard.ZstdDecompressor()
archive_fd, archive_path = tempfile.mkstemp(suffix=".tar") archive_fd, archive_path = tempfile.mkstemp(suffix=".tar")
archive_path = Path(archive_path)
logger.debug(f"Uncompressing file to {archive_path}") logger.debug(f"Uncompressing file to {archive_path}")
try: try:
with open(compressed_archive, "rb") as compressed, open( with compressed_archive.open("rb") as compressed, archive_path.open(
archive_path, "wb" "wb"
) as decompressed: ) as decompressed:
dctx.copy_stream(compressed, decompressed) dctx.copy_stream(compressed, decompressed)
logger.debug(f"Successfully uncompressed archive {compressed_archive}") logger.debug(f"Successfully uncompressed archive {compressed_archive}")
except zstandard.ZstdError as e: except zstandard.ZstdError as e:
raise Exception(f"Couldn't uncompressed archive: {e}") raise Exception(f"Couldn't uncompressed archive: {e}") from e
return archive_fd, Path(archive_path) return archive_fd, archive_path
def extract_tar_archive(archive_path: Path, destination: Path): def extract_tar_archive(archive_path: Path, destination: Path):
...@@ -54,12 +53,12 @@ def extract_tar_archive(archive_path: Path, destination: Path): ...@@ -54,12 +53,12 @@ def extract_tar_archive(archive_path: Path, destination: Path):
with tarfile.open(archive_path) as tar_archive: with tarfile.open(archive_path) as tar_archive:
tar_archive.extractall(destination) tar_archive.extractall(destination)
except tarfile.ReadError as e: except tarfile.ReadError as e:
raise Exception(f"Couldn't handle the decompressed Tar archive: {e}") raise Exception(f"Couldn't handle the decompressed Tar archive: {e}") from e
def extract_tar_zst_archive( def extract_tar_zst_archive(
compressed_archive: Path, destination: Path compressed_archive: Path, destination: Path
) -> Tuple[int, Path]: ) -> tuple[int, Path]:
""" """
Extract a ZST-compressed tar archive's content to a specific destination Extract a ZST-compressed tar archive's content to a specific destination
...@@ -89,8 +88,8 @@ def close_delete_file(file_descriptor: int, file_path: Path): ...@@ -89,8 +88,8 @@ def close_delete_file(file_descriptor: int, file_path: Path):
def zstd_compress( def zstd_compress(
source: Path, destination: Optional[Path] = None source: Path, destination: Path | None = None
) -> Tuple[Union[int, None], Path, str]: ) -> tuple[int | None, Path, str]:
"""Compress a file using the Zstandard compression algorithm. """Compress a file using the Zstandard compression algorithm.
:param source: Path to the file to compress. :param source: Path to the file to compress.
...@@ -117,13 +116,13 @@ def zstd_compress( ...@@ -117,13 +116,13 @@ def zstd_compress(
archive_file.write(compressed_chunk) archive_file.write(compressed_chunk)
logger.debug(f"Successfully compressed {source}") logger.debug(f"Successfully compressed {source}")
except zstandard.ZstdError as e: except zstandard.ZstdError as e:
raise Exception(f"Couldn't compress archive: {e}") raise Exception(f"Couldn't compress archive: {e}") from e
return file_d, destination, archive_hasher.hexdigest() return file_d, destination, archive_hasher.hexdigest()
def create_tar_archive( def create_tar_archive(
path: Path, destination: Optional[Path] = None path: Path, destination: Path | None = None
) -> Tuple[Union[int, None], Path, str]: ) -> tuple[int | None, Path, str]:
"""Create a tar archive using the content at specified location. """Create a tar archive using the content at specified location.
:param path: Path to the file to archive :param path: Path to the file to archive
...@@ -153,7 +152,7 @@ def create_tar_archive( ...@@ -153,7 +152,7 @@ def create_tar_archive(
files.append(p) files.append(p)
logger.debug(f"Successfully created Tar archive from files @ {path}") logger.debug(f"Successfully created Tar archive from files @ {path}")
except tarfile.TarError as e: except tarfile.TarError as e:
raise Exception(f"Couldn't create Tar archive: {e}") raise Exception(f"Couldn't create Tar archive: {e}") from e
# Sort by path # Sort by path
files.sort() files.sort()
...@@ -168,8 +167,8 @@ def create_tar_archive( ...@@ -168,8 +167,8 @@ def create_tar_archive(
def create_tar_zst_archive( def create_tar_zst_archive(
source: Path, destination: Optional[Path] = None source: Path, destination: Path | None = None
) -> Tuple[Union[int, None], Path, str, str]: ) -> tuple[int | None, Path, str, str]:
"""Helper to create a TAR+ZST archive from a source folder. """Helper to create a TAR+ZST archive from a source folder.
:param source: Path to the folder whose content should be archived. :param source: Path to the folder whose content should be archived.
......
# -*- coding: utf-8 -*-
""" """
Base classes to implement Arkindex workers. Base classes to implement Arkindex workers.
""" """
import contextlib
import json import json
import os import os
import sys import sys
import uuid import uuid
from collections.abc import Iterable, Iterator
from enum import Enum from enum import Enum
from itertools import groupby from itertools import groupby
from operator import itemgetter from operator import itemgetter
from pathlib import Path from pathlib import Path
from typing import Iterable, Iterator, List, Tuple, Union
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
...@@ -102,7 +101,7 @@ class ElementsWorker( ...@@ -102,7 +101,7 @@ class ElementsWorker(
self._worker_version_cache = {} self._worker_version_cache = {}
def list_elements(self) -> Union[Iterable[CachedElement], List[str]]: def list_elements(self) -> Iterable[CachedElement] | list[str]:
""" """
List the elements to be processed, either from the CLI arguments or List the elements to be processed, either from the CLI arguments or
the cache database when enabled. the cache database when enabled.
...@@ -227,21 +226,17 @@ class ElementsWorker( ...@@ -227,21 +226,17 @@ class ElementsWorker(
) )
if element: if element:
# Try to update the activity to error state regardless of the response # Try to update the activity to error state regardless of the response
try: with contextlib.suppress(Exception):
self.update_activity(element.id, ActivityState.Error) self.update_activity(element.id, ActivityState.Error)
except Exception:
pass
if failed: if failed:
logger.error( logger.error(
"Ran on {} elements: {} completed, {} failed".format( f"Ran on {count} elements: {count - failed} completed, {failed} failed"
count, count - failed, failed
)
) )
if failed >= count: # Everything failed! if failed >= count: # Everything failed!
sys.exit(1) sys.exit(1)
def process_element(self, element: Union[Element, CachedElement]): def process_element(self, element: Element | CachedElement):
""" """
Override this method to implement your worker and process a single Arkindex element at once. Override this method to implement your worker and process a single Arkindex element at once.
...@@ -251,7 +246,7 @@ class ElementsWorker( ...@@ -251,7 +246,7 @@ class ElementsWorker(
""" """
def update_activity( def update_activity(
self, element_id: Union[str, uuid.UUID], state: ActivityState self, element_id: str | uuid.UUID, state: ActivityState
) -> bool: ) -> bool:
""" """
Update the WorkerActivity for this element and worker. Update the WorkerActivity for this element and worker.
...@@ -269,7 +264,7 @@ class ElementsWorker( ...@@ -269,7 +264,7 @@ class ElementsWorker(
return True return True
assert element_id and isinstance( assert element_id and isinstance(
element_id, (uuid.UUID, str) element_id, uuid.UUID | str
), "element_id shouldn't be null and should be an UUID or str" ), "element_id shouldn't be null and should be an UUID or str"
assert isinstance(state, ActivityState), "state should be an ActivityState" assert isinstance(state, ActivityState), "state should be an ActivityState"
...@@ -382,7 +377,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ...@@ -382,7 +377,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
def list_dataset_elements_per_split( def list_dataset_elements_per_split(
self, dataset: Dataset self, dataset: Dataset
) -> Iterator[Tuple[str, List[Element]]]: ) -> Iterator[tuple[str, list[Element]]]:
""" """
List the elements in the dataset, grouped by split, using the List the elements in the dataset, grouped by split, using the
[list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method. [list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method.
...@@ -392,8 +387,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ...@@ -392,8 +387,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
""" """
def format_split( def format_split(
split: Tuple[str, Iterator[Tuple[str, Element]]] split: tuple[str, Iterator[tuple[str, Element]]],
) -> Tuple[str, List[Element]]: ) -> tuple[str, list[Element]]:
return (split[0], list(map(itemgetter(1), list(split[1])))) return (split[0], list(map(itemgetter(1), list(split[1]))))
return map( return map(
...@@ -435,7 +430,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ...@@ -435,7 +430,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
""" """
self.configure() self.configure()
datasets: List[Dataset] | List[str] = list(self.list_datasets()) datasets: list[Dataset] | list[str] = list(self.list_datasets())
if not datasets: if not datasets:
logger.warning("No datasets to process, stopping.") logger.warning("No datasets to process, stopping.")
sys.exit(1) sys.exit(1)
...@@ -445,6 +440,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ...@@ -445,6 +440,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
failed = 0 failed = 0
for i, item in enumerate(datasets, start=1): for i, item in enumerate(datasets, start=1):
dataset = None dataset = None
dataset_artifact = None
try: try:
if not self.is_read_only: if not self.is_read_only:
# Just use the result of list_datasets as the dataset # Just use the result of list_datasets as the dataset
...@@ -470,7 +467,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ...@@ -470,7 +467,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
self.update_dataset_state(dataset, DatasetState.Building) self.update_dataset_state(dataset, DatasetState.Building)
else: else:
logger.info(f"Downloading data for {dataset} ({i}/{count})") logger.info(f"Downloading data for {dataset} ({i}/{count})")
self.download_dataset_artifact(dataset) dataset_artifact = self.download_dataset_artifact(dataset)
# Process the dataset # Process the dataset
self.process_dataset(dataset) self.process_dataset(dataset)
...@@ -499,16 +496,16 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ...@@ -499,16 +496,16 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
) )
if dataset and self.generator: if dataset and self.generator:
# Try to update the state to Error regardless of the response # Try to update the state to Error regardless of the response
try: with contextlib.suppress(Exception):
self.update_dataset_state(dataset, DatasetState.Error) self.update_dataset_state(dataset, DatasetState.Error)
except Exception: finally:
pass # Cleanup the dataset artifact if it was downloaded, no matter what
if dataset_artifact:
dataset_artifact.unlink(missing_ok=True)
if failed: if failed:
logger.error( logger.error(
"Ran on {} datasets: {} completed, {} failed".format( f"Ran on {count} datasets: {count - failed} completed, {failed} failed"
count, count - failed, failed
)
) )
if failed >= count: # Everything failed! if failed >= count: # Everything failed!
sys.exit(1) sys.exit(1)
# -*- coding: utf-8 -*-
""" """
The base class for all Arkindex workers. The base class for all Arkindex workers.
""" """
...@@ -9,7 +8,6 @@ import os ...@@ -9,7 +8,6 @@ import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import List, Optional
import gnupg import gnupg
import yaml import yaml
...@@ -52,15 +50,15 @@ class ExtrasDirNotFoundError(Exception): ...@@ -52,15 +50,15 @@ class ExtrasDirNotFoundError(Exception):
""" """
class BaseWorker(object): class BaseWorker:
""" """
Base class for Arkindex workers. Base class for Arkindex workers.
""" """
def __init__( def __init__(
self, self,
description: Optional[str] = "Arkindex Base Worker", description: str | None = "Arkindex Base Worker",
support_cache: Optional[bool] = False, support_cache: bool | None = False,
): ):
""" """
Initialize the worker. Initialize the worker.
...@@ -353,7 +351,8 @@ class BaseWorker(object): ...@@ -353,7 +351,8 @@ class BaseWorker(object):
try: try:
gpg = gnupg.GPG() gpg = gnupg.GPG()
decrypted = gpg.decrypt_file(open(path, "rb")) with path.open("rb") as gpg_file:
decrypted = gpg.decrypt_file(gpg_file)
assert ( assert (
decrypted.ok decrypted.ok
), f"GPG error: {decrypted.status} - {decrypted.stderr}" ), f"GPG error: {decrypted.status} - {decrypted.stderr}"
...@@ -412,7 +411,7 @@ class BaseWorker(object): ...@@ -412,7 +411,7 @@ class BaseWorker(object):
) )
return extras_dir return extras_dir
def find_parents_file_paths(self, filename: Path) -> List[Path]: def find_parents_file_paths(self, filename: Path) -> list[Path]:
""" """
Find the paths of a specific file from the parent tasks. Find the paths of a specific file from the parent tasks.
Only works if the task_parents attributes is updated, so if the cache is supported, Only works if the task_parents attributes is updated, so if the cache is supported,
......
# -*- coding: utf-8 -*-
""" """
ElementsWorker methods for classifications and ML classes. ElementsWorker methods for classifications and ML classes.
""" """
from typing import Dict, List, Optional, Union
from uuid import UUID from uuid import UUID
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
...@@ -14,7 +12,7 @@ from arkindex_worker.cache import CachedClassification, CachedElement ...@@ -14,7 +12,7 @@ from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element from arkindex_worker.models import Element
class ClassificationMixin(object): class ClassificationMixin:
def load_corpus_classes(self): def load_corpus_classes(self):
""" """
Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache. Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache.
...@@ -91,11 +89,11 @@ class ClassificationMixin(object): ...@@ -91,11 +89,11 @@ class ClassificationMixin(object):
def create_classification( def create_classification(
self, self,
element: Union[Element, CachedElement], element: Element | CachedElement,
ml_class: str, ml_class: str,
confidence: float, confidence: float,
high_confidence: Optional[bool] = False, high_confidence: bool = False,
) -> Dict[str, str]: ) -> dict[str, str]:
""" """
Create a classification on the given element through the API. Create a classification on the given element through the API.
...@@ -106,7 +104,7 @@ class ClassificationMixin(object): ...@@ -106,7 +104,7 @@ class ClassificationMixin(object):
:returns: The created classification, as returned by the ``CreateClassification`` API endpoint. :returns: The created classification, as returned by the ``CreateClassification`` API endpoint.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
assert ml_class and isinstance( assert ml_class and isinstance(
ml_class, str ml_class, str
...@@ -180,9 +178,9 @@ class ClassificationMixin(object): ...@@ -180,9 +178,9 @@ class ClassificationMixin(object):
def create_classifications( def create_classifications(
self, self,
element: Union[Element, CachedElement], element: Element | CachedElement,
classifications: List[Dict[str, Union[str, float, bool]]], classifications: list[dict[str, str | float | bool]],
) -> List[Dict[str, Union[str, float, bool]]]: ) -> list[dict[str, str | float | bool]]:
""" """
Create multiple classifications at once on the given element through the API. Create multiple classifications at once on the given element through the API.
...@@ -196,7 +194,7 @@ class ClassificationMixin(object): ...@@ -196,7 +194,7 @@ class ClassificationMixin(object):
the ``CreateClassifications`` API endpoint. the ``CreateClassifications`` API endpoint.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
assert classifications and isinstance( assert classifications and isinstance(
classifications, list classifications, list
...@@ -204,17 +202,17 @@ class ClassificationMixin(object): ...@@ -204,17 +202,17 @@ class ClassificationMixin(object):
for index, classification in enumerate(classifications): for index, classification in enumerate(classifications):
ml_class_id = classification.get("ml_class_id") ml_class_id = classification.get("ml_class_id")
assert ml_class_id and isinstance( assert (
ml_class_id, str ml_class_id and isinstance(ml_class_id, str)
), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str" ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
# Make sure it's a valid UUID # Make sure it's a valid UUID
try: try:
UUID(ml_class_id) UUID(ml_class_id)
except ValueError: except ValueError as e:
raise ValueError( raise ValueError(
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid." f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
) ) from e
confidence = classification.get("confidence") confidence = classification.get("confidence")
assert ( assert (
......
# -*- coding: utf-8 -*-
""" """
BaseWorker methods for datasets. BaseWorker methods for datasets.
""" """
from collections.abc import Iterator
from enum import Enum from enum import Enum
from typing import Iterator, Tuple
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.models import Dataset, Element from arkindex_worker.models import Dataset, Element
...@@ -36,7 +35,7 @@ class DatasetState(Enum): ...@@ -36,7 +35,7 @@ class DatasetState(Enum):
""" """
class DatasetMixin(object): class DatasetMixin:
def list_process_datasets(self) -> Iterator[Dataset]: def list_process_datasets(self) -> Iterator[Dataset]:
""" """
List datasets associated to the worker's process. This helper is not available in developer mode. List datasets associated to the worker's process. This helper is not available in developer mode.
...@@ -51,7 +50,7 @@ class DatasetMixin(object): ...@@ -51,7 +50,7 @@ class DatasetMixin(object):
return map(Dataset, list(results)) return map(Dataset, list(results))
def list_dataset_elements(self, dataset: Dataset) -> Iterator[Tuple[str, Element]]: def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
""" """
List elements in a dataset. List elements in a dataset.
......
# -*- coding: utf-8 -*-
""" """
ElementsWorker methods for elements and element types. ElementsWorker methods for elements and element types.
""" """
from typing import Dict, Iterable, List, NamedTuple, Optional, Union from collections.abc import Iterable
from typing import NamedTuple
from uuid import UUID from uuid import UUID
from peewee import IntegrityError from peewee import IntegrityError
...@@ -28,8 +28,8 @@ class MissingTypeError(Exception): ...@@ -28,8 +28,8 @@ class MissingTypeError(Exception):
""" """
class ElementMixin(object): class ElementMixin:
def create_required_types(self, element_types: List[ElementType]): def create_required_types(self, element_types: list[ElementType]):
"""Creates given element types in the corpus. """Creates given element types in the corpus.
:param element_types: The missing element types to create. :param element_types: The missing element types to create.
...@@ -86,9 +86,10 @@ class ElementMixin(object): ...@@ -86,9 +86,10 @@ class ElementMixin(object):
element: Element, element: Element,
type: str, type: str,
name: str, name: str,
polygon: List[List[Union[int, float]]], polygon: list[list[int | float]] | None = None,
confidence: Optional[float] = None, confidence: float | None = None,
slim_output: Optional[bool] = True, image: str | None = None,
slim_output: bool = True,
) -> str: ) -> str:
""" """
Create a child element on the given element through the API. Create a child element on the given element through the API.
...@@ -96,8 +97,10 @@ class ElementMixin(object): ...@@ -96,8 +97,10 @@ class ElementMixin(object):
:param Element element: The parent element. :param Element element: The parent element.
:param type: Slug of the element type for this child element. :param type: Slug of the element type for this child element.
:param name: Name of the child element. :param name: Name of the child element.
:param polygon: Polygon of the child element. :param polygon: Optional polygon of the child element.
:param confidence: Optional confidence score, between 0.0 and 1.0. :param confidence: Optional confidence score, between 0.0 and 1.0.
:param image: Optional image ID of the child element.
:param slim_output: Whether to return the child ID or the full child.
:returns: UUID of the created element. :returns: UUID of the created element.
""" """
assert element and isinstance( assert element and isinstance(
...@@ -109,19 +112,29 @@ class ElementMixin(object): ...@@ -109,19 +112,29 @@ class ElementMixin(object):
assert name and isinstance( assert name and isinstance(
name, str name, str
), "name shouldn't be null and should be of type str" ), "name shouldn't be null and should be of type str"
assert polygon and isinstance( assert polygon is None or isinstance(
polygon, list polygon, list
), "polygon shouldn't be null and should be of type list" ), "polygon should be None or a list"
assert len(polygon) >= 3, "polygon should have at least three points" if polygon is not None:
assert all( assert len(polygon) >= 3, "polygon should have at least three points"
isinstance(point, list) and len(point) == 2 for point in polygon assert all(
), "polygon points should be lists of two items" isinstance(point, list) and len(point) == 2 for point in polygon
assert all( ), "polygon points should be lists of two items"
isinstance(coord, (int, float)) for point in polygon for coord in point assert all(
), "polygon points should be lists of two numbers" isinstance(coord, int | float) for point in polygon for coord in point
), "polygon points should be lists of two numbers"
assert confidence is None or ( assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1 isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence should be None or a float in [0..1] range" ), "confidence should be None or a float in [0..1] range"
assert image is None or isinstance(image, str), "image should be None or string"
if image is not None:
# Make sure it's a valid UUID
try:
UUID(image)
except ValueError as e:
raise ValueError("image is not a valid uuid.") from e
if polygon and image is None:
assert element.zone, "An image or a parent with an image is required to create an element with a polygon."
assert isinstance(slim_output, bool), "slim_output should be of type bool" assert isinstance(slim_output, bool), "slim_output should be of type bool"
if self.is_read_only: if self.is_read_only:
...@@ -133,7 +146,7 @@ class ElementMixin(object): ...@@ -133,7 +146,7 @@ class ElementMixin(object):
body={ body={
"type": type, "type": type,
"name": name, "name": name,
"image": element.zone.image.id, "image": image,
"corpus": element.corpus.id, "corpus": element.corpus.id,
"polygon": polygon, "polygon": polygon,
"parent": element.id, "parent": element.id,
...@@ -146,11 +159,9 @@ class ElementMixin(object): ...@@ -146,11 +159,9 @@ class ElementMixin(object):
def create_elements( def create_elements(
self, self,
parent: Union[Element, CachedElement], parent: Element | CachedElement,
elements: List[ elements: list[dict[str, str | list[list[int | float]] | float | None]],
Dict[str, Union[str, List[List[Union[int, float]]], float, None]] ) -> list[dict[str, str]]:
],
) -> List[Dict[str, str]]:
""" """
Create child elements on the given element in a single API request. Create child elements on the given element in a single API request.
...@@ -195,18 +206,18 @@ class ElementMixin(object): ...@@ -195,18 +206,18 @@ class ElementMixin(object):
), f"Element at index {index} in elements: Should be of type dict" ), f"Element at index {index} in elements: Should be of type dict"
name = element.get("name") name = element.get("name")
assert name and isinstance( assert (
name, str name and isinstance(name, str)
), f"Element at index {index} in elements: name shouldn't be null and should be of type str" ), f"Element at index {index} in elements: name shouldn't be null and should be of type str"
type = element.get("type") type = element.get("type")
assert type and isinstance( assert (
type, str type and isinstance(type, str)
), f"Element at index {index} in elements: type shouldn't be null and should be of type str" ), f"Element at index {index} in elements: type shouldn't be null and should be of type str"
polygon = element.get("polygon") polygon = element.get("polygon")
assert polygon and isinstance( assert (
polygon, list polygon and isinstance(polygon, list)
), f"Element at index {index} in elements: polygon shouldn't be null and should be of type list" ), f"Element at index {index} in elements: polygon shouldn't be null and should be of type list"
assert ( assert (
len(polygon) >= 3 len(polygon) >= 3
...@@ -215,12 +226,13 @@ class ElementMixin(object): ...@@ -215,12 +226,13 @@ class ElementMixin(object):
isinstance(point, list) and len(point) == 2 for point in polygon isinstance(point, list) and len(point) == 2 for point in polygon
), f"Element at index {index} in elements: polygon points should be lists of two items" ), f"Element at index {index} in elements: polygon points should be lists of two items"
assert all( assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point isinstance(coord, int | float) for point in polygon for coord in point
), f"Element at index {index} in elements: polygon points should be lists of two numbers" ), f"Element at index {index} in elements: polygon points should be lists of two numbers"
confidence = element.get("confidence") confidence = element.get("confidence")
assert confidence is None or ( assert (
isinstance(confidence, float) and 0 <= confidence <= 1 confidence is None
or (isinstance(confidence, float) and 0 <= confidence <= 1)
), f"Element at index {index} in elements: confidence should be None or a float in [0..1] range" ), f"Element at index {index} in elements: confidence should be None or a float in [0..1] range"
if self.is_read_only: if self.is_read_only:
...@@ -271,8 +283,37 @@ class ElementMixin(object): ...@@ -271,8 +283,37 @@ class ElementMixin(object):
return created_ids return created_ids
def create_element_parent(
self,
parent: Element,
child: Element,
) -> dict[str, str]:
"""
Link an element to a parent through the API.
:param parent: Parent element.
:param child: Child element.
:returns: A dict from the ``CreateElementParent`` API endpoint.
"""
assert parent and isinstance(
parent, Element
), "parent shouldn't be null and should be of type Element"
assert child and isinstance(
child, Element
), "child shouldn't be null and should be of type Element"
if self.is_read_only:
logger.warning("Cannot link elements as this worker is in read-only mode")
return
return self.request(
"CreateElementParent",
parent=parent.id,
child=child.id,
)
def partial_update_element( def partial_update_element(
self, element: Union[Element, CachedElement], **kwargs self, element: Element | CachedElement, **kwargs
) -> dict: ) -> dict:
""" """
Partially updates an element through the API. Partially updates an element through the API.
...@@ -289,10 +330,10 @@ class ElementMixin(object): ...@@ -289,10 +330,10 @@ class ElementMixin(object):
* *image* (``UUID``): Optional ID of the image of this element * *image* (``UUID``): Optional ID of the image of this element
:returns: A dict from the ``PartialUpdateElement`` API endpoint, :returns: A dict from the ``PartialUpdateElement`` API endpoint.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
if "type" in kwargs: if "type" in kwargs:
...@@ -309,7 +350,7 @@ class ElementMixin(object): ...@@ -309,7 +350,7 @@ class ElementMixin(object):
isinstance(point, list) and len(point) == 2 for point in polygon isinstance(point, list) and len(point) == 2 for point in polygon
), "polygon points should be lists of two items" ), "polygon points should be lists of two items"
assert all( assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point isinstance(coord, int | float) for point in polygon for coord in point
), "polygon points should be lists of two numbers" ), "polygon points should be lists of two numbers"
if "confidence" in kwargs: if "confidence" in kwargs:
...@@ -363,21 +404,21 @@ class ElementMixin(object): ...@@ -363,21 +404,21 @@ class ElementMixin(object):
def list_element_children( def list_element_children(
self, self,
element: Union[Element, CachedElement], element: Element | CachedElement,
folder: Optional[bool] = None, folder: bool | None = None,
name: Optional[str] = None, name: str | None = None,
recursive: Optional[bool] = None, recursive: bool | None = None,
transcription_worker_version: Optional[Union[str, bool]] = None, transcription_worker_version: str | bool | None = None,
transcription_worker_run: Optional[Union[str, bool]] = None, transcription_worker_run: str | bool | None = None,
type: Optional[str] = None, type: str | None = None,
with_classes: Optional[bool] = None, with_classes: bool | None = None,
with_corpus: Optional[bool] = None, with_corpus: bool | None = None,
with_metadata: Optional[bool] = None, with_metadata: bool | None = None,
with_has_children: Optional[bool] = None, with_has_children: bool | None = None,
with_zone: Optional[bool] = None, with_zone: bool | None = None,
worker_version: Optional[Union[str, bool]] = None, worker_version: str | bool | None = None,
worker_run: Optional[Union[str, bool]] = None, worker_run: str | bool | None = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]: ) -> Iterable[dict] | Iterable[CachedElement]:
""" """
List children of an element. List children of an element.
...@@ -412,7 +453,7 @@ class ElementMixin(object): ...@@ -412,7 +453,7 @@ class ElementMixin(object):
or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled. or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
query_params = {} query_params = {}
if folder is not None: if folder is not None:
...@@ -426,7 +467,7 @@ class ElementMixin(object): ...@@ -426,7 +467,7 @@ class ElementMixin(object):
query_params["recursive"] = recursive query_params["recursive"] = recursive
if transcription_worker_version is not None: if transcription_worker_version is not None:
assert isinstance( assert isinstance(
transcription_worker_version, (str, bool) transcription_worker_version, str | bool
), "transcription_worker_version should be of type str or bool" ), "transcription_worker_version should be of type str or bool"
if isinstance(transcription_worker_version, bool): if isinstance(transcription_worker_version, bool):
assert ( assert (
...@@ -435,7 +476,7 @@ class ElementMixin(object): ...@@ -435,7 +476,7 @@ class ElementMixin(object):
query_params["transcription_worker_version"] = transcription_worker_version query_params["transcription_worker_version"] = transcription_worker_version
if transcription_worker_run is not None: if transcription_worker_run is not None:
assert isinstance( assert isinstance(
transcription_worker_run, (str, bool) transcription_worker_run, str | bool
), "transcription_worker_run should be of type str or bool" ), "transcription_worker_run should be of type str or bool"
if isinstance(transcription_worker_run, bool): if isinstance(transcription_worker_run, bool):
assert ( assert (
...@@ -466,7 +507,7 @@ class ElementMixin(object): ...@@ -466,7 +507,7 @@ class ElementMixin(object):
query_params["with_zone"] = with_zone query_params["with_zone"] = with_zone
if worker_version is not None: if worker_version is not None:
assert isinstance( assert isinstance(
worker_version, (str, bool) worker_version, str | bool
), "worker_version should be of type str or bool" ), "worker_version should be of type str or bool"
if isinstance(worker_version, bool): if isinstance(worker_version, bool):
assert ( assert (
...@@ -475,7 +516,7 @@ class ElementMixin(object): ...@@ -475,7 +516,7 @@ class ElementMixin(object):
query_params["worker_version"] = worker_version query_params["worker_version"] = worker_version
if worker_run is not None: if worker_run is not None:
assert isinstance( assert isinstance(
worker_run, (str, bool) worker_run, str | bool
), "worker_run should be of type str or bool" ), "worker_run should be of type str or bool"
if isinstance(worker_run, bool): if isinstance(worker_run, bool):
assert ( assert (
...@@ -485,11 +526,14 @@ class ElementMixin(object): ...@@ -485,11 +526,14 @@ class ElementMixin(object):
if self.use_cache: if self.use_cache:
# Checking that we only received query_params handled by the cache # Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= { assert (
"type", set(query_params.keys())
"worker_version", <= {
"worker_run", "type",
}, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" "worker_version",
"worker_run",
}
), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
query = CachedElement.select().where(CachedElement.parent_id == element.id) query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type: if type:
...@@ -522,21 +566,21 @@ class ElementMixin(object): ...@@ -522,21 +566,21 @@ class ElementMixin(object):
def list_element_parents( def list_element_parents(
self, self,
element: Union[Element, CachedElement], element: Element | CachedElement,
folder: Optional[bool] = None, folder: bool | None = None,
name: Optional[str] = None, name: str | None = None,
recursive: Optional[bool] = None, recursive: bool | None = None,
transcription_worker_version: Optional[Union[str, bool]] = None, transcription_worker_version: str | bool | None = None,
transcription_worker_run: Optional[Union[str, bool]] = None, transcription_worker_run: str | bool | None = None,
type: Optional[str] = None, type: str | None = None,
with_classes: Optional[bool] = None, with_classes: bool | None = None,
with_corpus: Optional[bool] = None, with_corpus: bool | None = None,
with_metadata: Optional[bool] = None, with_metadata: bool | None = None,
with_has_children: Optional[bool] = None, with_has_children: bool | None = None,
with_zone: Optional[bool] = None, with_zone: bool | None = None,
worker_version: Optional[Union[str, bool]] = None, worker_version: str | bool | None = None,
worker_run: Optional[Union[str, bool]] = None, worker_run: str | bool | None = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]: ) -> Iterable[dict] | Iterable[CachedElement]:
""" """
List parents of an element. List parents of an element.
...@@ -571,7 +615,7 @@ class ElementMixin(object): ...@@ -571,7 +615,7 @@ class ElementMixin(object):
or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled. or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
query_params = {} query_params = {}
if folder is not None: if folder is not None:
...@@ -585,7 +629,7 @@ class ElementMixin(object): ...@@ -585,7 +629,7 @@ class ElementMixin(object):
query_params["recursive"] = recursive query_params["recursive"] = recursive
if transcription_worker_version is not None: if transcription_worker_version is not None:
assert isinstance( assert isinstance(
transcription_worker_version, (str, bool) transcription_worker_version, str | bool
), "transcription_worker_version should be of type str or bool" ), "transcription_worker_version should be of type str or bool"
if isinstance(transcription_worker_version, bool): if isinstance(transcription_worker_version, bool):
assert ( assert (
...@@ -594,7 +638,7 @@ class ElementMixin(object): ...@@ -594,7 +638,7 @@ class ElementMixin(object):
query_params["transcription_worker_version"] = transcription_worker_version query_params["transcription_worker_version"] = transcription_worker_version
if transcription_worker_run is not None: if transcription_worker_run is not None:
assert isinstance( assert isinstance(
transcription_worker_run, (str, bool) transcription_worker_run, str | bool
), "transcription_worker_run should be of type str or bool" ), "transcription_worker_run should be of type str or bool"
if isinstance(transcription_worker_run, bool): if isinstance(transcription_worker_run, bool):
assert ( assert (
...@@ -625,7 +669,7 @@ class ElementMixin(object): ...@@ -625,7 +669,7 @@ class ElementMixin(object):
query_params["with_zone"] = with_zone query_params["with_zone"] = with_zone
if worker_version is not None: if worker_version is not None:
assert isinstance( assert isinstance(
worker_version, (str, bool) worker_version, str | bool
), "worker_version should be of type str or bool" ), "worker_version should be of type str or bool"
if isinstance(worker_version, bool): if isinstance(worker_version, bool):
assert ( assert (
...@@ -634,7 +678,7 @@ class ElementMixin(object): ...@@ -634,7 +678,7 @@ class ElementMixin(object):
query_params["worker_version"] = worker_version query_params["worker_version"] = worker_version
if worker_run is not None: if worker_run is not None:
assert isinstance( assert isinstance(
worker_run, (str, bool) worker_run, str | bool
), "worker_run should be of type str or bool" ), "worker_run should be of type str or bool"
if isinstance(worker_run, bool): if isinstance(worker_run, bool):
assert ( assert (
...@@ -644,11 +688,14 @@ class ElementMixin(object): ...@@ -644,11 +688,14 @@ class ElementMixin(object):
if self.use_cache: if self.use_cache:
# Checking that we only received query_params handled by the cache # Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= { assert (
"type", set(query_params.keys())
"worker_version", <= {
"worker_run", "type",
}, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" "worker_version",
"worker_run",
}
), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
parent_ids = CachedElement.select(CachedElement.parent_id).where( parent_ids = CachedElement.select(CachedElement.parent_id).where(
CachedElement.id == element.id CachedElement.id == element.id
......
# -*- coding: utf-8 -*-
""" """
ElementsWorker methods for entities. ElementsWorker methods for entities.
""" """
from operator import itemgetter from operator import itemgetter
from typing import Dict, List, Optional, TypedDict, Union from typing import TypedDict
from peewee import IntegrityError from peewee import IntegrityError
...@@ -12,16 +11,13 @@ from arkindex_worker import logger ...@@ -12,16 +11,13 @@ from arkindex_worker import logger
from arkindex_worker.cache import CachedEntity, CachedTranscriptionEntity from arkindex_worker.cache import CachedEntity, CachedTranscriptionEntity
from arkindex_worker.models import Element, Transcription from arkindex_worker.models import Element, Transcription
Entity = TypedDict(
"Entity", class Entity(TypedDict):
{ name: str
"name": str, type_id: str
"type_id": str, length: int
"length": int, offset: int
"offset": int, confidence: float | None
"confidence": Optional[float],
},
)
class MissingEntityType(Exception): class MissingEntityType(Exception):
...@@ -31,9 +27,9 @@ class MissingEntityType(Exception): ...@@ -31,9 +27,9 @@ class MissingEntityType(Exception):
""" """
class EntityMixin(object): class EntityMixin:
def check_required_entity_types( def check_required_entity_types(
self, entity_types: List[str], create_missing: bool = True self, entity_types: list[str], create_missing: bool = True
): ):
"""Checks that every entity type needed is available in the corpus. """Checks that every entity type needed is available in the corpus.
Missing ones may be created automatically if needed. Missing ones may be created automatically if needed.
...@@ -71,7 +67,7 @@ class EntityMixin(object): ...@@ -71,7 +67,7 @@ class EntityMixin(object):
self, self,
name: str, name: str,
type: str, type: str,
metas=dict(), metas=None,
validated=None, validated=None,
): ):
""" """
...@@ -87,6 +83,7 @@ class EntityMixin(object): ...@@ -87,6 +83,7 @@ class EntityMixin(object):
assert type and isinstance( assert type and isinstance(
type, str type, str
), "type shouldn't be null and should be of type str" ), "type shouldn't be null and should be of type str"
metas = metas or {}
if metas: if metas:
assert isinstance(metas, dict), "metas should be of type dict" assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None: if validated is not None:
...@@ -140,8 +137,8 @@ class EntityMixin(object): ...@@ -140,8 +137,8 @@ class EntityMixin(object):
entity: str, entity: str,
offset: int, offset: int,
length: int, length: int,
confidence: Optional[float] = None, confidence: float | None = None,
) -> Optional[Dict[str, Union[str, int]]]: ) -> dict[str, str | int] | None:
""" """
Create a link between an existing entity and an existing transcription. Create a link between an existing entity and an existing transcription.
If cache support is enabled, a `CachedTranscriptionEntity` will also be created. If cache support is enabled, a `CachedTranscriptionEntity` will also be created.
...@@ -211,8 +208,8 @@ class EntityMixin(object): ...@@ -211,8 +208,8 @@ class EntityMixin(object):
def create_transcription_entities( def create_transcription_entities(
self, self,
transcription: Transcription, transcription: Transcription,
entities: List[Entity], entities: list[Entity],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
""" """
Create multiple entities attached to a transcription in a single API request. Create multiple entities attached to a transcription in a single API request.
...@@ -250,13 +247,13 @@ class EntityMixin(object): ...@@ -250,13 +247,13 @@ class EntityMixin(object):
), f"Entity at index {index} in entities: Should be of type dict" ), f"Entity at index {index} in entities: Should be of type dict"
name = entity.get("name") name = entity.get("name")
assert name and isinstance( assert (
name, str name and isinstance(name, str)
), f"Entity at index {index} in entities: name shouldn't be null and should be of type str" ), f"Entity at index {index} in entities: name shouldn't be null and should be of type str"
type_id = entity.get("type_id") type_id = entity.get("type_id")
assert type_id and isinstance( assert (
type_id, str type_id and isinstance(type_id, str)
), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str" ), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str"
offset = entity.get("offset") offset = entity.get("offset")
...@@ -270,8 +267,9 @@ class EntityMixin(object): ...@@ -270,8 +267,9 @@ class EntityMixin(object):
), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer" ), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer"
confidence = entity.get("confidence") confidence = entity.get("confidence")
assert confidence is None or ( assert (
isinstance(confidence, float) and 0 <= confidence <= 1 confidence is None
or (isinstance(confidence, float) and 0 <= confidence <= 1)
), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range" ), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range"
assert len(entities) == len( assert len(entities) == len(
...@@ -298,7 +296,7 @@ class EntityMixin(object): ...@@ -298,7 +296,7 @@ class EntityMixin(object):
def list_transcription_entities( def list_transcription_entities(
self, self,
transcription: Transcription, transcription: Transcription,
worker_version: Optional[Union[str, bool]] = None, worker_version: str | bool | None = None,
): ):
""" """
List existing entities on a transcription List existing entities on a transcription
...@@ -314,7 +312,7 @@ class EntityMixin(object): ...@@ -314,7 +312,7 @@ class EntityMixin(object):
if worker_version is not None: if worker_version is not None:
assert isinstance( assert isinstance(
worker_version, (str, bool) worker_version, str | bool
), "worker_version should be of type str or bool" ), "worker_version should be of type str or bool"
if isinstance(worker_version, bool): if isinstance(worker_version, bool):
...@@ -329,12 +327,11 @@ class EntityMixin(object): ...@@ -329,12 +327,11 @@ class EntityMixin(object):
def list_corpus_entities( def list_corpus_entities(
self, self,
name: Optional[str] = None, name: str | None = None,
parent: Optional[Element] = None, parent: Element | None = None,
): ):
""" """
List all entities in the worker's corpus List all entities in the worker's corpus and store them in the ``self.entities`` cache.
This method does not support cache
:param name: Filter entities by part of their name (case-insensitive) :param name: Filter entities by part of their name (case-insensitive)
:param parent: Restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored. :param parent: Restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored.
""" """
...@@ -348,8 +345,14 @@ class EntityMixin(object): ...@@ -348,8 +345,14 @@ class EntityMixin(object):
assert isinstance(parent, Element), "parent should be of type Element" assert isinstance(parent, Element), "parent should be of type Element"
query_params["parent"] = parent.id query_params["parent"] = parent.id
return self.api_client.paginate( self.entities = {
"ListCorpusEntities", id=self.corpus_id, **query_params entity["id"]: entity
for entity in self.api_client.paginate(
"ListCorpusEntities", id=self.corpus_id, **query_params
)
}
logger.info(
f"Loaded {len(self.entities)} entities in corpus ({self.corpus_id})"
) )
def list_corpus_entity_types( def list_corpus_entity_types(
......
# -*- coding: utf-8 -*-
""" """
ElementsWorker methods for metadata. ElementsWorker methods for metadata.
""" """
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import CachedElement from arkindex_worker.cache import CachedElement
...@@ -57,14 +55,14 @@ class MetaType(Enum): ...@@ -57,14 +55,14 @@ class MetaType(Enum):
""" """
class MetaDataMixin(object): class MetaDataMixin:
def create_metadata( def create_metadata(
self, self,
element: Union[Element, CachedElement], element: Element | CachedElement,
type: MetaType, type: MetaType,
name: str, name: str,
value: str, value: str,
entity: Optional[str] = None, entity: str | None = None,
) -> str: ) -> str:
""" """
Create a metadata on the given element through API. Create a metadata on the given element through API.
...@@ -77,7 +75,7 @@ class MetaDataMixin(object): ...@@ -77,7 +75,7 @@ class MetaDataMixin(object):
:returns: UUID of the created metadata. :returns: UUID of the created metadata.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be of type Element or CachedElement" ), "element shouldn't be null and should be of type Element or CachedElement"
assert type and isinstance( assert type and isinstance(
type, MetaType type, MetaType
...@@ -110,26 +108,22 @@ class MetaDataMixin(object): ...@@ -110,26 +108,22 @@ class MetaDataMixin(object):
def create_metadatas( def create_metadatas(
self, self,
element: Union[Element, CachedElement], element: Element | CachedElement,
metadatas: List[ metadatas: list[dict[str, MetaType | str | int | float | None]],
Dict[ ) -> list[dict[str, str]]:
str, Union[MetaType, str, Union[str, Union[int, float]], Optional[str]]
]
],
) -> List[Dict[str, str]]:
""" """
Create multiple metadatas on an existing element. Create multiple metadata on an existing element.
This method does not support cache. This method does not support cache.
:param element: The element to create multiple metadata on. :param element: The element to create multiple metadata on.
:param metadatas: The list of dict whose keys are the following: :param metadatas: The list of dict whose keys are the following:
- type : MetaType - type: MetaType
- name : str - name: str
- value : Union[str, Union[int, float]] - value: str | int | float
- entity_id : Union[str, None] - entity_id: str | None
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be of type Element or CachedElement" ), "element shouldn't be null and should be of type Element or CachedElement"
assert metadatas and isinstance( assert metadatas and isinstance(
...@@ -152,7 +146,7 @@ class MetaDataMixin(object): ...@@ -152,7 +146,7 @@ class MetaDataMixin(object):
), "name shouldn't be null and should be of type str" ), "name shouldn't be null and should be of type str"
assert metadata.get("value") is not None and isinstance( assert metadata.get("value") is not None and isinstance(
metadata.get("value"), (str, float, int) metadata.get("value"), str | float | int
), "value shouldn't be null and should be of type (str or float or int)" ), "value shouldn't be null and should be of type (str or float or int)"
assert metadata.get("entity_id") is None or isinstance( assert metadata.get("entity_id") is None or isinstance(
...@@ -172,7 +166,7 @@ class MetaDataMixin(object): ...@@ -172,7 +166,7 @@ class MetaDataMixin(object):
logger.warning("Cannot create metadata as this worker is in read-only mode") logger.warning("Cannot create metadata as this worker is in read-only mode")
return return
created_metadatas = self.request( created_metadata_list = self.request(
"CreateMetaDataBulk", "CreateMetaDataBulk",
id=element.id, id=element.id,
body={ body={
...@@ -181,11 +175,11 @@ class MetaDataMixin(object): ...@@ -181,11 +175,11 @@ class MetaDataMixin(object):
}, },
)["metadata_list"] )["metadata_list"]
return created_metadatas return created_metadata_list
def list_element_metadata( def list_element_metadata(
self, element: Union[Element, CachedElement] self, element: Element | CachedElement
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
""" """
List all metadata linked to an element. List all metadata linked to an element.
This method does not support cache. This method does not support cache.
...@@ -193,7 +187,7 @@ class MetaDataMixin(object): ...@@ -193,7 +187,7 @@ class MetaDataMixin(object):
:param element: The element to list metadata on. :param element: The element to list metadata on.
""" """
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, Element | CachedElement
), "element shouldn't be null and should be of type Element or CachedElement" ), "element shouldn't be null and should be of type Element or CachedElement"
return self.api_client.paginate("ListElementMetaData", id=element.id) return self.api_client.paginate("ListElementMetaData", id=element.id)
# -*- coding: utf-8 -*-
""" """
BaseWorker methods for tasks. BaseWorker methods for tasks.
""" """
import uuid import uuid
from typing import Iterator from collections.abc import Iterator
from apistar.compat import DownloadedFile from apistar.compat import DownloadedFile
from arkindex_worker.models import Artifact from arkindex_worker.models import Artifact
class TaskMixin(object): class TaskMixin:
def list_artifacts(self, task_id: uuid.UUID) -> Iterator[Artifact]: def list_artifacts(self, task_id: uuid.UUID) -> Iterator[Artifact]:
""" """
List artifacts associated to a task. List artifacts associated to a task.
......
# -*- coding: utf-8 -*-
""" """
BaseWorker methods for training. BaseWorker methods for training.
""" """
...@@ -6,7 +5,7 @@ BaseWorker methods for training. ...@@ -6,7 +5,7 @@ BaseWorker methods for training.
import functools import functools
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import NewType, Optional, Tuple, Union from typing import NewType
from uuid import UUID from uuid import UUID
import requests import requests
...@@ -26,7 +25,7 @@ FileSize = NewType("FileSize", int) ...@@ -26,7 +25,7 @@ FileSize = NewType("FileSize", int)
@contextmanager @contextmanager
def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]: def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]:
""" """
Create a tar archive from the files at the given location then compress it to a zst archive. Create a tar archive from the files at the given location then compress it to a zst archive.
...@@ -72,7 +71,7 @@ def skip_if_read_only(func): ...@@ -72,7 +71,7 @@ def skip_if_read_only(func):
return wrapper return wrapper
class TrainingMixin(object): class TrainingMixin:
""" """
A mixin helper to create a new model version easily. A mixin helper to create a new model version easily.
You may use `publish_model_version` to publish a ready model version directly, or You may use `publish_model_version` to publish a ready model version directly, or
...@@ -87,10 +86,10 @@ class TrainingMixin(object): ...@@ -87,10 +86,10 @@ class TrainingMixin(object):
self, self,
model_path: DirPath, model_path: DirPath,
model_id: str, model_id: str,
tag: Optional[str] = None, tag: str | None = None,
description: Optional[str] = None, description: str | None = None,
configuration: Optional[dict] = {}, configuration: dict | None = None,
parent: Optional[Union[str, UUID]] = None, parent: str | UUID | None = None,
): ):
""" """
Publish a unique version of a model in Arkindex, identified by its hash. Publish a unique version of a model in Arkindex, identified by its hash.
...@@ -105,6 +104,7 @@ class TrainingMixin(object): ...@@ -105,6 +104,7 @@ class TrainingMixin(object):
:param parent: ID of the parent model version :param parent: ID of the parent model version
""" """
configuration = configuration or {}
if not self.model_version: if not self.model_version:
self.create_model_version( self.create_model_version(
model_id=model_id, model_id=model_id,
...@@ -161,10 +161,10 @@ class TrainingMixin(object): ...@@ -161,10 +161,10 @@ class TrainingMixin(object):
def create_model_version( def create_model_version(
self, self,
model_id: str, model_id: str,
tag: Optional[str] = None, tag: str | None = None,
description: Optional[str] = None, description: str | None = None,
configuration: Optional[dict] = {}, configuration: dict | None = None,
parent: Optional[Union[str, UUID]] = None, parent: str | UUID | None = None,
): ):
""" """
Create a new version of the specified model with its base attributes. Create a new version of the specified model with its base attributes.
...@@ -176,6 +176,8 @@ class TrainingMixin(object): ...@@ -176,6 +176,8 @@ class TrainingMixin(object):
:param parent: ID of the parent model version :param parent: ID of the parent model version
""" """
assert not self.model_version, "A model version has already been created." assert not self.model_version, "A model version has already been created."
configuration = configuration or {}
self.model_version = self.request( self.model_version = self.request(
"CreateModelVersion", "CreateModelVersion",
id=model_id, id=model_id,
...@@ -186,6 +188,7 @@ class TrainingMixin(object): ...@@ -186,6 +188,7 @@ class TrainingMixin(object):
parent=parent, parent=parent,
), ),
) )
logger.info( logger.info(
f"Model version ({self.model_version['id']}) was successfully created" f"Model version ({self.model_version['id']}) was successfully created"
) )
...@@ -193,10 +196,10 @@ class TrainingMixin(object): ...@@ -193,10 +196,10 @@ class TrainingMixin(object):
@skip_if_read_only @skip_if_read_only
def update_model_version( def update_model_version(
self, self,
tag: Optional[str] = None, tag: str | None = None,
description: Optional[str] = None, description: str | None = None,
configuration: Optional[dict] = None, configuration: dict | None = None,
parent: Optional[Union[str, UUID]] = None, parent: str | UUID | None = None,
): ):
""" """
Update the current model version with the given attributes. Update the current model version with the given attributes.
...@@ -235,9 +238,7 @@ class TrainingMixin(object): ...@@ -235,9 +238,7 @@ class TrainingMixin(object):
), "The model is already marked as available." ), "The model is already marked as available."
s3_put_url = self.model_version.get("s3_put_url") s3_put_url = self.model_version.get("s3_put_url")
assert ( assert s3_put_url, "S3 PUT URL is not set, please ensure you have the right to validate a model version."
s3_put_url
), "S3 PUT URL is not set, please ensure you have the right to validate a model version."
logger.info("Uploading to s3...") logger.info("Uploading to s3...")
# Upload the archive on s3 # Upload the archive on s3
...@@ -263,9 +264,7 @@ class TrainingMixin(object): ...@@ -263,9 +264,7 @@ class TrainingMixin(object):
:param size: The size of the uploaded archive :param size: The size of the uploaded archive
:param archive_hash: MD5 hash of the uploaded archive :param archive_hash: MD5 hash of the uploaded archive
""" """
assert ( assert self.model_version, "You must create the model version and upload its archive before validating it."
self.model_version
), "You must create the model version and upload its archive before validating it."
try: try:
self.model_version = self.request( self.model_version = self.request(
"ValidateModelVersion", "ValidateModelVersion",
......