Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (33)
Showing
with 363 additions and 703 deletions
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.278
rev: v0.1.7
hooks:
# Run the linter.
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
exclude: "^worker-{{cookiecutter.slug}}/"
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.11.0
hooks:
- id: black
# Run the formatter.
- id: ruff-format
exclude: "^worker-{{cookiecutter.slug}}/"
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-ast
- id: check-executables-have-shebangs
......@@ -26,9 +26,11 @@ repos:
- id: name-tests-test
args: ['--django']
- id: check-json
- id: check-toml
exclude: "^worker-{{cookiecutter.slug}}/pyproject.toml$"
- id: requirements-txt-fixer
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.6
hooks:
- id: codespell
args: ['--write-changes']
......
MIT License
Copyright (c) 2023 Teklia
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
.PHONY: release
release:
$(eval version:=$(shell cat VERSION))
echo $(version)
git commit VERSION -m "Version $(version)"
# Grep the version from pyproject.toml, squeeze multiple spaces, delete double and single quotes, get 3rd val.
# This command tolerates multiple whitespace sequences around the version number.
$(eval version:=$(shell grep -m 1 version pyproject.toml | tr -s ' ' | tr -d '"' | tr -d "'" | cut -d' ' -f3))
echo Releasing version $(version)
git commit pyproject.toml -m "Version $(version)"
git tag $(version)
git push origin master $(version)
......@@ -2,6 +2,12 @@
An easy to use Python 3 high level API client, to build ML tasks.
This is an open-source project, licensed using [the MIT license](https://opensource.org/license/mit/).
## Documentation
The [documentation](https://workers.arkindex.org/) is made with [Material for MkDocs](https://github.com/squidfunk/mkdocs-material) and is hosted by [GitLab Pages](https://docs.gitlab.com/ee/user/project/pages/).
## Create a new worker using our template
```
......
0.3.6-rc1
# -*- coding: utf-8 -*-
import logging
logging.basicConfig(
......
# -*- coding: utf-8 -*-
"""
Database mappings and helper methods for the experimental worker caching feature.
......@@ -10,7 +9,6 @@ reducing network usage.
import json
import sqlite3
from pathlib import Path
from typing import Optional, Union
from peewee import (
SQL,
......@@ -106,8 +104,8 @@ class CachedElement(Model):
def open_image(
self,
*args,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
max_width: int | None = None,
max_height: int | None = None,
**kwargs,
) -> Image:
"""
......@@ -145,17 +143,15 @@ class CachedElement(Model):
if max_width is None and max_height is None:
resize = "full"
else:
# Do not resize for polygons that do not exactly match the images
# as the resize is made directly by the IIIF server using the box parameter
if (
# Do not resize for polygons that do not exactly match the images
# as the resize is made directly by the IIIF server using the box parameter
bounding_box.width != self.image.width
or bounding_box.height != self.image.height
):
resize = "full"
# Do not resize when the image is below the maximum size
elif (max_width is None or self.image.width <= max_width) and (
max_height is None or self.image.height <= max_height
) or (
# Do not resize when the image is below the maximum size
(max_width is None or self.image.width <= max_width)
and (max_height is None or self.image.height <= max_height)
):
resize = "full"
else:
......@@ -319,22 +315,21 @@ def create_version_table():
Version.create(version=SQL_VERSION)
def check_version(cache_path: Union[str, Path]):
def check_version(cache_path: str | Path):
"""
Check the validity of the SQLite version
:param cache_path: Path towards a local SQLite database
"""
with SqliteDatabase(cache_path) as provided_db:
with provided_db.bind_ctx([Version]):
try:
version = Version.get().version
except OperationalError:
version = None
with SqliteDatabase(cache_path) as provided_db, provided_db.bind_ctx([Version]):
try:
version = Version.get().version
except OperationalError:
version = None
assert (
version == SQL_VERSION
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
assert (
version == SQL_VERSION
), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}"
def merge_parents_cache(paths: list, current_database: Path):
......@@ -358,9 +353,8 @@ def merge_parents_cache(paths: list, current_database: Path):
# Check that the parent cache uses a compatible version
check_version(path)
with SqliteDatabase(path) as source:
with source.bind_ctx(MODELS):
source.create_tables(MODELS)
with SqliteDatabase(path) as source, source.bind_ctx(MODELS):
source.create_tables(MODELS)
logger.info(f"Merging parent db {path} into {current_database}")
statements = [
......
# -*- coding: utf-8 -*-
"""
Helper classes for workers that interact with Git repositories and the GitLab API.
"""
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, Union
import gitlab
import requests
import sh
from gitlab.v4.objects import MergeRequest, ProjectMergeRequest
from arkindex_worker import logger
NOTHING_TO_COMMIT_MSG = "nothing to commit, working tree clean"
MR_HAS_CONFLICTS_ERROR_CODE = 406
class GitlabHelper:
"""Helper class to save files to GitLab repository"""
def __init__(
self,
project_id: str,
gitlab_url: str,
gitlab_token: str,
branch: str,
rebase_wait_period: Optional[int] = 1,
delete_source_branch: Optional[bool] = True,
max_rebase_tries: Optional[int] = 10,
):
"""
:param project_id: the id of the gitlab project
:param gitlab_url: gitlab server url
:param gitlab_token: gitlab private token of user with permission to accept merge requests
:param branch: name of the branch to where the exported branch will be merged
:param rebase_wait_period: seconds to wait between each poll to check whether rebase has finished
:param delete_source_branch: should delete the source branch after merging?
:param max_rebase_tries: max number of tries to rebase when merging before giving up
"""
self.project_id = project_id
self.gitlab_url = gitlab_url
self.gitlab_token = str(gitlab_token).strip()
self.branch = branch
self.rebase_wait_period = rebase_wait_period
self.delete_source_branch = delete_source_branch
self.max_rebase_tries = max_rebase_tries
logger.info("Creating a Gitlab client")
self._api = gitlab.Gitlab(self.gitlab_url, private_token=self.gitlab_token)
self.project = self._api.projects.get(self.project_id)
self.is_rebase_finished = False
def merge(self, branch_name: str, title: str) -> bool:
"""
Create a merge request and try to merge.
Always rebase first to avoid conflicts from MRs made in parallel
:param branch_name: Source branch name
:param title: Title of the merge request
:return: Whether the branch was successfully merged
"""
mr = None
# always rebase first, because other workers might have merged already
for i in range(self.max_rebase_tries):
logger.info(f"Trying to merge, try nr: {i}")
try:
if mr is None:
mr = self._create_merge_request(branch_name, title)
mr.rebase()
rebase_success = self._wait_for_rebase_to_finish(mr.iid)
if not rebase_success:
logger.error("Rebase failed, won't be able to merge!")
return False
mr.merge(should_remove_source_branch=self.delete_source_branch)
logger.info("Merge successful")
return True
except gitlab.GitlabMRClosedError as e:
if e.response_code == MR_HAS_CONFLICTS_ERROR_CODE:
logger.info("Merge failed, trying to rebase and merge again.")
continue
else:
logger.error(f"Merge was not successful: {e}")
return False
except gitlab.GitlabError as e:
logger.error(f"Gitlab error: {e}")
if 400 <= e.response_code < 500:
# 4XX errors shouldn't be fixed by retrying
raise e
except requests.exceptions.ConnectionError as e:
logger.error(f"Server connection error, will wait and retry: {e}")
time.sleep(self.rebase_wait_period)
return False
def _create_merge_request(self, branch_name: str, title: str) -> MergeRequest:
"""
Create a MergeRequest towards the branch with the given title
:param branch_name: Target branch of the merge request
:param title: Title of the merge request
:return: The created merge request
"""
logger.info(f"Creating a merge request for {branch_name}")
# retry_transient_error will retry the request on 50X errors
# https://github.com/python-gitlab/python-gitlab/blob/265dbbdd37af88395574564aeb3fd0350288a18c/gitlab/__init__.py#L539
mr = self.project.mergerequests.create(
{
"source_branch": branch_name,
"target_branch": self.branch,
"title": title,
},
)
return mr
def _get_merge_request(
self, merge_request_id: Union[str, int], include_rebase_in_progress: bool = True
) -> ProjectMergeRequest:
"""
Retrieve a merge request by ID
:param merge_request_id: The ID of the merge request
:param include_rebase_in_progress: Whether the rebase in progree should be included
:return: The related merge request
"""
return self.project.mergerequests.get(
merge_request_id, include_rebase_in_progress=include_rebase_in_progress
)
def _wait_for_rebase_to_finish(self, merge_request_id: Union[str, int]) -> bool:
"""
Poll the merge request until it has finished rebasing
:param merge_request_id: The ID of the merge request
:return: Whether the rebase has finished successfully
"""
logger.info("Checking if rebase has finished..")
self.is_rebase_finished = False
while not self.is_rebase_finished:
time.sleep(self.rebase_wait_period)
mr = self._get_merge_request(merge_request_id)
self.is_rebase_finished = not mr.rebase_in_progress
if mr.merge_error is None:
logger.info("Rebase has finished")
return True
logger.error(f"Rebase failed: {mr.merge_error}")
return False
def make_backup(path: str):
"""
Create a backup file in the same directory with timestamp as suffix ".bak_{timestamp}"
:param path: Path to the file to be backed up
"""
path = Path(path)
if not path.exists():
raise ValueError(f"No file to backup! File not found: {path}")
# timestamp with milliseconds
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
backup_path = Path(str(path) + f".bak_{timestamp}")
shutil.copy(path, backup_path)
logger.info(f"Made a backup {backup_path}")
def prepare_git_key(
private_key: str,
known_hosts: str,
private_key_path: Optional[str] = "~/.ssh/id_ed25519",
known_hosts_path: Optional[str] = "~/.ssh/known_hosts",
):
"""
Prepare the git keys (put them in to the correct place) so that git could be used.
Fixes some whitespace problems that come from arkindex secrets store (Django admin).
Also creates a backup of the previous keys if they exist, to avoid losing the
original keys of the developers.
:param private_key: git private key contents
:param known_hosts: git known_hosts contents
:param private_key_path: path where to put the private key
:param known_hosts_path: path where to put the known_hosts
"""
# secrets admin UI seems to strip the trailing whitespace
# but git requires the key file to have a new line at the end
# for some reason uses CRLF line endings, but git doesn't like that
private_key = private_key.replace("\r", "") + "\n"
known_hosts = known_hosts.replace("\r", "") + "\n"
private_key_path = Path(private_key_path).expanduser()
known_hosts_path = Path(known_hosts_path).expanduser()
if private_key_path.exists():
if private_key_path.read_text() != private_key:
make_backup(private_key_path)
if known_hosts_path.exists():
if known_hosts_path.read_text() != known_hosts:
make_backup(known_hosts_path)
private_key_path.write_text(private_key)
# private key must be private, otherwise git will fail
# expecting octal for permissions
private_key_path.chmod(0o600)
known_hosts_path.write_text(known_hosts)
logger.info(f"Private key size after: {private_key_path.stat().st_size}")
logger.info(f"Known size after: {known_hosts_path.stat().st_size}")
class GitHelper:
"""
A helper class for running git commands
At the beginning of the workflow call [run_clone_in_background][arkindex_worker.git.GitHelper.run_clone_in_background].
When all the files are ready to be added to git then call
[save_files][arkindex_worker.git.GitHelper.save_files] to move the files in to the git repository
and try to push them.
Examples
--------
in worker.configure() configure the git helper and start the cloning:
```
gitlab = GitlabHelper(...)
prepare_git_key(...)
self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...)
self.git_helper.run_clone_in_background()
```
at the end of the workflow (at the end of worker.run()) push the files to git:
```
self.git_helper.save_files(self.out_dir)
```
"""
def __init__(
self,
repo_url,
git_dir,
export_path,
workflow_id,
gitlab_helper: GitlabHelper,
git_clone_wait_period=1,
):
"""
:param repo_url: the url of the git repository where the export will be pushed
:param git_dir: the directory where to clone the git repository
:param export_path: the path inside the git repository where to put the exported files
:param workflow_id: the process id to see the workflow graph in the frontend
:param gitlab_helper: helper for gitlab
:param git_clone_wait_period: check if clone has finished every N seconds at the end of the workflow
"""
logger.info("Creating git helper")
self.repo_url = repo_url
self.git_dir = Path(git_dir)
self.export_path = self.git_dir / export_path
self.workflow_id = workflow_id
self.gitlab_helper = gitlab_helper
self.git_clone_wait_period = git_clone_wait_period
self.is_clone_finished = False
self.cmd = None
self.success = None
self.exit_code = None
self.git_dir.mkdir(parents=True, exist_ok=True)
# run git commands outside of the repository (no need to change dir)
self._git = sh.git.bake("-C", self.git_dir)
def _clone_done(self, cmd, success, exit_code):
"""
Method that is called when git clone has finished in the background
"""
logger.info("Finishing cloning")
self.cmd = cmd
self.success = success
self.exit_code = exit_code
self.is_clone_finished = True
if not success:
logger.error(f"Clone failed: {cmd} : {success} : {exit_code}")
logger.info("Cloning finished")
def run_clone_in_background(self):
"""
Clones the git repository in the background in to the self.git_dir directory.
`self.is_clone_finished` can be used to know whether the cloning has finished
or not.
"""
logger.info(f"Starting clone {self.repo_url} in background")
cmd = sh.git.clone(
self.repo_url, self.git_dir, _bg=True, _done=self._clone_done
)
logger.info(f"Continuing clone {self.repo_url} in background")
return cmd
def _wait_for_clone_to_finish(self):
logger.info("Checking if cloning has finished..")
while not self.is_clone_finished:
time.sleep(self.git_clone_wait_period)
logger.info("Cloning has finished")
if not self.success:
logger.error("Clone was not a success")
logger.error(f"Clone error exit code: {str(self.exit_code)}")
raise ValueError("Clone was not a success")
def save_files(self, export_out_dir: Path):
"""
Move files in export_out_dir to the cloned git repository
and try to merge the created files if possible.
:param export_out_dir: Path to the files to be saved
:raises sh.ErrorReturnCode: _description_
:raises Exception: _description_
"""
self._wait_for_clone_to_finish()
# move exported files to git directory
file_count = self._move_files_to_git(export_out_dir)
# use timestamp to avoid branch name conflicts with multiple chunks
current_timestamp = datetime.isoformat(datetime.now())
# ":" is not allowed in a branch name
branch_timestamp = current_timestamp.replace(":", ".")
# add files to a new branch
branch_name = f"workflow_{self.workflow_id}_{branch_timestamp}"
self._git.checkout("-b", branch_name)
self._git.add("-A")
try:
self._git.commit(
"-m",
f"Exported files from workflow: {self.workflow_id} at {current_timestamp}",
)
except sh.ErrorReturnCode as e:
if NOTHING_TO_COMMIT_MSG in str(e.stdout):
logger.warning("Nothing to commit (no changes)")
return
else:
logger.error(f"Commit failed:: {e}")
raise e
# count the number of lines in the output
wc_cmd_out = str(
sh.wc(self._git.show("--stat", "--name-status", "--oneline", "HEAD"), "-l")
)
# -1 because the of the git command header
files_committed = int(wc_cmd_out.strip()) - 1
logger.info(f"Committed {files_committed} files")
if file_count != files_committed:
logger.warning(
f"Of {file_count} added files only {files_committed} were committed"
)
self._git.push("-u", "origin", "HEAD")
if self.gitlab_helper:
try:
self.gitlab_helper.merge(branch_name, f"Merge {branch_name}")
except Exception as e:
logger.error(f"Merge failed: {e}")
raise e
else:
logger.info(
"No gitlab_helper defined, not trying to merge the pushed branch"
)
def _move_files_to_git(self, export_out_dir: Path) -> int:
"""
Move all files in the export_out_dir to the git repository
while keeping the same directory structure
:param export_out_dir: Path to the files to be moved
:return: Total count of moved files
"""
file_count = 0
file_names = [
file_name for file_name in export_out_dir.rglob("*") if file_name.is_file()
]
for file in file_names:
rel_file_path = file.relative_to(export_out_dir)
out_file = self.export_path / rel_file_path
if not out_file.exists():
out_file.parent.mkdir(parents=True, exist_ok=True)
# rename does not work if the source and destination are not on the same mounts
# it will give an error: "OSError: [Errno 18] Invalid cross-device link:"
shutil.copy(file, out_file)
file.unlink()
file_count += 1
logger.info(f"Moved {file_count} files")
return file_count
# -*- coding: utf-8 -*-
"""
Helper methods to download and open IIIF images, and manage polygons.
"""
......@@ -7,7 +6,7 @@ from collections import namedtuple
from io import BytesIO
from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING
import requests
from PIL import Image
......@@ -42,9 +41,9 @@ IIIF_MAX = "max"
def open_image(
path: str,
mode: Optional[str] = "RGB",
rotation_angle: Optional[int] = 0,
mirrored: Optional[bool] = False,
mode: str | None = "RGB",
rotation_angle: int | None = 0,
mirrored: bool | None = False,
) -> Image:
"""
Open an image from a path or a URL.
......@@ -71,7 +70,7 @@ def open_image(
else:
try:
image = Image.open(path)
except (IOError, ValueError):
except (OSError, ValueError):
image = download_image(path)
if image.mode != mode:
......@@ -141,14 +140,14 @@ def download_image(url: str) -> Image:
return image
def polygon_bounding_box(polygon: List[List[Union[int, float]]]) -> BoundingBox:
def polygon_bounding_box(polygon: list[list[int | float]]) -> BoundingBox:
"""
Compute the rectangle bounding box of a polygon.
:param polygon: Polygon to get the bounding box of.
:returns: Bounding box of this polygon.
"""
x_coords, y_coords = zip(*polygon)
x_coords, y_coords = zip(*polygon, strict=True)
x, y = min(x_coords), min(y_coords)
width, height = max(x_coords) - x, max(y_coords) - y
return BoundingBox(x, y, width, height)
......@@ -248,8 +247,8 @@ def download_tiles(url: str) -> Image:
def trim_polygon(
polygon: List[List[int]], image_width: int, image_height: int
) -> List[List[int]]:
polygon: list[list[int]], image_width: int, image_height: int
) -> list[list[int]]:
"""
Trim a polygon to an image's boundaries, with non-negative coordinates.
......@@ -265,10 +264,10 @@ def trim_polygon(
"""
assert isinstance(
polygon, (list, tuple)
polygon, list | tuple
), "Input polygon must be a valid list or tuple of points."
assert all(
isinstance(point, (list, tuple)) for point in polygon
isinstance(point, list | tuple) for point in polygon
), "Polygon points must be tuples or lists."
assert all(
len(point) == 2 for point in polygon
......@@ -301,10 +300,10 @@ def trim_polygon(
def revert_orientation(
element: Union["Element", "CachedElement"],
polygon: List[List[Union[int, float]]],
reverse: Optional[bool] = False,
) -> List[List[int]]:
element: "Element | CachedElement",
polygon: list[list[int | float]],
reverse: bool = False,
) -> list[list[int]]:
"""
Update the coordinates of the polygon of a child element based on the orientation of
its parent.
......@@ -324,7 +323,7 @@ def revert_orientation(
from arkindex_worker.models import Element
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement"
assert polygon and isinstance(
polygon, list
......
# -*- coding: utf-8 -*-
"""
Wrappers around API results to provide more convenient attribute access and IIIF helpers.
"""
import tempfile
from collections.abc import Generator
from contextlib import contextmanager
from typing import Generator, List, Optional
from PIL import Image
from requests import HTTPError
......@@ -34,10 +33,10 @@ class MagicDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError:
except KeyError as e:
raise AttributeError(
"{} object has no attribute '{}'".format(self.__class__.__name__, name)
)
f"{self.__class__.__name__} object has no attribute '{name}'"
) from e
def __setattr__(self, name, value):
return super().__setitem__(name, value)
......@@ -74,7 +73,7 @@ class Element(MagicDict):
parts[-3] = size
return "/".join(parts)
def image_url(self, size: str = "full") -> Optional[str]:
def image_url(self, size: str = "full") -> str | None:
"""
Build an URL to access the image.
When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers.
......@@ -89,10 +88,10 @@ class Element(MagicDict):
url = self.zone.image.url
if not url.endswith("/"):
url += "/"
return "{}full/{}/0/default.jpg".format(url, size)
return f"{url}full/{size}/0/default.jpg"
@property
def polygon(self) -> List[float]:
def polygon(self) -> list[float]:
"""
Access an Element's polygon.
This is a shortcut to an Element's polygon, normally accessed via
......@@ -101,7 +100,7 @@ class Element(MagicDict):
the [CachedElement][arkindex_worker.cache.CachedElement].polygon field.
"""
if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id))
raise ValueError(f"Element {self.id} has no zone")
return self.zone.polygon
@property
......@@ -122,11 +121,11 @@ class Element(MagicDict):
def open_image(
self,
*args,
max_width: Optional[int] = None,
max_height: Optional[int] = None,
use_full_image: Optional[bool] = False,
max_width: int | None = None,
max_height: int | None = None,
use_full_image: bool | None = False,
**kwargs,
) -> Image:
) -> Image.Image:
"""
Open this element's image using Pillow, rotating and mirroring it according
to the ``rotation_angle`` and ``mirrored`` attributes.
......@@ -173,7 +172,7 @@ class Element(MagicDict):
)
if not self.get("zone"):
raise ValueError("Element {} has no zone".format(self.id))
raise ValueError(f"Element {self.id} has no zone")
if self.requires_tiles:
if max_width is None and max_height is None:
......@@ -194,10 +193,7 @@ class Element(MagicDict):
else:
resize = f"{max_width or ''},{max_height or ''}"
if use_full_image:
url = self.image_url(resize)
else:
url = self.resize_zone_url(resize)
url = self.image_url(resize) if use_full_image else self.resize_zone_url(resize)
try:
return open_image(
......@@ -215,13 +211,13 @@ class Element(MagicDict):
# This element uses an S3 URL: the URL may have expired.
# Call the API to get a fresh URL and try again
# TODO: this should be done by the worker
raise NotImplementedError
raise NotImplementedError from e
return open_image(self.image_url(resize), *args, **kwargs)
raise
@contextmanager
def open_image_tempfile(
self, format: Optional[str] = "jpeg", *args, **kwargs
self, format: str | None = "jpeg", *args, **kwargs
) -> Generator[tempfile.NamedTemporaryFile, None, None]:
"""
Get the element's image as a temporary file stored on the disk.
......@@ -249,7 +245,7 @@ class Element(MagicDict):
type_name = self.type["display_name"]
else:
type_name = str(self.type)
return "{} {} ({})".format(type_name, self.name, self.id)
return f"{type_name} {self.name} ({self.id})"
class ArkindexModel(MagicDict):
......
# -*- coding: utf-8 -*-
import hashlib
import logging
import os
import tarfile
import tempfile
from pathlib import Path
from typing import Optional, Tuple, Union
import zstandard
import zstandard as zstd
......@@ -16,7 +14,7 @@ CHUNK_SIZE = 1024
"""Chunk Size used for ZSTD compression"""
def decompress_zst_archive(compressed_archive: Path) -> Tuple[int, Path]:
def decompress_zst_archive(compressed_archive: Path) -> tuple[int, Path]:
"""
Decompress a ZST-compressed tar archive in data dir. The tar archive is not extracted.
This returns the path to the archive and the file descriptor.
......@@ -29,18 +27,19 @@ def decompress_zst_archive(compressed_archive: Path) -> Tuple[int, Path]:
"""
dctx = zstandard.ZstdDecompressor()
archive_fd, archive_path = tempfile.mkstemp(suffix=".tar")
archive_path = Path(archive_path)
logger.debug(f"Uncompressing file to {archive_path}")
try:
with open(compressed_archive, "rb") as compressed, open(
archive_path, "wb"
with compressed_archive.open("rb") as compressed, archive_path.open(
"wb"
) as decompressed:
dctx.copy_stream(compressed, decompressed)
logger.debug(f"Successfully uncompressed archive {compressed_archive}")
except zstandard.ZstdError as e:
raise Exception(f"Couldn't uncompressed archive: {e}")
raise Exception(f"Couldn't uncompressed archive: {e}") from e
return archive_fd, Path(archive_path)
return archive_fd, archive_path
def extract_tar_archive(archive_path: Path, destination: Path):
......@@ -54,12 +53,12 @@ def extract_tar_archive(archive_path: Path, destination: Path):
with tarfile.open(archive_path) as tar_archive:
tar_archive.extractall(destination)
except tarfile.ReadError as e:
raise Exception(f"Couldn't handle the decompressed Tar archive: {e}")
raise Exception(f"Couldn't handle the decompressed Tar archive: {e}") from e
def extract_tar_zst_archive(
compressed_archive: Path, destination: Path
) -> Tuple[int, Path]:
) -> tuple[int, Path]:
"""
Extract a ZST-compressed tar archive's content to a specific destination
......@@ -89,8 +88,8 @@ def close_delete_file(file_descriptor: int, file_path: Path):
def zstd_compress(
source: Path, destination: Optional[Path] = None
) -> Tuple[Union[int, None], Path, str]:
source: Path, destination: Path | None = None
) -> tuple[int | None, Path, str]:
"""Compress a file using the Zstandard compression algorithm.
:param source: Path to the file to compress.
......@@ -117,13 +116,13 @@ def zstd_compress(
archive_file.write(compressed_chunk)
logger.debug(f"Successfully compressed {source}")
except zstandard.ZstdError as e:
raise Exception(f"Couldn't compress archive: {e}")
raise Exception(f"Couldn't compress archive: {e}") from e
return file_d, destination, archive_hasher.hexdigest()
def create_tar_archive(
path: Path, destination: Optional[Path] = None
) -> Tuple[Union[int, None], Path, str]:
path: Path, destination: Path | None = None
) -> tuple[int | None, Path, str]:
"""Create a tar archive using the content at specified location.
:param path: Path to the file to archive
......@@ -153,7 +152,7 @@ def create_tar_archive(
files.append(p)
logger.debug(f"Successfully created Tar archive from files @ {path}")
except tarfile.TarError as e:
raise Exception(f"Couldn't create Tar archive: {e}")
raise Exception(f"Couldn't create Tar archive: {e}") from e
# Sort by path
files.sort()
......@@ -168,8 +167,8 @@ def create_tar_archive(
def create_tar_zst_archive(
source: Path, destination: Optional[Path] = None
) -> Tuple[Union[int, None], Path, str, str]:
source: Path, destination: Path | None = None
) -> tuple[int | None, Path, str, str]:
"""Helper to create a TAR+ZST archive from a source folder.
:param source: Path to the folder whose content should be archived.
......
# -*- coding: utf-8 -*-
"""
Base classes to implement Arkindex workers.
"""
import contextlib
import json
import os
import sys
import uuid
from collections.abc import Iterable, Iterator
from enum import Enum
from itertools import groupby
from operator import itemgetter
from pathlib import Path
from typing import Iterable, Iterator, List, Tuple, Union
from apistar.exceptions import ErrorResponse
......@@ -102,7 +101,7 @@ class ElementsWorker(
self._worker_version_cache = {}
def list_elements(self) -> Union[Iterable[CachedElement], List[str]]:
def list_elements(self) -> Iterable[CachedElement] | list[str]:
"""
List the elements to be processed, either from the CLI arguments or
the cache database when enabled.
......@@ -227,21 +226,17 @@ class ElementsWorker(
)
if element:
# Try to update the activity to error state regardless of the response
try:
with contextlib.suppress(Exception):
self.update_activity(element.id, ActivityState.Error)
except Exception:
pass
if failed:
logger.error(
"Ran on {} elements: {} completed, {} failed".format(
count, count - failed, failed
)
f"Ran on {count} elements: {count - failed} completed, {failed} failed"
)
if failed >= count: # Everything failed!
sys.exit(1)
def process_element(self, element: Union[Element, CachedElement]):
def process_element(self, element: Element | CachedElement):
"""
Override this method to implement your worker and process a single Arkindex element at once.
......@@ -251,7 +246,7 @@ class ElementsWorker(
"""
def update_activity(
self, element_id: Union[str, uuid.UUID], state: ActivityState
self, element_id: str | uuid.UUID, state: ActivityState
) -> bool:
"""
Update the WorkerActivity for this element and worker.
......@@ -269,7 +264,7 @@ class ElementsWorker(
return True
assert element_id and isinstance(
element_id, (uuid.UUID, str)
element_id, uuid.UUID | str
), "element_id shouldn't be null and should be an UUID or str"
assert isinstance(state, ActivityState), "state should be an ActivityState"
......@@ -382,7 +377,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
def list_dataset_elements_per_split(
self, dataset: Dataset
) -> Iterator[Tuple[str, List[Element]]]:
) -> Iterator[tuple[str, list[Element]]]:
"""
List the elements in the dataset, grouped by split, using the
[list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method.
......@@ -392,8 +387,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
"""
def format_split(
split: Tuple[str, Iterator[Tuple[str, Element]]]
) -> Tuple[str, List[Element]]:
split: tuple[str, Iterator[tuple[str, Element]]],
) -> tuple[str, list[Element]]:
return (split[0], list(map(itemgetter(1), list(split[1]))))
return map(
......@@ -435,7 +430,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
"""
self.configure()
datasets: List[Dataset] | List[str] = list(self.list_datasets())
datasets: list[Dataset] | list[str] = list(self.list_datasets())
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
......@@ -445,6 +440,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
failed = 0
for i, item in enumerate(datasets, start=1):
dataset = None
dataset_artifact = None
try:
if not self.is_read_only:
# Just use the result of list_datasets as the dataset
......@@ -470,7 +467,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
self.update_dataset_state(dataset, DatasetState.Building)
else:
logger.info(f"Downloading data for {dataset} ({i}/{count})")
self.download_dataset_artifact(dataset)
dataset_artifact = self.download_dataset_artifact(dataset)
# Process the dataset
self.process_dataset(dataset)
......@@ -499,16 +496,16 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
)
if dataset and self.generator:
# Try to update the state to Error regardless of the response
try:
with contextlib.suppress(Exception):
self.update_dataset_state(dataset, DatasetState.Error)
except Exception:
pass
finally:
# Cleanup the dataset artifact if it was downloaded, no matter what
if dataset_artifact:
dataset_artifact.unlink(missing_ok=True)
if failed:
logger.error(
"Ran on {} datasets: {} completed, {} failed".format(
count, count - failed, failed
)
f"Ran on {count} datasets: {count - failed} completed, {failed} failed"
)
if failed >= count: # Everything failed!
sys.exit(1)
# -*- coding: utf-8 -*-
"""
The base class for all Arkindex workers.
"""
......@@ -9,7 +8,6 @@ import os
import shutil
from pathlib import Path
from tempfile import mkdtemp
from typing import List, Optional
import gnupg
import yaml
......@@ -52,15 +50,15 @@ class ExtrasDirNotFoundError(Exception):
"""
class BaseWorker(object):
class BaseWorker:
"""
Base class for Arkindex workers.
"""
def __init__(
self,
description: Optional[str] = "Arkindex Base Worker",
support_cache: Optional[bool] = False,
description: str | None = "Arkindex Base Worker",
support_cache: bool | None = False,
):
"""
Initialize the worker.
......@@ -353,7 +351,8 @@ class BaseWorker(object):
try:
gpg = gnupg.GPG()
decrypted = gpg.decrypt_file(open(path, "rb"))
with path.open("rb") as gpg_file:
decrypted = gpg.decrypt_file(gpg_file)
assert (
decrypted.ok
), f"GPG error: {decrypted.status} - {decrypted.stderr}"
......@@ -412,7 +411,7 @@ class BaseWorker(object):
)
return extras_dir
def find_parents_file_paths(self, filename: Path) -> List[Path]:
def find_parents_file_paths(self, filename: Path) -> list[Path]:
"""
Find the paths of a specific file from the parent tasks.
Only works if the task_parents attributes is updated, so if the cache is supported,
......
# -*- coding: utf-8 -*-
"""
ElementsWorker methods for classifications and ML classes.
"""
from typing import Dict, List, Optional, Union
from uuid import UUID
from apistar.exceptions import ErrorResponse
......@@ -14,7 +12,7 @@ from arkindex_worker.cache import CachedClassification, CachedElement
from arkindex_worker.models import Element
class ClassificationMixin(object):
class ClassificationMixin:
def load_corpus_classes(self):
"""
Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache.
......@@ -91,11 +89,11 @@ class ClassificationMixin(object):
def create_classification(
self,
element: Union[Element, CachedElement],
element: Element | CachedElement,
ml_class: str,
confidence: float,
high_confidence: Optional[bool] = False,
) -> Dict[str, str]:
high_confidence: bool = False,
) -> dict[str, str]:
"""
Create a classification on the given element through the API.
......@@ -106,7 +104,7 @@ class ClassificationMixin(object):
:returns: The created classification, as returned by the ``CreateClassification`` API endpoint.
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement"
assert ml_class and isinstance(
ml_class, str
......@@ -180,9 +178,9 @@ class ClassificationMixin(object):
def create_classifications(
self,
element: Union[Element, CachedElement],
classifications: List[Dict[str, Union[str, float, bool]]],
) -> List[Dict[str, Union[str, float, bool]]]:
element: Element | CachedElement,
classifications: list[dict[str, str | float | bool]],
) -> list[dict[str, str | float | bool]]:
"""
Create multiple classifications at once on the given element through the API.
......@@ -196,7 +194,7 @@ class ClassificationMixin(object):
the ``CreateClassifications`` API endpoint.
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement"
assert classifications and isinstance(
classifications, list
......@@ -204,17 +202,17 @@ class ClassificationMixin(object):
for index, classification in enumerate(classifications):
ml_class_id = classification.get("ml_class_id")
assert ml_class_id and isinstance(
ml_class_id, str
assert (
ml_class_id and isinstance(ml_class_id, str)
), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
# Make sure it's a valid UUID
try:
UUID(ml_class_id)
except ValueError:
except ValueError as e:
raise ValueError(
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
)
) from e
confidence = classification.get("confidence")
assert (
......
# -*- coding: utf-8 -*-
"""
BaseWorker methods for datasets.
"""
from collections.abc import Iterator
from enum import Enum
from typing import Iterator, Tuple
from arkindex_worker import logger
from arkindex_worker.models import Dataset, Element
......@@ -36,7 +35,7 @@ class DatasetState(Enum):
"""
class DatasetMixin(object):
class DatasetMixin:
def list_process_datasets(self) -> Iterator[Dataset]:
"""
List datasets associated to the worker's process. This helper is not available in developer mode.
......@@ -51,7 +50,7 @@ class DatasetMixin(object):
return map(Dataset, list(results))
def list_dataset_elements(self, dataset: Dataset) -> Iterator[Tuple[str, Element]]:
def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
"""
List elements in a dataset.
......
# -*- coding: utf-8 -*-
"""
ElementsWorker methods for elements and element types.
"""
from typing import Dict, Iterable, List, NamedTuple, Optional, Union
from collections.abc import Iterable
from typing import NamedTuple
from uuid import UUID
from peewee import IntegrityError
......@@ -28,8 +28,8 @@ class MissingTypeError(Exception):
"""
class ElementMixin(object):
def create_required_types(self, element_types: List[ElementType]):
class ElementMixin:
def create_required_types(self, element_types: list[ElementType]):
"""Creates given element types in the corpus.
:param element_types: The missing element types to create.
......@@ -86,9 +86,10 @@ class ElementMixin(object):
element: Element,
type: str,
name: str,
polygon: List[List[Union[int, float]]],
confidence: Optional[float] = None,
slim_output: Optional[bool] = True,
polygon: list[list[int | float]] | None = None,
confidence: float | None = None,
image: str | None = None,
slim_output: bool = True,
) -> str:
"""
Create a child element on the given element through the API.
......@@ -96,8 +97,10 @@ class ElementMixin(object):
:param Element element: The parent element.
:param type: Slug of the element type for this child element.
:param name: Name of the child element.
:param polygon: Polygon of the child element.
:param polygon: Optional polygon of the child element.
:param confidence: Optional confidence score, between 0.0 and 1.0.
:param image: Optional image ID of the child element.
:param slim_output: Whether to return the child ID or the full child.
:returns: UUID of the created element.
"""
assert element and isinstance(
......@@ -109,19 +112,29 @@ class ElementMixin(object):
assert name and isinstance(
name, str
), "name shouldn't be null and should be of type str"
assert polygon and isinstance(
assert polygon is None or isinstance(
polygon, list
), "polygon shouldn't be null and should be of type list"
assert len(polygon) >= 3, "polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), "polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
), "polygon points should be lists of two numbers"
), "polygon should be None or a list"
if polygon is not None:
assert len(polygon) >= 3, "polygon should have at least three points"
assert all(
isinstance(point, list) and len(point) == 2 for point in polygon
), "polygon points should be lists of two items"
assert all(
isinstance(coord, int | float) for point in polygon for coord in point
), "polygon points should be lists of two numbers"
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
), "confidence should be None or a float in [0..1] range"
assert image is None or isinstance(image, str), "image should be None or string"
if image is not None:
# Make sure it's a valid UUID
try:
UUID(image)
except ValueError as e:
raise ValueError("image is not a valid uuid.") from e
if polygon and image is None:
assert element.zone, "An image or a parent with an image is required to create an element with a polygon."
assert isinstance(slim_output, bool), "slim_output should be of type bool"
if self.is_read_only:
......@@ -133,7 +146,7 @@ class ElementMixin(object):
body={
"type": type,
"name": name,
"image": element.zone.image.id,
"image": image,
"corpus": element.corpus.id,
"polygon": polygon,
"parent": element.id,
......@@ -146,11 +159,9 @@ class ElementMixin(object):
def create_elements(
self,
parent: Union[Element, CachedElement],
elements: List[
Dict[str, Union[str, List[List[Union[int, float]]], float, None]]
],
) -> List[Dict[str, str]]:
parent: Element | CachedElement,
elements: list[dict[str, str | list[list[int | float]] | float | None]],
) -> list[dict[str, str]]:
"""
Create child elements on the given element in a single API request.
......@@ -195,18 +206,18 @@ class ElementMixin(object):
), f"Element at index {index} in elements: Should be of type dict"
name = element.get("name")
assert name and isinstance(
name, str
assert (
name and isinstance(name, str)
), f"Element at index {index} in elements: name shouldn't be null and should be of type str"
type = element.get("type")
assert type and isinstance(
type, str
assert (
type and isinstance(type, str)
), f"Element at index {index} in elements: type shouldn't be null and should be of type str"
polygon = element.get("polygon")
assert polygon and isinstance(
polygon, list
assert (
polygon and isinstance(polygon, list)
), f"Element at index {index} in elements: polygon shouldn't be null and should be of type list"
assert (
len(polygon) >= 3
......@@ -215,12 +226,13 @@ class ElementMixin(object):
isinstance(point, list) and len(point) == 2 for point in polygon
), f"Element at index {index} in elements: polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
isinstance(coord, int | float) for point in polygon for coord in point
), f"Element at index {index} in elements: polygon points should be lists of two numbers"
confidence = element.get("confidence")
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
assert (
confidence is None
or (isinstance(confidence, float) and 0 <= confidence <= 1)
), f"Element at index {index} in elements: confidence should be None or a float in [0..1] range"
if self.is_read_only:
......@@ -271,8 +283,37 @@ class ElementMixin(object):
return created_ids
def create_element_parent(
self,
parent: Element,
child: Element,
) -> dict[str, str]:
"""
Link an element to a parent through the API.
:param parent: Parent element.
:param child: Child element.
:returns: A dict from the ``CreateElementParent`` API endpoint.
"""
assert parent and isinstance(
parent, Element
), "parent shouldn't be null and should be of type Element"
assert child and isinstance(
child, Element
), "child shouldn't be null and should be of type Element"
if self.is_read_only:
logger.warning("Cannot link elements as this worker is in read-only mode")
return
return self.request(
"CreateElementParent",
parent=parent.id,
child=child.id,
)
def partial_update_element(
self, element: Union[Element, CachedElement], **kwargs
self, element: Element | CachedElement, **kwargs
) -> dict:
"""
Partially updates an element through the API.
......@@ -289,10 +330,10 @@ class ElementMixin(object):
* *image* (``UUID``): Optional ID of the image of this element
:returns: A dict from the ``PartialUpdateElement`` API endpoint,
:returns: A dict from the ``PartialUpdateElement`` API endpoint.
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement"
if "type" in kwargs:
......@@ -309,7 +350,7 @@ class ElementMixin(object):
isinstance(point, list) and len(point) == 2 for point in polygon
), "polygon points should be lists of two items"
assert all(
isinstance(coord, (int, float)) for point in polygon for coord in point
isinstance(coord, int | float) for point in polygon for coord in point
), "polygon points should be lists of two numbers"
if "confidence" in kwargs:
......@@ -363,21 +404,21 @@ class ElementMixin(object):
def list_element_children(
self,
element: Union[Element, CachedElement],
folder: Optional[bool] = None,
name: Optional[str] = None,
recursive: Optional[bool] = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
transcription_worker_run: Optional[Union[str, bool]] = None,
type: Optional[str] = None,
with_classes: Optional[bool] = None,
with_corpus: Optional[bool] = None,
with_metadata: Optional[bool] = None,
with_has_children: Optional[bool] = None,
with_zone: Optional[bool] = None,
worker_version: Optional[Union[str, bool]] = None,
worker_run: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]:
element: Element | CachedElement,
folder: bool | None = None,
name: str | None = None,
recursive: bool | None = None,
transcription_worker_version: str | bool | None = None,
transcription_worker_run: str | bool | None = None,
type: str | None = None,
with_classes: bool | None = None,
with_corpus: bool | None = None,
with_metadata: bool | None = None,
with_has_children: bool | None = None,
with_zone: bool | None = None,
worker_version: str | bool | None = None,
worker_run: str | bool | None = None,
) -> Iterable[dict] | Iterable[CachedElement]:
"""
List children of an element.
......@@ -412,7 +453,7 @@ class ElementMixin(object):
or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement"
query_params = {}
if folder is not None:
......@@ -426,7 +467,7 @@ class ElementMixin(object):
query_params["recursive"] = recursive
if transcription_worker_version is not None:
assert isinstance(
transcription_worker_version, (str, bool)
transcription_worker_version, str | bool
), "transcription_worker_version should be of type str or bool"
if isinstance(transcription_worker_version, bool):
assert (
......@@ -435,7 +476,7 @@ class ElementMixin(object):
query_params["transcription_worker_version"] = transcription_worker_version
if transcription_worker_run is not None:
assert isinstance(
transcription_worker_run, (str, bool)
transcription_worker_run, str | bool
), "transcription_worker_run should be of type str or bool"
if isinstance(transcription_worker_run, bool):
assert (
......@@ -466,7 +507,7 @@ class ElementMixin(object):
query_params["with_zone"] = with_zone
if worker_version is not None:
assert isinstance(
worker_version, (str, bool)
worker_version, str | bool
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
......@@ -475,7 +516,7 @@ class ElementMixin(object):
query_params["worker_version"] = worker_version
if worker_run is not None:
assert isinstance(
worker_run, (str, bool)
worker_run, str | bool
), "worker_run should be of type str or bool"
if isinstance(worker_run, bool):
assert (
......@@ -485,11 +526,14 @@ class ElementMixin(object):
if self.use_cache:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"type",
"worker_version",
"worker_run",
}, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
assert (
set(query_params.keys())
<= {
"type",
"worker_version",
"worker_run",
}
), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
query = CachedElement.select().where(CachedElement.parent_id == element.id)
if type:
......@@ -522,21 +566,21 @@ class ElementMixin(object):
def list_element_parents(
self,
element: Union[Element, CachedElement],
folder: Optional[bool] = None,
name: Optional[str] = None,
recursive: Optional[bool] = None,
transcription_worker_version: Optional[Union[str, bool]] = None,
transcription_worker_run: Optional[Union[str, bool]] = None,
type: Optional[str] = None,
with_classes: Optional[bool] = None,
with_corpus: Optional[bool] = None,
with_metadata: Optional[bool] = None,
with_has_children: Optional[bool] = None,
with_zone: Optional[bool] = None,
worker_version: Optional[Union[str, bool]] = None,
worker_run: Optional[Union[str, bool]] = None,
) -> Union[Iterable[dict], Iterable[CachedElement]]:
element: Element | CachedElement,
folder: bool | None = None,
name: str | None = None,
recursive: bool | None = None,
transcription_worker_version: str | bool | None = None,
transcription_worker_run: str | bool | None = None,
type: str | None = None,
with_classes: bool | None = None,
with_corpus: bool | None = None,
with_metadata: bool | None = None,
with_has_children: bool | None = None,
with_zone: bool | None = None,
worker_version: str | bool | None = None,
worker_run: str | bool | None = None,
) -> Iterable[dict] | Iterable[CachedElement]:
"""
List parents of an element.
......@@ -571,7 +615,7 @@ class ElementMixin(object):
or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be an Element or CachedElement"
query_params = {}
if folder is not None:
......@@ -585,7 +629,7 @@ class ElementMixin(object):
query_params["recursive"] = recursive
if transcription_worker_version is not None:
assert isinstance(
transcription_worker_version, (str, bool)
transcription_worker_version, str | bool
), "transcription_worker_version should be of type str or bool"
if isinstance(transcription_worker_version, bool):
assert (
......@@ -594,7 +638,7 @@ class ElementMixin(object):
query_params["transcription_worker_version"] = transcription_worker_version
if transcription_worker_run is not None:
assert isinstance(
transcription_worker_run, (str, bool)
transcription_worker_run, str | bool
), "transcription_worker_run should be of type str or bool"
if isinstance(transcription_worker_run, bool):
assert (
......@@ -625,7 +669,7 @@ class ElementMixin(object):
query_params["with_zone"] = with_zone
if worker_version is not None:
assert isinstance(
worker_version, (str, bool)
worker_version, str | bool
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
assert (
......@@ -634,7 +678,7 @@ class ElementMixin(object):
query_params["worker_version"] = worker_version
if worker_run is not None:
assert isinstance(
worker_run, (str, bool)
worker_run, str | bool
), "worker_run should be of type str or bool"
if isinstance(worker_run, bool):
assert (
......@@ -644,11 +688,14 @@ class ElementMixin(object):
if self.use_cache:
# Checking that we only received query_params handled by the cache
assert set(query_params.keys()) <= {
"type",
"worker_version",
"worker_run",
}, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
assert (
set(query_params.keys())
<= {
"type",
"worker_version",
"worker_run",
}
), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
parent_ids = CachedElement.select(CachedElement.parent_id).where(
CachedElement.id == element.id
......
# -*- coding: utf-8 -*-
"""
ElementsWorker methods for entities.
"""
from operator import itemgetter
from typing import Dict, List, Optional, TypedDict, Union
from typing import TypedDict
from peewee import IntegrityError
......@@ -12,16 +11,13 @@ from arkindex_worker import logger
from arkindex_worker.cache import CachedEntity, CachedTranscriptionEntity
from arkindex_worker.models import Element, Transcription
Entity = TypedDict(
"Entity",
{
"name": str,
"type_id": str,
"length": int,
"offset": int,
"confidence": Optional[float],
},
)
class Entity(TypedDict):
name: str
type_id: str
length: int
offset: int
confidence: float | None
class MissingEntityType(Exception):
......@@ -31,9 +27,9 @@ class MissingEntityType(Exception):
"""
class EntityMixin(object):
class EntityMixin:
def check_required_entity_types(
self, entity_types: List[str], create_missing: bool = True
self, entity_types: list[str], create_missing: bool = True
):
"""Checks that every entity type needed is available in the corpus.
Missing ones may be created automatically if needed.
......@@ -71,7 +67,7 @@ class EntityMixin(object):
self,
name: str,
type: str,
metas=dict(),
metas=None,
validated=None,
):
"""
......@@ -87,6 +83,7 @@ class EntityMixin(object):
assert type and isinstance(
type, str
), "type shouldn't be null and should be of type str"
metas = metas or {}
if metas:
assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None:
......@@ -140,8 +137,8 @@ class EntityMixin(object):
entity: str,
offset: int,
length: int,
confidence: Optional[float] = None,
) -> Optional[Dict[str, Union[str, int]]]:
confidence: float | None = None,
) -> dict[str, str | int] | None:
"""
Create a link between an existing entity and an existing transcription.
If cache support is enabled, a `CachedTranscriptionEntity` will also be created.
......@@ -211,8 +208,8 @@ class EntityMixin(object):
def create_transcription_entities(
self,
transcription: Transcription,
entities: List[Entity],
) -> List[Dict[str, str]]:
entities: list[Entity],
) -> list[dict[str, str]]:
"""
Create multiple entities attached to a transcription in a single API request.
......@@ -250,13 +247,13 @@ class EntityMixin(object):
), f"Entity at index {index} in entities: Should be of type dict"
name = entity.get("name")
assert name and isinstance(
name, str
assert (
name and isinstance(name, str)
), f"Entity at index {index} in entities: name shouldn't be null and should be of type str"
type_id = entity.get("type_id")
assert type_id and isinstance(
type_id, str
assert (
type_id and isinstance(type_id, str)
), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str"
offset = entity.get("offset")
......@@ -270,8 +267,9 @@ class EntityMixin(object):
), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer"
confidence = entity.get("confidence")
assert confidence is None or (
isinstance(confidence, float) and 0 <= confidence <= 1
assert (
confidence is None
or (isinstance(confidence, float) and 0 <= confidence <= 1)
), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range"
assert len(entities) == len(
......@@ -298,7 +296,7 @@ class EntityMixin(object):
def list_transcription_entities(
self,
transcription: Transcription,
worker_version: Optional[Union[str, bool]] = None,
worker_version: str | bool | None = None,
):
"""
List existing entities on a transcription
......@@ -314,7 +312,7 @@ class EntityMixin(object):
if worker_version is not None:
assert isinstance(
worker_version, (str, bool)
worker_version, str | bool
), "worker_version should be of type str or bool"
if isinstance(worker_version, bool):
......@@ -329,12 +327,11 @@ class EntityMixin(object):
def list_corpus_entities(
self,
name: Optional[str] = None,
parent: Optional[Element] = None,
name: str | None = None,
parent: Element | None = None,
):
"""
List all entities in the worker's corpus
This method does not support cache
List all entities in the worker's corpus and store them in the ``self.entities`` cache.
:param name: Filter entities by part of their name (case-insensitive)
:param parent: Restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored.
"""
......@@ -348,8 +345,14 @@ class EntityMixin(object):
assert isinstance(parent, Element), "parent should be of type Element"
query_params["parent"] = parent.id
return self.api_client.paginate(
"ListCorpusEntities", id=self.corpus_id, **query_params
self.entities = {
entity["id"]: entity
for entity in self.api_client.paginate(
"ListCorpusEntities", id=self.corpus_id, **query_params
)
}
logger.info(
f"Loaded {len(self.entities)} entities in corpus ({self.corpus_id})"
)
def list_corpus_entity_types(
......
# -*- coding: utf-8 -*-
"""
ElementsWorker methods for metadata.
"""
from enum import Enum
from typing import Dict, List, Optional, Union
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement
......@@ -57,14 +55,14 @@ class MetaType(Enum):
"""
class MetaDataMixin(object):
class MetaDataMixin:
def create_metadata(
self,
element: Union[Element, CachedElement],
element: Element | CachedElement,
type: MetaType,
name: str,
value: str,
entity: Optional[str] = None,
entity: str | None = None,
) -> str:
"""
Create a metadata on the given element through API.
......@@ -77,7 +75,7 @@ class MetaDataMixin(object):
:returns: UUID of the created metadata.
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be of type Element or CachedElement"
assert type and isinstance(
type, MetaType
......@@ -110,26 +108,22 @@ class MetaDataMixin(object):
def create_metadatas(
self,
element: Union[Element, CachedElement],
metadatas: List[
Dict[
str, Union[MetaType, str, Union[str, Union[int, float]], Optional[str]]
]
],
) -> List[Dict[str, str]]:
element: Element | CachedElement,
metadatas: list[dict[str, MetaType | str | int | float | None]],
) -> list[dict[str, str]]:
"""
Create multiple metadatas on an existing element.
Create multiple metadata on an existing element.
This method does not support cache.
:param element: The element to create multiple metadata on.
:param metadatas: The list of dict whose keys are the following:
- type : MetaType
- name : str
- value : Union[str, Union[int, float]]
- entity_id : Union[str, None]
- type: MetaType
- name: str
- value: str | int | float
- entity_id: str | None
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be of type Element or CachedElement"
assert metadatas and isinstance(
......@@ -152,7 +146,7 @@ class MetaDataMixin(object):
), "name shouldn't be null and should be of type str"
assert metadata.get("value") is not None and isinstance(
metadata.get("value"), (str, float, int)
metadata.get("value"), str | float | int
), "value shouldn't be null and should be of type (str or float or int)"
assert metadata.get("entity_id") is None or isinstance(
......@@ -172,7 +166,7 @@ class MetaDataMixin(object):
logger.warning("Cannot create metadata as this worker is in read-only mode")
return
created_metadatas = self.request(
created_metadata_list = self.request(
"CreateMetaDataBulk",
id=element.id,
body={
......@@ -181,11 +175,11 @@ class MetaDataMixin(object):
},
)["metadata_list"]
return created_metadatas
return created_metadata_list
def list_element_metadata(
self, element: Union[Element, CachedElement]
) -> List[Dict[str, str]]:
self, element: Element | CachedElement
) -> list[dict[str, str]]:
"""
List all metadata linked to an element.
This method does not support cache.
......@@ -193,7 +187,7 @@ class MetaDataMixin(object):
:param element: The element to list metadata on.
"""
assert element and isinstance(
element, (Element, CachedElement)
element, Element | CachedElement
), "element shouldn't be null and should be of type Element or CachedElement"
return self.api_client.paginate("ListElementMetaData", id=element.id)
# -*- coding: utf-8 -*-
"""
BaseWorker methods for tasks.
"""
import uuid
from typing import Iterator
from collections.abc import Iterator
from apistar.compat import DownloadedFile
from arkindex_worker.models import Artifact
class TaskMixin(object):
class TaskMixin:
def list_artifacts(self, task_id: uuid.UUID) -> Iterator[Artifact]:
"""
List artifacts associated to a task.
......
# -*- coding: utf-8 -*-
"""
BaseWorker methods for training.
"""
......@@ -6,7 +5,7 @@ BaseWorker methods for training.
import functools
from contextlib import contextmanager
from pathlib import Path
from typing import NewType, Optional, Tuple, Union
from typing import NewType
from uuid import UUID
import requests
......@@ -26,7 +25,7 @@ FileSize = NewType("FileSize", int)
@contextmanager
def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]:
def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]:
"""
Create a tar archive from the files at the given location then compress it to a zst archive.
......@@ -72,7 +71,7 @@ def skip_if_read_only(func):
return wrapper
class TrainingMixin(object):
class TrainingMixin:
"""
A mixin helper to create a new model version easily.
You may use `publish_model_version` to publish a ready model version directly, or
......@@ -87,10 +86,10 @@ class TrainingMixin(object):
self,
model_path: DirPath,
model_id: str,
tag: Optional[str] = None,
description: Optional[str] = None,
configuration: Optional[dict] = {},
parent: Optional[Union[str, UUID]] = None,
tag: str | None = None,
description: str | None = None,
configuration: dict | None = None,
parent: str | UUID | None = None,
):
"""
Publish a unique version of a model in Arkindex, identified by its hash.
......@@ -105,6 +104,7 @@ class TrainingMixin(object):
:param parent: ID of the parent model version
"""
configuration = configuration or {}
if not self.model_version:
self.create_model_version(
model_id=model_id,
......@@ -161,10 +161,10 @@ class TrainingMixin(object):
def create_model_version(
self,
model_id: str,
tag: Optional[str] = None,
description: Optional[str] = None,
configuration: Optional[dict] = {},
parent: Optional[Union[str, UUID]] = None,
tag: str | None = None,
description: str | None = None,
configuration: dict | None = None,
parent: str | UUID | None = None,
):
"""
Create a new version of the specified model with its base attributes.
......@@ -176,6 +176,8 @@ class TrainingMixin(object):
:param parent: ID of the parent model version
"""
assert not self.model_version, "A model version has already been created."
configuration = configuration or {}
self.model_version = self.request(
"CreateModelVersion",
id=model_id,
......@@ -186,6 +188,7 @@ class TrainingMixin(object):
parent=parent,
),
)
logger.info(
f"Model version ({self.model_version['id']}) was successfully created"
)
......@@ -193,10 +196,10 @@ class TrainingMixin(object):
@skip_if_read_only
def update_model_version(
self,
tag: Optional[str] = None,
description: Optional[str] = None,
configuration: Optional[dict] = None,
parent: Optional[Union[str, UUID]] = None,
tag: str | None = None,
description: str | None = None,
configuration: dict | None = None,
parent: str | UUID | None = None,
):
"""
Update the current model version with the given attributes.
......@@ -235,9 +238,7 @@ class TrainingMixin(object):
), "The model is already marked as available."
s3_put_url = self.model_version.get("s3_put_url")
assert (
s3_put_url
), "S3 PUT URL is not set, please ensure you have the right to validate a model version."
assert s3_put_url, "S3 PUT URL is not set, please ensure you have the right to validate a model version."
logger.info("Uploading to s3...")
# Upload the archive on s3
......@@ -263,9 +264,7 @@ class TrainingMixin(object):
:param size: The size of the uploaded archive
:param archive_hash: MD5 hash of the uploaded archive
"""
assert (
self.model_version
), "You must create the model version and upload its archive before validating it."
assert self.model_version, "You must create the model version and upload its archive before validating it."
try:
self.model_version = self.request(
"ValidateModelVersion",
......