diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d369f51f05fe3234cc440dc161d89eb15c846834..aca7f2e70d36cf34947c02d1b446c6cc5c19f7ca 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -55,8 +55,7 @@ test: - tox -- --junitxml=test-report.xml --durations=50 test-cookiecutter: - # Needed till next release - image: python:3.11 + image: python:slim stage: test cache: @@ -68,6 +67,7 @@ test-cookiecutter: PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" PRE_COMMIT_HOME: "$CI_PROJECT_DIR/.cache/pre-commit" ARKINDEX_API_SCHEMA_URL: schema.yml + DEBIAN_FRONTEND: non-interactive except: - schedules @@ -75,6 +75,9 @@ test-cookiecutter: before_script: - pip install cookiecutter tox pre-commit + # Install curl and git + - apt-get update -q -y && apt-get install -q -y --no-install-recommends curl git + # Configure git to be able to commit in the hook - git config --global user.email "crasher@teklia.com" - git config --global user.name "Crash Test" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 05b268629f29aa8eb15c501ced4b907fa164cd65..c9517d2683591cce8b1d1d4ce3054778d920760b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,17 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.278 + rev: v0.1.7 hooks: + # Run the linter. - id: ruff args: [--fix, --exit-non-zero-on-fix] exclude: "^worker-{{cookiecutter.slug}}/" - - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.11.0 - hooks: - - id: black + # Run the formatter. + - id: ruff-format + exclude: "^worker-{{cookiecutter.slug}}/" - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-ast - id: check-executables-have-shebangs @@ -26,9 +26,11 @@ repos: - id: name-tests-test args: ['--django'] - id: check-json + - id: check-toml + exclude: "^worker-{{cookiecutter.slug}}/pyproject.toml$" - id: requirements-txt-fixer - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.2.6 hooks: - id: codespell args: ['--write-changes'] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..66632cb5a2f25c743b5d4355e63eeaaa2361775c --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Teklia + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile index 38fb14c85e9fa608935a032004bfad2dc5d550d9..f9322fd1664aa6f32a3259456bb19a2e285859f3 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,10 @@ .PHONY: release release: - $(eval version:=$(shell cat VERSION)) - echo $(version) - git commit VERSION -m "Version $(version)" + # Grep the version from pyproject.toml, squeeze multiple spaces, delete double and single quotes, get 3rd val. + # This command tolerates multiple whitespace sequences around the version number. + $(eval version:=$(shell grep -m 1 version pyproject.toml | tr -s ' ' | tr -d '"' | tr -d "'" | cut -d' ' -f3)) + echo Releasing version $(version) + git commit pyproject.toml -m "Version $(version)" git tag $(version) git push origin master $(version) diff --git a/README.md b/README.md index b2687a19a5c327d190eb73194fc49c505b82a535..566ad5532d79b9cad5c11c856c0e2658bf98f8cd 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,12 @@ An easy to use Python 3 high level API client, to build ML tasks. +This is an open-source project, licensed using [the MIT license](https://opensource.org/license/mit/). + +## Documentation + +The [documentation](https://workers.arkindex.org/) is made with [Material for MkDocs](https://github.com/squidfunk/mkdocs-material) and is hosted by [GitLab Pages](https://docs.gitlab.com/ee/user/project/pages/). + ## Create a new worker using our template ``` diff --git a/VERSION b/VERSION deleted file mode 100644 index a16e482cecfd0c46a05a57eef1b7b9d467374de9..0000000000000000000000000000000000000000 --- a/VERSION +++ /dev/null @@ -1 +0,0 @@ -0.3.6-rc1 diff --git a/arkindex_worker/__init__.py b/arkindex_worker/__init__.py index b74e88890b314ce492612711659c84f01a5a3035..22615ee6f3929be5498aacad02530c8b5bd54b5b 100644 --- a/arkindex_worker/__init__.py +++ b/arkindex_worker/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import logging logging.basicConfig( diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index d887bd78bf09f7c2a9910199bb99e68f0a2dc2fc..c757a5ef2dfa13723ffa2cdc558ab2c27a4a470b 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ Database mappings and helper methods for the experimental worker caching feature. @@ -10,7 +9,6 @@ reducing network usage. import json import sqlite3 from pathlib import Path -from typing import Optional, Union from peewee import ( SQL, @@ -106,8 +104,8 @@ class CachedElement(Model): def open_image( self, *args, - max_width: Optional[int] = None, - max_height: Optional[int] = None, + max_width: int | None = None, + max_height: int | None = None, **kwargs, ) -> Image: """ @@ -145,17 +143,15 @@ class CachedElement(Model): if max_width is None and max_height is None: resize = "full" else: - # Do not resize for polygons that do not exactly match the images - # as the resize is made directly by the IIIF server using the box parameter if ( + # Do not resize for polygons that do not exactly match the images + # as the resize is made directly by the IIIF server using the box parameter bounding_box.width != self.image.width or bounding_box.height != self.image.height - ): - resize = "full" - - # Do not resize when the image is below the maximum size - elif (max_width is None or self.image.width <= max_width) and ( - max_height is None or self.image.height <= max_height + ) or ( + # Do not resize when the image is below the maximum size + (max_width is None or self.image.width <= max_width) + and (max_height is None or self.image.height <= max_height) ): resize = "full" else: @@ -319,22 +315,21 @@ def create_version_table(): Version.create(version=SQL_VERSION) -def check_version(cache_path: Union[str, Path]): +def check_version(cache_path: str | Path): """ Check the validity of the SQLite version :param cache_path: Path towards a local SQLite database """ - with SqliteDatabase(cache_path) as provided_db: - with provided_db.bind_ctx([Version]): - try: - version = Version.get().version - except OperationalError: - version = None + with SqliteDatabase(cache_path) as provided_db, provided_db.bind_ctx([Version]): + try: + version = Version.get().version + except OperationalError: + version = None - assert ( - version == SQL_VERSION - ), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}" + assert ( + version == SQL_VERSION + ), f"The SQLite database {cache_path} does not have the correct cache version, it should be {SQL_VERSION}" def merge_parents_cache(paths: list, current_database: Path): @@ -358,9 +353,8 @@ def merge_parents_cache(paths: list, current_database: Path): # Check that the parent cache uses a compatible version check_version(path) - with SqliteDatabase(path) as source: - with source.bind_ctx(MODELS): - source.create_tables(MODELS) + with SqliteDatabase(path) as source, source.bind_ctx(MODELS): + source.create_tables(MODELS) logger.info(f"Merging parent db {path} into {current_database}") statements = [ diff --git a/arkindex_worker/git.py b/arkindex_worker/git.py deleted file mode 100644 index f30b781898cd06b360ca9544126fd17056aae8c3..0000000000000000000000000000000000000000 --- a/arkindex_worker/git.py +++ /dev/null @@ -1,392 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Helper classes for workers that interact with Git repositories and the GitLab API. -""" -import shutil -import time -from datetime import datetime -from pathlib import Path -from typing import Optional, Union - -import gitlab -import requests -import sh -from gitlab.v4.objects import MergeRequest, ProjectMergeRequest - -from arkindex_worker import logger - -NOTHING_TO_COMMIT_MSG = "nothing to commit, working tree clean" -MR_HAS_CONFLICTS_ERROR_CODE = 406 - - -class GitlabHelper: - """Helper class to save files to GitLab repository""" - - def __init__( - self, - project_id: str, - gitlab_url: str, - gitlab_token: str, - branch: str, - rebase_wait_period: Optional[int] = 1, - delete_source_branch: Optional[bool] = True, - max_rebase_tries: Optional[int] = 10, - ): - """ - :param project_id: the id of the gitlab project - :param gitlab_url: gitlab server url - :param gitlab_token: gitlab private token of user with permission to accept merge requests - :param branch: name of the branch to where the exported branch will be merged - :param rebase_wait_period: seconds to wait between each poll to check whether rebase has finished - :param delete_source_branch: should delete the source branch after merging? - :param max_rebase_tries: max number of tries to rebase when merging before giving up - """ - self.project_id = project_id - self.gitlab_url = gitlab_url - self.gitlab_token = str(gitlab_token).strip() - self.branch = branch - self.rebase_wait_period = rebase_wait_period - self.delete_source_branch = delete_source_branch - self.max_rebase_tries = max_rebase_tries - - logger.info("Creating a Gitlab client") - self._api = gitlab.Gitlab(self.gitlab_url, private_token=self.gitlab_token) - self.project = self._api.projects.get(self.project_id) - self.is_rebase_finished = False - - def merge(self, branch_name: str, title: str) -> bool: - """ - Create a merge request and try to merge. - Always rebase first to avoid conflicts from MRs made in parallel - :param branch_name: Source branch name - :param title: Title of the merge request - :return: Whether the branch was successfully merged - """ - mr = None - # always rebase first, because other workers might have merged already - for i in range(self.max_rebase_tries): - logger.info(f"Trying to merge, try nr: {i}") - try: - if mr is None: - mr = self._create_merge_request(branch_name, title) - - mr.rebase() - rebase_success = self._wait_for_rebase_to_finish(mr.iid) - if not rebase_success: - logger.error("Rebase failed, won't be able to merge!") - return False - - mr.merge(should_remove_source_branch=self.delete_source_branch) - logger.info("Merge successful") - return True - except gitlab.GitlabMRClosedError as e: - if e.response_code == MR_HAS_CONFLICTS_ERROR_CODE: - logger.info("Merge failed, trying to rebase and merge again.") - continue - else: - logger.error(f"Merge was not successful: {e}") - return False - except gitlab.GitlabError as e: - logger.error(f"Gitlab error: {e}") - if 400 <= e.response_code < 500: - # 4XX errors shouldn't be fixed by retrying - raise e - except requests.exceptions.ConnectionError as e: - logger.error(f"Server connection error, will wait and retry: {e}") - time.sleep(self.rebase_wait_period) - - return False - - def _create_merge_request(self, branch_name: str, title: str) -> MergeRequest: - """ - Create a MergeRequest towards the branch with the given title - - :param branch_name: Target branch of the merge request - :param title: Title of the merge request - :return: The created merge request - """ - logger.info(f"Creating a merge request for {branch_name}") - # retry_transient_error will retry the request on 50X errors - # https://github.com/python-gitlab/python-gitlab/blob/265dbbdd37af88395574564aeb3fd0350288a18c/gitlab/__init__.py#L539 - mr = self.project.mergerequests.create( - { - "source_branch": branch_name, - "target_branch": self.branch, - "title": title, - }, - ) - return mr - - def _get_merge_request( - self, merge_request_id: Union[str, int], include_rebase_in_progress: bool = True - ) -> ProjectMergeRequest: - """ - Retrieve a merge request by ID - :param merge_request_id: The ID of the merge request - :param include_rebase_in_progress: Whether the rebase in progree should be included - :return: The related merge request - """ - return self.project.mergerequests.get( - merge_request_id, include_rebase_in_progress=include_rebase_in_progress - ) - - def _wait_for_rebase_to_finish(self, merge_request_id: Union[str, int]) -> bool: - """ - Poll the merge request until it has finished rebasing - :param merge_request_id: The ID of the merge request - :return: Whether the rebase has finished successfully - """ - - logger.info("Checking if rebase has finished..") - self.is_rebase_finished = False - while not self.is_rebase_finished: - time.sleep(self.rebase_wait_period) - mr = self._get_merge_request(merge_request_id) - self.is_rebase_finished = not mr.rebase_in_progress - if mr.merge_error is None: - logger.info("Rebase has finished") - return True - - logger.error(f"Rebase failed: {mr.merge_error}") - return False - - -def make_backup(path: str): - """ - Create a backup file in the same directory with timestamp as suffix ".bak_{timestamp}" - :param path: Path to the file to be backed up - """ - path = Path(path) - if not path.exists(): - raise ValueError(f"No file to backup! File not found: {path}") - # timestamp with milliseconds - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] - backup_path = Path(str(path) + f".bak_{timestamp}") - shutil.copy(path, backup_path) - logger.info(f"Made a backup {backup_path}") - - -def prepare_git_key( - private_key: str, - known_hosts: str, - private_key_path: Optional[str] = "~/.ssh/id_ed25519", - known_hosts_path: Optional[str] = "~/.ssh/known_hosts", -): - """ - Prepare the git keys (put them in to the correct place) so that git could be used. - Fixes some whitespace problems that come from arkindex secrets store (Django admin). - - Also creates a backup of the previous keys if they exist, to avoid losing the - original keys of the developers. - - :param private_key: git private key contents - :param known_hosts: git known_hosts contents - :param private_key_path: path where to put the private key - :param known_hosts_path: path where to put the known_hosts - """ - # secrets admin UI seems to strip the trailing whitespace - # but git requires the key file to have a new line at the end - # for some reason uses CRLF line endings, but git doesn't like that - private_key = private_key.replace("\r", "") + "\n" - known_hosts = known_hosts.replace("\r", "") + "\n" - - private_key_path = Path(private_key_path).expanduser() - known_hosts_path = Path(known_hosts_path).expanduser() - - if private_key_path.exists(): - if private_key_path.read_text() != private_key: - make_backup(private_key_path) - - if known_hosts_path.exists(): - if known_hosts_path.read_text() != known_hosts: - make_backup(known_hosts_path) - - private_key_path.write_text(private_key) - # private key must be private, otherwise git will fail - # expecting octal for permissions - private_key_path.chmod(0o600) - known_hosts_path.write_text(known_hosts) - - logger.info(f"Private key size after: {private_key_path.stat().st_size}") - logger.info(f"Known size after: {known_hosts_path.stat().st_size}") - - -class GitHelper: - """ - A helper class for running git commands - - At the beginning of the workflow call [run_clone_in_background][arkindex_worker.git.GitHelper.run_clone_in_background]. - When all the files are ready to be added to git then call - [save_files][arkindex_worker.git.GitHelper.save_files] to move the files in to the git repository - and try to push them. - - Examples - -------- - in worker.configure() configure the git helper and start the cloning: - ``` - gitlab = GitlabHelper(...) - prepare_git_key(...) - self.git_helper = GitHelper(workflow_id=workflow_id, gitlab_helper=gitlab, ...) - self.git_helper.run_clone_in_background() - ``` - - at the end of the workflow (at the end of worker.run()) push the files to git: - ``` - self.git_helper.save_files(self.out_dir) - ``` - """ - - def __init__( - self, - repo_url, - git_dir, - export_path, - workflow_id, - gitlab_helper: GitlabHelper, - git_clone_wait_period=1, - ): - """ - - :param repo_url: the url of the git repository where the export will be pushed - :param git_dir: the directory where to clone the git repository - :param export_path: the path inside the git repository where to put the exported files - :param workflow_id: the process id to see the workflow graph in the frontend - :param gitlab_helper: helper for gitlab - :param git_clone_wait_period: check if clone has finished every N seconds at the end of the workflow - """ - logger.info("Creating git helper") - self.repo_url = repo_url - self.git_dir = Path(git_dir) - self.export_path = self.git_dir / export_path - self.workflow_id = workflow_id - self.gitlab_helper = gitlab_helper - self.git_clone_wait_period = git_clone_wait_period - self.is_clone_finished = False - self.cmd = None - self.success = None - self.exit_code = None - - self.git_dir.mkdir(parents=True, exist_ok=True) - # run git commands outside of the repository (no need to change dir) - self._git = sh.git.bake("-C", self.git_dir) - - def _clone_done(self, cmd, success, exit_code): - """ - Method that is called when git clone has finished in the background - """ - logger.info("Finishing cloning") - self.cmd = cmd - self.success = success - self.exit_code = exit_code - self.is_clone_finished = True - if not success: - logger.error(f"Clone failed: {cmd} : {success} : {exit_code}") - logger.info("Cloning finished") - - def run_clone_in_background(self): - """ - Clones the git repository in the background in to the self.git_dir directory. - - `self.is_clone_finished` can be used to know whether the cloning has finished - or not. - """ - logger.info(f"Starting clone {self.repo_url} in background") - cmd = sh.git.clone( - self.repo_url, self.git_dir, _bg=True, _done=self._clone_done - ) - logger.info(f"Continuing clone {self.repo_url} in background") - return cmd - - def _wait_for_clone_to_finish(self): - logger.info("Checking if cloning has finished..") - while not self.is_clone_finished: - time.sleep(self.git_clone_wait_period) - logger.info("Cloning has finished") - - if not self.success: - logger.error("Clone was not a success") - logger.error(f"Clone error exit code: {str(self.exit_code)}") - raise ValueError("Clone was not a success") - - def save_files(self, export_out_dir: Path): - """ - Move files in export_out_dir to the cloned git repository - and try to merge the created files if possible. - :param export_out_dir: Path to the files to be saved - :raises sh.ErrorReturnCode: _description_ - :raises Exception: _description_ - """ - self._wait_for_clone_to_finish() - - # move exported files to git directory - file_count = self._move_files_to_git(export_out_dir) - - # use timestamp to avoid branch name conflicts with multiple chunks - current_timestamp = datetime.isoformat(datetime.now()) - # ":" is not allowed in a branch name - branch_timestamp = current_timestamp.replace(":", ".") - # add files to a new branch - branch_name = f"workflow_{self.workflow_id}_{branch_timestamp}" - self._git.checkout("-b", branch_name) - self._git.add("-A") - try: - self._git.commit( - "-m", - f"Exported files from workflow: {self.workflow_id} at {current_timestamp}", - ) - except sh.ErrorReturnCode as e: - if NOTHING_TO_COMMIT_MSG in str(e.stdout): - logger.warning("Nothing to commit (no changes)") - return - else: - logger.error(f"Commit failed:: {e}") - raise e - - # count the number of lines in the output - wc_cmd_out = str( - sh.wc(self._git.show("--stat", "--name-status", "--oneline", "HEAD"), "-l") - ) - # -1 because the of the git command header - files_committed = int(wc_cmd_out.strip()) - 1 - logger.info(f"Committed {files_committed} files") - if file_count != files_committed: - logger.warning( - f"Of {file_count} added files only {files_committed} were committed" - ) - - self._git.push("-u", "origin", "HEAD") - - if self.gitlab_helper: - try: - self.gitlab_helper.merge(branch_name, f"Merge {branch_name}") - except Exception as e: - logger.error(f"Merge failed: {e}") - raise e - else: - logger.info( - "No gitlab_helper defined, not trying to merge the pushed branch" - ) - - def _move_files_to_git(self, export_out_dir: Path) -> int: - """ - Move all files in the export_out_dir to the git repository - while keeping the same directory structure - :param export_out_dir: Path to the files to be moved - :return: Total count of moved files - """ - file_count = 0 - file_names = [ - file_name for file_name in export_out_dir.rglob("*") if file_name.is_file() - ] - for file in file_names: - rel_file_path = file.relative_to(export_out_dir) - out_file = self.export_path / rel_file_path - if not out_file.exists(): - out_file.parent.mkdir(parents=True, exist_ok=True) - # rename does not work if the source and destination are not on the same mounts - # it will give an error: "OSError: [Errno 18] Invalid cross-device link:" - shutil.copy(file, out_file) - file.unlink() - file_count += 1 - logger.info(f"Moved {file_count} files") - return file_count diff --git a/arkindex_worker/image.py b/arkindex_worker/image.py index 09cd953bfcd52eee48a73b12310b22ccd3763842..d289bed9685b7706304c243ff15ef92a959c42b9 100644 --- a/arkindex_worker/image.py +++ b/arkindex_worker/image.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ Helper methods to download and open IIIF images, and manage polygons. """ @@ -7,7 +6,7 @@ from collections import namedtuple from io import BytesIO from math import ceil from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING import requests from PIL import Image @@ -42,9 +41,9 @@ IIIF_MAX = "max" def open_image( path: str, - mode: Optional[str] = "RGB", - rotation_angle: Optional[int] = 0, - mirrored: Optional[bool] = False, + mode: str | None = "RGB", + rotation_angle: int | None = 0, + mirrored: bool | None = False, ) -> Image: """ Open an image from a path or a URL. @@ -71,7 +70,7 @@ def open_image( else: try: image = Image.open(path) - except (IOError, ValueError): + except (OSError, ValueError): image = download_image(path) if image.mode != mode: @@ -141,14 +140,14 @@ def download_image(url: str) -> Image: return image -def polygon_bounding_box(polygon: List[List[Union[int, float]]]) -> BoundingBox: +def polygon_bounding_box(polygon: list[list[int | float]]) -> BoundingBox: """ Compute the rectangle bounding box of a polygon. :param polygon: Polygon to get the bounding box of. :returns: Bounding box of this polygon. """ - x_coords, y_coords = zip(*polygon) + x_coords, y_coords = zip(*polygon, strict=True) x, y = min(x_coords), min(y_coords) width, height = max(x_coords) - x, max(y_coords) - y return BoundingBox(x, y, width, height) @@ -248,8 +247,8 @@ def download_tiles(url: str) -> Image: def trim_polygon( - polygon: List[List[int]], image_width: int, image_height: int -) -> List[List[int]]: + polygon: list[list[int]], image_width: int, image_height: int +) -> list[list[int]]: """ Trim a polygon to an image's boundaries, with non-negative coordinates. @@ -265,10 +264,10 @@ def trim_polygon( """ assert isinstance( - polygon, (list, tuple) + polygon, list | tuple ), "Input polygon must be a valid list or tuple of points." assert all( - isinstance(point, (list, tuple)) for point in polygon + isinstance(point, list | tuple) for point in polygon ), "Polygon points must be tuples or lists." assert all( len(point) == 2 for point in polygon @@ -301,10 +300,10 @@ def trim_polygon( def revert_orientation( - element: Union["Element", "CachedElement"], - polygon: List[List[Union[int, float]]], - reverse: Optional[bool] = False, -) -> List[List[int]]: + element: "Element | CachedElement", + polygon: list[list[int | float]], + reverse: bool = False, +) -> list[list[int]]: """ Update the coordinates of the polygon of a child element based on the orientation of its parent. @@ -324,7 +323,7 @@ def revert_orientation( from arkindex_worker.models import Element assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" assert polygon and isinstance( polygon, list diff --git a/arkindex_worker/models.py b/arkindex_worker/models.py index 98ca54a343951f6d68e136fe8edbeacb4ed2c496..50a9b923b64349ee9c5901191e18c90e7a8d38d8 100644 --- a/arkindex_worker/models.py +++ b/arkindex_worker/models.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- """ Wrappers around API results to provide more convenient attribute access and IIIF helpers. """ import tempfile +from collections.abc import Generator from contextlib import contextmanager -from typing import Generator, List, Optional from PIL import Image from requests import HTTPError @@ -34,10 +33,10 @@ class MagicDict(dict): def __getattr__(self, name): try: return self[name] - except KeyError: + except KeyError as e: raise AttributeError( - "{} object has no attribute '{}'".format(self.__class__.__name__, name) - ) + f"{self.__class__.__name__} object has no attribute '{name}'" + ) from e def __setattr__(self, name, value): return super().__setitem__(name, value) @@ -74,7 +73,7 @@ class Element(MagicDict): parts[-3] = size return "/".join(parts) - def image_url(self, size: str = "full") -> Optional[str]: + def image_url(self, size: str = "full") -> str | None: """ Build an URL to access the image. When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers. @@ -89,10 +88,10 @@ class Element(MagicDict): url = self.zone.image.url if not url.endswith("/"): url += "/" - return "{}full/{}/0/default.jpg".format(url, size) + return f"{url}full/{size}/0/default.jpg" @property - def polygon(self) -> List[float]: + def polygon(self) -> list[float]: """ Access an Element's polygon. This is a shortcut to an Element's polygon, normally accessed via @@ -101,7 +100,7 @@ class Element(MagicDict): the [CachedElement][arkindex_worker.cache.CachedElement].polygon field. """ if not self.get("zone"): - raise ValueError("Element {} has no zone".format(self.id)) + raise ValueError(f"Element {self.id} has no zone") return self.zone.polygon @property @@ -122,11 +121,11 @@ class Element(MagicDict): def open_image( self, *args, - max_width: Optional[int] = None, - max_height: Optional[int] = None, - use_full_image: Optional[bool] = False, + max_width: int | None = None, + max_height: int | None = None, + use_full_image: bool | None = False, **kwargs, - ) -> Image: + ) -> Image.Image: """ Open this element's image using Pillow, rotating and mirroring it according to the ``rotation_angle`` and ``mirrored`` attributes. @@ -173,7 +172,7 @@ class Element(MagicDict): ) if not self.get("zone"): - raise ValueError("Element {} has no zone".format(self.id)) + raise ValueError(f"Element {self.id} has no zone") if self.requires_tiles: if max_width is None and max_height is None: @@ -194,10 +193,7 @@ class Element(MagicDict): else: resize = f"{max_width or ''},{max_height or ''}" - if use_full_image: - url = self.image_url(resize) - else: - url = self.resize_zone_url(resize) + url = self.image_url(resize) if use_full_image else self.resize_zone_url(resize) try: return open_image( @@ -215,13 +211,13 @@ class Element(MagicDict): # This element uses an S3 URL: the URL may have expired. # Call the API to get a fresh URL and try again # TODO: this should be done by the worker - raise NotImplementedError + raise NotImplementedError from e return open_image(self.image_url(resize), *args, **kwargs) raise @contextmanager def open_image_tempfile( - self, format: Optional[str] = "jpeg", *args, **kwargs + self, format: str | None = "jpeg", *args, **kwargs ) -> Generator[tempfile.NamedTemporaryFile, None, None]: """ Get the element's image as a temporary file stored on the disk. @@ -249,7 +245,7 @@ class Element(MagicDict): type_name = self.type["display_name"] else: type_name = str(self.type) - return "{} {} ({})".format(type_name, self.name, self.id) + return f"{type_name} {self.name} ({self.id})" class ArkindexModel(MagicDict): diff --git a/arkindex_worker/utils.py b/arkindex_worker/utils.py index 1f5c54aa044ce0073bb71657a4499983597c04d1..34709c5f58cfe14a7c3ff6a3a96d5e2879cf7648 100644 --- a/arkindex_worker/utils.py +++ b/arkindex_worker/utils.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -*- import hashlib import logging import os import tarfile import tempfile from pathlib import Path -from typing import Optional, Tuple, Union import zstandard import zstandard as zstd @@ -16,7 +14,7 @@ CHUNK_SIZE = 1024 """Chunk Size used for ZSTD compression""" -def decompress_zst_archive(compressed_archive: Path) -> Tuple[int, Path]: +def decompress_zst_archive(compressed_archive: Path) -> tuple[int, Path]: """ Decompress a ZST-compressed tar archive in data dir. The tar archive is not extracted. This returns the path to the archive and the file descriptor. @@ -29,18 +27,19 @@ def decompress_zst_archive(compressed_archive: Path) -> Tuple[int, Path]: """ dctx = zstandard.ZstdDecompressor() archive_fd, archive_path = tempfile.mkstemp(suffix=".tar") + archive_path = Path(archive_path) logger.debug(f"Uncompressing file to {archive_path}") try: - with open(compressed_archive, "rb") as compressed, open( - archive_path, "wb" + with compressed_archive.open("rb") as compressed, archive_path.open( + "wb" ) as decompressed: dctx.copy_stream(compressed, decompressed) logger.debug(f"Successfully uncompressed archive {compressed_archive}") except zstandard.ZstdError as e: - raise Exception(f"Couldn't uncompressed archive: {e}") + raise Exception(f"Couldn't uncompressed archive: {e}") from e - return archive_fd, Path(archive_path) + return archive_fd, archive_path def extract_tar_archive(archive_path: Path, destination: Path): @@ -54,12 +53,12 @@ def extract_tar_archive(archive_path: Path, destination: Path): with tarfile.open(archive_path) as tar_archive: tar_archive.extractall(destination) except tarfile.ReadError as e: - raise Exception(f"Couldn't handle the decompressed Tar archive: {e}") + raise Exception(f"Couldn't handle the decompressed Tar archive: {e}") from e def extract_tar_zst_archive( compressed_archive: Path, destination: Path -) -> Tuple[int, Path]: +) -> tuple[int, Path]: """ Extract a ZST-compressed tar archive's content to a specific destination @@ -89,8 +88,8 @@ def close_delete_file(file_descriptor: int, file_path: Path): def zstd_compress( - source: Path, destination: Optional[Path] = None -) -> Tuple[Union[int, None], Path, str]: + source: Path, destination: Path | None = None +) -> tuple[int | None, Path, str]: """Compress a file using the Zstandard compression algorithm. :param source: Path to the file to compress. @@ -117,13 +116,13 @@ def zstd_compress( archive_file.write(compressed_chunk) logger.debug(f"Successfully compressed {source}") except zstandard.ZstdError as e: - raise Exception(f"Couldn't compress archive: {e}") + raise Exception(f"Couldn't compress archive: {e}") from e return file_d, destination, archive_hasher.hexdigest() def create_tar_archive( - path: Path, destination: Optional[Path] = None -) -> Tuple[Union[int, None], Path, str]: + path: Path, destination: Path | None = None +) -> tuple[int | None, Path, str]: """Create a tar archive using the content at specified location. :param path: Path to the file to archive @@ -153,7 +152,7 @@ def create_tar_archive( files.append(p) logger.debug(f"Successfully created Tar archive from files @ {path}") except tarfile.TarError as e: - raise Exception(f"Couldn't create Tar archive: {e}") + raise Exception(f"Couldn't create Tar archive: {e}") from e # Sort by path files.sort() @@ -168,8 +167,8 @@ def create_tar_archive( def create_tar_zst_archive( - source: Path, destination: Optional[Path] = None -) -> Tuple[Union[int, None], Path, str, str]: + source: Path, destination: Path | None = None +) -> tuple[int | None, Path, str, str]: """Helper to create a TAR+ZST archive from a source folder. :param source: Path to the folder whose content should be archived. diff --git a/arkindex_worker/worker/__init__.py b/arkindex_worker/worker/__init__.py index 66f13e559a6822b1f365e54d35c0a8a0f1f76038..25e239172a41f46b09ae1ccb51a47e3a193f25a8 100644 --- a/arkindex_worker/worker/__init__.py +++ b/arkindex_worker/worker/__init__.py @@ -1,17 +1,16 @@ -# -*- coding: utf-8 -*- """ Base classes to implement Arkindex workers. """ - +import contextlib import json import os import sys import uuid +from collections.abc import Iterable, Iterator from enum import Enum from itertools import groupby from operator import itemgetter from pathlib import Path -from typing import Iterable, Iterator, List, Tuple, Union from apistar.exceptions import ErrorResponse @@ -102,7 +101,7 @@ class ElementsWorker( self._worker_version_cache = {} - def list_elements(self) -> Union[Iterable[CachedElement], List[str]]: + def list_elements(self) -> Iterable[CachedElement] | list[str]: """ List the elements to be processed, either from the CLI arguments or the cache database when enabled. @@ -227,21 +226,17 @@ class ElementsWorker( ) if element: # Try to update the activity to error state regardless of the response - try: + with contextlib.suppress(Exception): self.update_activity(element.id, ActivityState.Error) - except Exception: - pass if failed: logger.error( - "Ran on {} elements: {} completed, {} failed".format( - count, count - failed, failed - ) + f"Ran on {count} elements: {count - failed} completed, {failed} failed" ) if failed >= count: # Everything failed! sys.exit(1) - def process_element(self, element: Union[Element, CachedElement]): + def process_element(self, element: Element | CachedElement): """ Override this method to implement your worker and process a single Arkindex element at once. @@ -251,7 +246,7 @@ class ElementsWorker( """ def update_activity( - self, element_id: Union[str, uuid.UUID], state: ActivityState + self, element_id: str | uuid.UUID, state: ActivityState ) -> bool: """ Update the WorkerActivity for this element and worker. @@ -269,7 +264,7 @@ class ElementsWorker( return True assert element_id and isinstance( - element_id, (uuid.UUID, str) + element_id, uuid.UUID | str ), "element_id shouldn't be null and should be an UUID or str" assert isinstance(state, ActivityState), "state should be an ActivityState" @@ -382,7 +377,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): def list_dataset_elements_per_split( self, dataset: Dataset - ) -> Iterator[Tuple[str, List[Element]]]: + ) -> Iterator[tuple[str, list[Element]]]: """ List the elements in the dataset, grouped by split, using the [list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method. @@ -392,8 +387,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): """ def format_split( - split: Tuple[str, Iterator[Tuple[str, Element]]] - ) -> Tuple[str, List[Element]]: + split: tuple[str, Iterator[tuple[str, Element]]], + ) -> tuple[str, list[Element]]: return (split[0], list(map(itemgetter(1), list(split[1])))) return map( @@ -435,7 +430,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): """ self.configure() - datasets: List[Dataset] | List[str] = list(self.list_datasets()) + datasets: list[Dataset] | list[str] = list(self.list_datasets()) if not datasets: logger.warning("No datasets to process, stopping.") sys.exit(1) @@ -445,6 +440,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): failed = 0 for i, item in enumerate(datasets, start=1): dataset = None + dataset_artifact = None + try: if not self.is_read_only: # Just use the result of list_datasets as the dataset @@ -470,7 +467,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): self.update_dataset_state(dataset, DatasetState.Building) else: logger.info(f"Downloading data for {dataset} ({i}/{count})") - self.download_dataset_artifact(dataset) + dataset_artifact = self.download_dataset_artifact(dataset) # Process the dataset self.process_dataset(dataset) @@ -499,16 +496,16 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin): ) if dataset and self.generator: # Try to update the state to Error regardless of the response - try: + with contextlib.suppress(Exception): self.update_dataset_state(dataset, DatasetState.Error) - except Exception: - pass + finally: + # Cleanup the dataset artifact if it was downloaded, no matter what + if dataset_artifact: + dataset_artifact.unlink(missing_ok=True) if failed: logger.error( - "Ran on {} datasets: {} completed, {} failed".format( - count, count - failed, failed - ) + f"Ran on {count} datasets: {count - failed} completed, {failed} failed" ) if failed >= count: # Everything failed! sys.exit(1) diff --git a/arkindex_worker/worker/base.py b/arkindex_worker/worker/base.py index 461d168ee9418d1638b1e83a683a31736f8f21ef..573f1bd1dcb9e39e41264ef1d218d4c38ac9f5b7 100644 --- a/arkindex_worker/worker/base.py +++ b/arkindex_worker/worker/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ The base class for all Arkindex workers. """ @@ -9,7 +8,6 @@ import os import shutil from pathlib import Path from tempfile import mkdtemp -from typing import List, Optional import gnupg import yaml @@ -52,15 +50,15 @@ class ExtrasDirNotFoundError(Exception): """ -class BaseWorker(object): +class BaseWorker: """ Base class for Arkindex workers. """ def __init__( self, - description: Optional[str] = "Arkindex Base Worker", - support_cache: Optional[bool] = False, + description: str | None = "Arkindex Base Worker", + support_cache: bool | None = False, ): """ Initialize the worker. @@ -217,6 +215,9 @@ class BaseWorker(object): # Define model_version_id from environment self.model_version_id = os.environ.get("ARKINDEX_MODEL_VERSION_ID") + # Define model_details from environment + self.model_details = {"id": os.environ.get("ARKINDEX_MODEL_ID")} + # Load all required secrets self.secrets = {name: self.load_secret(Path(name)) for name in required_secrets} @@ -259,6 +260,9 @@ class BaseWorker(object): # Set model_version ID as worker attribute self.model_version_id = model_version.get("id") + # Set model details as worker attribute + self.model_details = model_version.get("model") + # Retrieve initial configuration from API self.config = worker_version["configuration"].get("configuration", {}) if "user_configuration" in worker_version["configuration"]: @@ -347,7 +351,8 @@ class BaseWorker(object): try: gpg = gnupg.GPG() - decrypted = gpg.decrypt_file(open(path, "rb")) + with path.open("rb") as gpg_file: + decrypted = gpg.decrypt_file(gpg_file) assert ( decrypted.ok ), f"GPG error: {decrypted.status} - {decrypted.stderr}" @@ -406,7 +411,7 @@ class BaseWorker(object): ) return extras_dir - def find_parents_file_paths(self, filename: Path) -> List[Path]: + def find_parents_file_paths(self, filename: Path) -> list[Path]: """ Find the paths of a specific file from the parent tasks. Only works if the task_parents attributes is updated, so if the cache is supported, diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py index e28cefe7604acc525cea6b9824de683bf222df4e..37f42dd894b5c390361ca31b771071685e138d28 100644 --- a/arkindex_worker/worker/classification.py +++ b/arkindex_worker/worker/classification.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- """ ElementsWorker methods for classifications and ML classes. """ -from typing import Dict, List, Optional, Union from uuid import UUID from apistar.exceptions import ErrorResponse @@ -14,7 +12,7 @@ from arkindex_worker.cache import CachedClassification, CachedElement from arkindex_worker.models import Element -class ClassificationMixin(object): +class ClassificationMixin: def load_corpus_classes(self): """ Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache. @@ -91,11 +89,11 @@ class ClassificationMixin(object): def create_classification( self, - element: Union[Element, CachedElement], + element: Element | CachedElement, ml_class: str, confidence: float, - high_confidence: Optional[bool] = False, - ) -> Dict[str, str]: + high_confidence: bool = False, + ) -> dict[str, str]: """ Create a classification on the given element through the API. @@ -106,7 +104,7 @@ class ClassificationMixin(object): :returns: The created classification, as returned by the ``CreateClassification`` API endpoint. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" assert ml_class and isinstance( ml_class, str @@ -180,9 +178,9 @@ class ClassificationMixin(object): def create_classifications( self, - element: Union[Element, CachedElement], - classifications: List[Dict[str, Union[str, float, bool]]], - ) -> List[Dict[str, Union[str, float, bool]]]: + element: Element | CachedElement, + classifications: list[dict[str, str | float | bool]], + ) -> list[dict[str, str | float | bool]]: """ Create multiple classifications at once on the given element through the API. @@ -196,7 +194,7 @@ class ClassificationMixin(object): the ``CreateClassifications`` API endpoint. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" assert classifications and isinstance( classifications, list @@ -204,17 +202,17 @@ class ClassificationMixin(object): for index, classification in enumerate(classifications): ml_class_id = classification.get("ml_class_id") - assert ml_class_id and isinstance( - ml_class_id, str + assert ( + ml_class_id and isinstance(ml_class_id, str) ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str" # Make sure it's a valid UUID try: UUID(ml_class_id) - except ValueError: + except ValueError as e: raise ValueError( f"Classification at index {index} in classifications: ml_class_id is not a valid uuid." - ) + ) from e confidence = classification.get("confidence") assert ( diff --git a/arkindex_worker/worker/dataset.py b/arkindex_worker/worker/dataset.py index 34caceaac7de283b24e0b7e003da0f15c94cb3a7..be088ef422d33aeab9a26d26a7cee1fd1461fa21 100644 --- a/arkindex_worker/worker/dataset.py +++ b/arkindex_worker/worker/dataset.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- """ BaseWorker methods for datasets. """ +from collections.abc import Iterator from enum import Enum -from typing import Iterator, Tuple from arkindex_worker import logger from arkindex_worker.models import Dataset, Element @@ -36,7 +35,7 @@ class DatasetState(Enum): """ -class DatasetMixin(object): +class DatasetMixin: def list_process_datasets(self) -> Iterator[Dataset]: """ List datasets associated to the worker's process. This helper is not available in developer mode. @@ -51,7 +50,7 @@ class DatasetMixin(object): return map(Dataset, list(results)) - def list_dataset_elements(self, dataset: Dataset) -> Iterator[Tuple[str, Element]]: + def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]: """ List elements in a dataset. diff --git a/arkindex_worker/worker/element.py b/arkindex_worker/worker/element.py index 1effe4a99e95605b0e71315a635f4dc2f3402a08..a1f09849af8066630b3a8e496b716dc4d3269b69 100644 --- a/arkindex_worker/worker/element.py +++ b/arkindex_worker/worker/element.py @@ -1,8 +1,8 @@ -# -*- coding: utf-8 -*- """ ElementsWorker methods for elements and element types. """ -from typing import Dict, Iterable, List, NamedTuple, Optional, Union +from collections.abc import Iterable +from typing import NamedTuple from uuid import UUID from peewee import IntegrityError @@ -28,8 +28,8 @@ class MissingTypeError(Exception): """ -class ElementMixin(object): - def create_required_types(self, element_types: List[ElementType]): +class ElementMixin: + def create_required_types(self, element_types: list[ElementType]): """Creates given element types in the corpus. :param element_types: The missing element types to create. @@ -86,9 +86,10 @@ class ElementMixin(object): element: Element, type: str, name: str, - polygon: List[List[Union[int, float]]], - confidence: Optional[float] = None, - slim_output: Optional[bool] = True, + polygon: list[list[int | float]] | None = None, + confidence: float | None = None, + image: str | None = None, + slim_output: bool = True, ) -> str: """ Create a child element on the given element through the API. @@ -96,8 +97,10 @@ class ElementMixin(object): :param Element element: The parent element. :param type: Slug of the element type for this child element. :param name: Name of the child element. - :param polygon: Polygon of the child element. + :param polygon: Optional polygon of the child element. :param confidence: Optional confidence score, between 0.0 and 1.0. + :param image: Optional image ID of the child element. + :param slim_output: Whether to return the child ID or the full child. :returns: UUID of the created element. """ assert element and isinstance( @@ -109,19 +112,29 @@ class ElementMixin(object): assert name and isinstance( name, str ), "name shouldn't be null and should be of type str" - assert polygon and isinstance( + assert polygon is None or isinstance( polygon, list - ), "polygon shouldn't be null and should be of type list" - assert len(polygon) >= 3, "polygon should have at least three points" - assert all( - isinstance(point, list) and len(point) == 2 for point in polygon - ), "polygon points should be lists of two items" - assert all( - isinstance(coord, (int, float)) for point in polygon for coord in point - ), "polygon points should be lists of two numbers" + ), "polygon should be None or a list" + if polygon is not None: + assert len(polygon) >= 3, "polygon should have at least three points" + assert all( + isinstance(point, list) and len(point) == 2 for point in polygon + ), "polygon points should be lists of two items" + assert all( + isinstance(coord, int | float) for point in polygon for coord in point + ), "polygon points should be lists of two numbers" assert confidence is None or ( isinstance(confidence, float) and 0 <= confidence <= 1 ), "confidence should be None or a float in [0..1] range" + assert image is None or isinstance(image, str), "image should be None or string" + if image is not None: + # Make sure it's a valid UUID + try: + UUID(image) + except ValueError as e: + raise ValueError("image is not a valid uuid.") from e + if polygon and image is None: + assert element.zone, "An image or a parent with an image is required to create an element with a polygon." assert isinstance(slim_output, bool), "slim_output should be of type bool" if self.is_read_only: @@ -133,7 +146,7 @@ class ElementMixin(object): body={ "type": type, "name": name, - "image": element.zone.image.id, + "image": image, "corpus": element.corpus.id, "polygon": polygon, "parent": element.id, @@ -146,11 +159,9 @@ class ElementMixin(object): def create_elements( self, - parent: Union[Element, CachedElement], - elements: List[ - Dict[str, Union[str, List[List[Union[int, float]]], float, None]] - ], - ) -> List[Dict[str, str]]: + parent: Element | CachedElement, + elements: list[dict[str, str | list[list[int | float]] | float | None]], + ) -> list[dict[str, str]]: """ Create child elements on the given element in a single API request. @@ -195,18 +206,18 @@ class ElementMixin(object): ), f"Element at index {index} in elements: Should be of type dict" name = element.get("name") - assert name and isinstance( - name, str + assert ( + name and isinstance(name, str) ), f"Element at index {index} in elements: name shouldn't be null and should be of type str" type = element.get("type") - assert type and isinstance( - type, str + assert ( + type and isinstance(type, str) ), f"Element at index {index} in elements: type shouldn't be null and should be of type str" polygon = element.get("polygon") - assert polygon and isinstance( - polygon, list + assert ( + polygon and isinstance(polygon, list) ), f"Element at index {index} in elements: polygon shouldn't be null and should be of type list" assert ( len(polygon) >= 3 @@ -215,12 +226,13 @@ class ElementMixin(object): isinstance(point, list) and len(point) == 2 for point in polygon ), f"Element at index {index} in elements: polygon points should be lists of two items" assert all( - isinstance(coord, (int, float)) for point in polygon for coord in point + isinstance(coord, int | float) for point in polygon for coord in point ), f"Element at index {index} in elements: polygon points should be lists of two numbers" confidence = element.get("confidence") - assert confidence is None or ( - isinstance(confidence, float) and 0 <= confidence <= 1 + assert ( + confidence is None + or (isinstance(confidence, float) and 0 <= confidence <= 1) ), f"Element at index {index} in elements: confidence should be None or a float in [0..1] range" if self.is_read_only: @@ -271,8 +283,37 @@ class ElementMixin(object): return created_ids + def create_element_parent( + self, + parent: Element, + child: Element, + ) -> dict[str, str]: + """ + Link an element to a parent through the API. + + :param parent: Parent element. + :param child: Child element. + :returns: A dict from the ``CreateElementParent`` API endpoint. + """ + assert parent and isinstance( + parent, Element + ), "parent shouldn't be null and should be of type Element" + assert child and isinstance( + child, Element + ), "child shouldn't be null and should be of type Element" + + if self.is_read_only: + logger.warning("Cannot link elements as this worker is in read-only mode") + return + + return self.request( + "CreateElementParent", + parent=parent.id, + child=child.id, + ) + def partial_update_element( - self, element: Union[Element, CachedElement], **kwargs + self, element: Element | CachedElement, **kwargs ) -> dict: """ Partially updates an element through the API. @@ -289,10 +330,10 @@ class ElementMixin(object): * *image* (``UUID``): Optional ID of the image of this element - :returns: A dict from the ``PartialUpdateElement`` API endpoint, + :returns: A dict from the ``PartialUpdateElement`` API endpoint. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" if "type" in kwargs: @@ -309,7 +350,7 @@ class ElementMixin(object): isinstance(point, list) and len(point) == 2 for point in polygon ), "polygon points should be lists of two items" assert all( - isinstance(coord, (int, float)) for point in polygon for coord in point + isinstance(coord, int | float) for point in polygon for coord in point ), "polygon points should be lists of two numbers" if "confidence" in kwargs: @@ -363,21 +404,21 @@ class ElementMixin(object): def list_element_children( self, - element: Union[Element, CachedElement], - folder: Optional[bool] = None, - name: Optional[str] = None, - recursive: Optional[bool] = None, - transcription_worker_version: Optional[Union[str, bool]] = None, - transcription_worker_run: Optional[Union[str, bool]] = None, - type: Optional[str] = None, - with_classes: Optional[bool] = None, - with_corpus: Optional[bool] = None, - with_metadata: Optional[bool] = None, - with_has_children: Optional[bool] = None, - with_zone: Optional[bool] = None, - worker_version: Optional[Union[str, bool]] = None, - worker_run: Optional[Union[str, bool]] = None, - ) -> Union[Iterable[dict], Iterable[CachedElement]]: + element: Element | CachedElement, + folder: bool | None = None, + name: str | None = None, + recursive: bool | None = None, + transcription_worker_version: str | bool | None = None, + transcription_worker_run: str | bool | None = None, + type: str | None = None, + with_classes: bool | None = None, + with_corpus: bool | None = None, + with_metadata: bool | None = None, + with_has_children: bool | None = None, + with_zone: bool | None = None, + worker_version: str | bool | None = None, + worker_run: str | bool | None = None, + ) -> Iterable[dict] | Iterable[CachedElement]: """ List children of an element. @@ -412,7 +453,7 @@ class ElementMixin(object): or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" query_params = {} if folder is not None: @@ -426,7 +467,7 @@ class ElementMixin(object): query_params["recursive"] = recursive if transcription_worker_version is not None: assert isinstance( - transcription_worker_version, (str, bool) + transcription_worker_version, str | bool ), "transcription_worker_version should be of type str or bool" if isinstance(transcription_worker_version, bool): assert ( @@ -435,7 +476,7 @@ class ElementMixin(object): query_params["transcription_worker_version"] = transcription_worker_version if transcription_worker_run is not None: assert isinstance( - transcription_worker_run, (str, bool) + transcription_worker_run, str | bool ), "transcription_worker_run should be of type str or bool" if isinstance(transcription_worker_run, bool): assert ( @@ -466,7 +507,7 @@ class ElementMixin(object): query_params["with_zone"] = with_zone if worker_version is not None: assert isinstance( - worker_version, (str, bool) + worker_version, str | bool ), "worker_version should be of type str or bool" if isinstance(worker_version, bool): assert ( @@ -475,7 +516,7 @@ class ElementMixin(object): query_params["worker_version"] = worker_version if worker_run is not None: assert isinstance( - worker_run, (str, bool) + worker_run, str | bool ), "worker_run should be of type str or bool" if isinstance(worker_run, bool): assert ( @@ -485,11 +526,14 @@ class ElementMixin(object): if self.use_cache: # Checking that we only received query_params handled by the cache - assert set(query_params.keys()) <= { - "type", - "worker_version", - "worker_run", - }, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" + assert ( + set(query_params.keys()) + <= { + "type", + "worker_version", + "worker_run", + } + ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" query = CachedElement.select().where(CachedElement.parent_id == element.id) if type: @@ -522,21 +566,21 @@ class ElementMixin(object): def list_element_parents( self, - element: Union[Element, CachedElement], - folder: Optional[bool] = None, - name: Optional[str] = None, - recursive: Optional[bool] = None, - transcription_worker_version: Optional[Union[str, bool]] = None, - transcription_worker_run: Optional[Union[str, bool]] = None, - type: Optional[str] = None, - with_classes: Optional[bool] = None, - with_corpus: Optional[bool] = None, - with_metadata: Optional[bool] = None, - with_has_children: Optional[bool] = None, - with_zone: Optional[bool] = None, - worker_version: Optional[Union[str, bool]] = None, - worker_run: Optional[Union[str, bool]] = None, - ) -> Union[Iterable[dict], Iterable[CachedElement]]: + element: Element | CachedElement, + folder: bool | None = None, + name: str | None = None, + recursive: bool | None = None, + transcription_worker_version: str | bool | None = None, + transcription_worker_run: str | bool | None = None, + type: str | None = None, + with_classes: bool | None = None, + with_corpus: bool | None = None, + with_metadata: bool | None = None, + with_has_children: bool | None = None, + with_zone: bool | None = None, + worker_version: str | bool | None = None, + worker_run: str | bool | None = None, + ) -> Iterable[dict] | Iterable[CachedElement]: """ List parents of an element. @@ -571,7 +615,7 @@ class ElementMixin(object): or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" query_params = {} if folder is not None: @@ -585,7 +629,7 @@ class ElementMixin(object): query_params["recursive"] = recursive if transcription_worker_version is not None: assert isinstance( - transcription_worker_version, (str, bool) + transcription_worker_version, str | bool ), "transcription_worker_version should be of type str or bool" if isinstance(transcription_worker_version, bool): assert ( @@ -594,7 +638,7 @@ class ElementMixin(object): query_params["transcription_worker_version"] = transcription_worker_version if transcription_worker_run is not None: assert isinstance( - transcription_worker_run, (str, bool) + transcription_worker_run, str | bool ), "transcription_worker_run should be of type str or bool" if isinstance(transcription_worker_run, bool): assert ( @@ -625,7 +669,7 @@ class ElementMixin(object): query_params["with_zone"] = with_zone if worker_version is not None: assert isinstance( - worker_version, (str, bool) + worker_version, str | bool ), "worker_version should be of type str or bool" if isinstance(worker_version, bool): assert ( @@ -634,7 +678,7 @@ class ElementMixin(object): query_params["worker_version"] = worker_version if worker_run is not None: assert isinstance( - worker_run, (str, bool) + worker_run, str | bool ), "worker_run should be of type str or bool" if isinstance(worker_run, bool): assert ( @@ -644,11 +688,14 @@ class ElementMixin(object): if self.use_cache: # Checking that we only received query_params handled by the cache - assert set(query_params.keys()) <= { - "type", - "worker_version", - "worker_run", - }, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" + assert ( + set(query_params.keys()) + <= { + "type", + "worker_version", + "worker_run", + } + ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'" parent_ids = CachedElement.select(CachedElement.parent_id).where( CachedElement.id == element.id diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py index 539cc94844d7984ba377fc18901a5ac652f7ee52..bb52291c82b3b3805f82b0f0a2f08ee4a049adfa 100644 --- a/arkindex_worker/worker/entity.py +++ b/arkindex_worker/worker/entity.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- """ ElementsWorker methods for entities. """ from operator import itemgetter -from typing import Dict, List, Optional, TypedDict, Union +from typing import TypedDict from peewee import IntegrityError @@ -12,16 +11,13 @@ from arkindex_worker import logger from arkindex_worker.cache import CachedEntity, CachedTranscriptionEntity from arkindex_worker.models import Element, Transcription -Entity = TypedDict( - "Entity", - { - "name": str, - "type_id": str, - "length": int, - "offset": int, - "confidence": Optional[float], - }, -) + +class Entity(TypedDict): + name: str + type_id: str + length: int + offset: int + confidence: float | None class MissingEntityType(Exception): @@ -31,9 +27,9 @@ class MissingEntityType(Exception): """ -class EntityMixin(object): +class EntityMixin: def check_required_entity_types( - self, entity_types: List[str], create_missing: bool = True + self, entity_types: list[str], create_missing: bool = True ): """Checks that every entity type needed is available in the corpus. Missing ones may be created automatically if needed. @@ -71,7 +67,7 @@ class EntityMixin(object): self, name: str, type: str, - metas=dict(), + metas=None, validated=None, ): """ @@ -87,6 +83,7 @@ class EntityMixin(object): assert type and isinstance( type, str ), "type shouldn't be null and should be of type str" + metas = metas or {} if metas: assert isinstance(metas, dict), "metas should be of type dict" if validated is not None: @@ -140,8 +137,8 @@ class EntityMixin(object): entity: str, offset: int, length: int, - confidence: Optional[float] = None, - ) -> Optional[Dict[str, Union[str, int]]]: + confidence: float | None = None, + ) -> dict[str, str | int] | None: """ Create a link between an existing entity and an existing transcription. If cache support is enabled, a `CachedTranscriptionEntity` will also be created. @@ -211,8 +208,8 @@ class EntityMixin(object): def create_transcription_entities( self, transcription: Transcription, - entities: List[Entity], - ) -> List[Dict[str, str]]: + entities: list[Entity], + ) -> list[dict[str, str]]: """ Create multiple entities attached to a transcription in a single API request. @@ -250,13 +247,13 @@ class EntityMixin(object): ), f"Entity at index {index} in entities: Should be of type dict" name = entity.get("name") - assert name and isinstance( - name, str + assert ( + name and isinstance(name, str) ), f"Entity at index {index} in entities: name shouldn't be null and should be of type str" type_id = entity.get("type_id") - assert type_id and isinstance( - type_id, str + assert ( + type_id and isinstance(type_id, str) ), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str" offset = entity.get("offset") @@ -270,8 +267,9 @@ class EntityMixin(object): ), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer" confidence = entity.get("confidence") - assert confidence is None or ( - isinstance(confidence, float) and 0 <= confidence <= 1 + assert ( + confidence is None + or (isinstance(confidence, float) and 0 <= confidence <= 1) ), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range" assert len(entities) == len( @@ -298,7 +296,7 @@ class EntityMixin(object): def list_transcription_entities( self, transcription: Transcription, - worker_version: Optional[Union[str, bool]] = None, + worker_version: str | bool | None = None, ): """ List existing entities on a transcription @@ -314,7 +312,7 @@ class EntityMixin(object): if worker_version is not None: assert isinstance( - worker_version, (str, bool) + worker_version, str | bool ), "worker_version should be of type str or bool" if isinstance(worker_version, bool): @@ -329,12 +327,11 @@ class EntityMixin(object): def list_corpus_entities( self, - name: Optional[str] = None, - parent: Optional[Element] = None, + name: str | None = None, + parent: Element | None = None, ): """ - List all entities in the worker's corpus - This method does not support cache + List all entities in the worker's corpus and store them in the ``self.entities`` cache. :param name: Filter entities by part of their name (case-insensitive) :param parent: Restrict entities to those linked to all transcriptions of an element and all its descendants. Note that links to metadata are ignored. """ @@ -348,8 +345,14 @@ class EntityMixin(object): assert isinstance(parent, Element), "parent should be of type Element" query_params["parent"] = parent.id - return self.api_client.paginate( - "ListCorpusEntities", id=self.corpus_id, **query_params + self.entities = { + entity["id"]: entity + for entity in self.api_client.paginate( + "ListCorpusEntities", id=self.corpus_id, **query_params + ) + } + logger.info( + f"Loaded {len(self.entities)} entities in corpus ({self.corpus_id})" ) def list_corpus_entity_types( diff --git a/arkindex_worker/worker/metadata.py b/arkindex_worker/worker/metadata.py index e2c6f46b42cad1193843caa3dc3ee991d18c40fd..136da1547836151351dab41a34269001b40a4d82 100644 --- a/arkindex_worker/worker/metadata.py +++ b/arkindex_worker/worker/metadata.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- """ ElementsWorker methods for metadata. """ from enum import Enum -from typing import Dict, List, Optional, Union from arkindex_worker import logger from arkindex_worker.cache import CachedElement @@ -57,14 +55,14 @@ class MetaType(Enum): """ -class MetaDataMixin(object): +class MetaDataMixin: def create_metadata( self, - element: Union[Element, CachedElement], + element: Element | CachedElement, type: MetaType, name: str, value: str, - entity: Optional[str] = None, + entity: str | None = None, ) -> str: """ Create a metadata on the given element through API. @@ -77,7 +75,7 @@ class MetaDataMixin(object): :returns: UUID of the created metadata. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be of type Element or CachedElement" assert type and isinstance( type, MetaType @@ -110,26 +108,22 @@ class MetaDataMixin(object): def create_metadatas( self, - element: Union[Element, CachedElement], - metadatas: List[ - Dict[ - str, Union[MetaType, str, Union[str, Union[int, float]], Optional[str]] - ] - ], - ) -> List[Dict[str, str]]: + element: Element | CachedElement, + metadatas: list[dict[str, MetaType | str | int | float | None]], + ) -> list[dict[str, str]]: """ - Create multiple metadatas on an existing element. + Create multiple metadata on an existing element. This method does not support cache. :param element: The element to create multiple metadata on. :param metadatas: The list of dict whose keys are the following: - - type : MetaType - - name : str - - value : Union[str, Union[int, float]] - - entity_id : Union[str, None] + - type: MetaType + - name: str + - value: str | int | float + - entity_id: str | None """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be of type Element or CachedElement" assert metadatas and isinstance( @@ -152,7 +146,7 @@ class MetaDataMixin(object): ), "name shouldn't be null and should be of type str" assert metadata.get("value") is not None and isinstance( - metadata.get("value"), (str, float, int) + metadata.get("value"), str | float | int ), "value shouldn't be null and should be of type (str or float or int)" assert metadata.get("entity_id") is None or isinstance( @@ -172,7 +166,7 @@ class MetaDataMixin(object): logger.warning("Cannot create metadata as this worker is in read-only mode") return - created_metadatas = self.request( + created_metadata_list = self.request( "CreateMetaDataBulk", id=element.id, body={ @@ -181,11 +175,11 @@ class MetaDataMixin(object): }, )["metadata_list"] - return created_metadatas + return created_metadata_list def list_element_metadata( - self, element: Union[Element, CachedElement] - ) -> List[Dict[str, str]]: + self, element: Element | CachedElement + ) -> list[dict[str, str]]: """ List all metadata linked to an element. This method does not support cache. @@ -193,7 +187,7 @@ class MetaDataMixin(object): :param element: The element to list metadata on. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be of type Element or CachedElement" return self.api_client.paginate("ListElementMetaData", id=element.id) diff --git a/arkindex_worker/worker/task.py b/arkindex_worker/worker/task.py index e4c2f6e23caa800a9eda6a76b758de262fbf9e53..4a19b17b41a90dcd9f332c11dc4be073ced383ee 100644 --- a/arkindex_worker/worker/task.py +++ b/arkindex_worker/worker/task.py @@ -1,17 +1,16 @@ -# -*- coding: utf-8 -*- """ BaseWorker methods for tasks. """ import uuid -from typing import Iterator +from collections.abc import Iterator from apistar.compat import DownloadedFile from arkindex_worker.models import Artifact -class TaskMixin(object): +class TaskMixin: def list_artifacts(self, task_id: uuid.UUID) -> Iterator[Artifact]: """ List artifacts associated to a task. diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index 5e08e5048766d15f26ba456cfd7346233356d4b1..ebc6fd429d688849c3a6ef2216b8cdb61a7deaf3 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ BaseWorker methods for training. """ @@ -6,7 +5,7 @@ BaseWorker methods for training. import functools from contextlib import contextmanager from pathlib import Path -from typing import NewType, Optional, Tuple, Union +from typing import NewType from uuid import UUID import requests @@ -26,7 +25,7 @@ FileSize = NewType("FileSize", int) @contextmanager -def create_archive(path: DirPath) -> Tuple[Path, Hash, FileSize, Hash]: +def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]: """ Create a tar archive from the files at the given location then compress it to a zst archive. @@ -72,7 +71,7 @@ def skip_if_read_only(func): return wrapper -class TrainingMixin(object): +class TrainingMixin: """ A mixin helper to create a new model version easily. You may use `publish_model_version` to publish a ready model version directly, or @@ -87,10 +86,10 @@ class TrainingMixin(object): self, model_path: DirPath, model_id: str, - tag: Optional[str] = None, - description: Optional[str] = None, - configuration: Optional[dict] = {}, - parent: Optional[Union[str, UUID]] = None, + tag: str | None = None, + description: str | None = None, + configuration: dict | None = None, + parent: str | UUID | None = None, ): """ Publish a unique version of a model in Arkindex, identified by its hash. @@ -105,6 +104,7 @@ class TrainingMixin(object): :param parent: ID of the parent model version """ + configuration = configuration or {} if not self.model_version: self.create_model_version( model_id=model_id, @@ -161,10 +161,10 @@ class TrainingMixin(object): def create_model_version( self, model_id: str, - tag: Optional[str] = None, - description: Optional[str] = None, - configuration: Optional[dict] = {}, - parent: Optional[Union[str, UUID]] = None, + tag: str | None = None, + description: str | None = None, + configuration: dict | None = None, + parent: str | UUID | None = None, ): """ Create a new version of the specified model with its base attributes. @@ -176,6 +176,8 @@ class TrainingMixin(object): :param parent: ID of the parent model version """ assert not self.model_version, "A model version has already been created." + + configuration = configuration or {} self.model_version = self.request( "CreateModelVersion", id=model_id, @@ -186,6 +188,7 @@ class TrainingMixin(object): parent=parent, ), ) + logger.info( f"Model version ({self.model_version['id']}) was successfully created" ) @@ -193,10 +196,10 @@ class TrainingMixin(object): @skip_if_read_only def update_model_version( self, - tag: Optional[str] = None, - description: Optional[str] = None, - configuration: Optional[dict] = None, - parent: Optional[Union[str, UUID]] = None, + tag: str | None = None, + description: str | None = None, + configuration: dict | None = None, + parent: str | UUID | None = None, ): """ Update the current model version with the given attributes. @@ -235,9 +238,7 @@ class TrainingMixin(object): ), "The model is already marked as available." s3_put_url = self.model_version.get("s3_put_url") - assert ( - s3_put_url - ), "S3 PUT URL is not set, please ensure you have the right to validate a model version." + assert s3_put_url, "S3 PUT URL is not set, please ensure you have the right to validate a model version." logger.info("Uploading to s3...") # Upload the archive on s3 @@ -263,9 +264,7 @@ class TrainingMixin(object): :param size: The size of the uploaded archive :param archive_hash: MD5 hash of the uploaded archive """ - assert ( - self.model_version - ), "You must create the model version and upload its archive before validating it." + assert self.model_version, "You must create the model version and upload its archive before validating it." try: self.model_version = self.request( "ValidateModelVersion", diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py index 7ce96c689596486bdcdab3251d4a73c3c6877a9c..ec3781a925a0da6003c7f7bcf1d734a8c87c1d55 100644 --- a/arkindex_worker/worker/transcription.py +++ b/arkindex_worker/worker/transcription.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- """ ElementsWorker methods for transcriptions. """ +from collections.abc import Iterable from enum import Enum -from typing import Dict, Iterable, List, Optional, Union from peewee import IntegrityError @@ -40,14 +39,14 @@ class TextOrientation(Enum): """ -class TranscriptionMixin(object): +class TranscriptionMixin: def create_transcription( self, - element: Union[Element, CachedElement], + element: Element | CachedElement, text: str, confidence: float, orientation: TextOrientation = TextOrientation.HorizontalLeftToRight, - ) -> Optional[Dict[str, Union[str, float]]]: + ) -> dict[str, str | float] | None: """ Create a transcription on the given element through the API. @@ -59,7 +58,7 @@ class TranscriptionMixin(object): or None if the worker is in read-only mode. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" assert text and isinstance( text, str @@ -111,8 +110,8 @@ class TranscriptionMixin(object): def create_transcriptions( self, - transcriptions: List[Dict[str, Union[str, float, Optional[TextOrientation]]]], - ) -> List[Dict[str, Union[str, float]]]: + transcriptions: list[dict[str, str | float | TextOrientation | None]], + ) -> list[dict[str, str | float]]: """ Create multiple transcriptions at once on existing elements through the API, and creates [CachedTranscription][arkindex_worker.cache.CachedTranscription] instances if cache support is enabled. @@ -140,13 +139,13 @@ class TranscriptionMixin(object): for index, transcription in enumerate(transcriptions_payload): element_id = transcription.get("element_id") - assert element_id and isinstance( - element_id, str + assert ( + element_id and isinstance(element_id, str) ), f"Transcription at index {index} in transcriptions: element_id shouldn't be null and should be of type str" text = transcription.get("text") - assert text and isinstance( - text, str + assert ( + text and isinstance(text, str) ), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str" confidence = transcription.get("confidence") @@ -159,8 +158,8 @@ class TranscriptionMixin(object): orientation = transcription.get( "orientation", TextOrientation.HorizontalLeftToRight ) - assert orientation and isinstance( - orientation, TextOrientation + assert ( + orientation and isinstance(orientation, TextOrientation) ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation" if orientation: transcription["orientation"] = orientation.value @@ -203,10 +202,10 @@ class TranscriptionMixin(object): def create_element_transcriptions( self, - element: Union[Element, CachedElement], + element: Element | CachedElement, sub_element_type: str, - transcriptions: List[Dict[str, Union[str, float]]], - ) -> Dict[str, Union[str, bool]]: + transcriptions: list[dict[str, str | float]], + ) -> dict[str, str | bool]: """ Create multiple elements and transcriptions at once on a single parent element through the API. @@ -228,7 +227,7 @@ class TranscriptionMixin(object): :returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" assert sub_element_type and isinstance( sub_element_type, str @@ -242,8 +241,8 @@ class TranscriptionMixin(object): for index, transcription in enumerate(transcriptions_payload): text = transcription.get("text") - assert text and isinstance( - text, str + assert ( + text and isinstance(text, str) ), f"Transcription at index {index} in transcriptions: text shouldn't be null and should be of type str" confidence = transcription.get("confidence") @@ -256,15 +255,15 @@ class TranscriptionMixin(object): orientation = transcription.get( "orientation", TextOrientation.HorizontalLeftToRight ) - assert orientation and isinstance( - orientation, TextOrientation + assert ( + orientation and isinstance(orientation, TextOrientation) ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation" if orientation: transcription["orientation"] = orientation.value polygon = transcription.get("polygon") - assert polygon and isinstance( - polygon, list + assert ( + polygon and isinstance(polygon, list) ), f"Transcription at index {index} in transcriptions: polygon shouldn't be null and should be of type list" assert ( len(polygon) >= 3 @@ -273,12 +272,16 @@ class TranscriptionMixin(object): isinstance(point, list) and len(point) == 2 for point in polygon ), f"Transcription at index {index} in transcriptions: polygon points should be lists of two items" assert all( - isinstance(coord, (int, float)) for point in polygon for coord in point + isinstance(coord, int | float) for point in polygon for coord in point ), f"Transcription at index {index} in transcriptions: polygon points should be lists of two numbers" element_confidence = transcription.get("element_confidence") - assert element_confidence is None or ( - isinstance(element_confidence, float) and 0 <= element_confidence <= 1 + assert ( + element_confidence is None + or ( + isinstance(element_confidence, float) + and 0 <= element_confidence <= 1 + ) ), f"Transcription at index {index} in transcriptions: element_confidence should be either null or a float in [0..1] range" if self.is_read_only: @@ -359,11 +362,11 @@ class TranscriptionMixin(object): def list_transcriptions( self, - element: Union[Element, CachedElement], - element_type: Optional[str] = None, - recursive: Optional[bool] = None, - worker_version: Optional[Union[str, bool]] = None, - ) -> Union[Iterable[dict], Iterable[CachedTranscription]]: + element: Element | CachedElement, + element_type: str | None = None, + recursive: bool | None = None, + worker_version: str | bool | None = None, + ) -> Iterable[dict] | Iterable[CachedTranscription]: """ List transcriptions on an element. @@ -375,7 +378,7 @@ class TranscriptionMixin(object): or an iterable of CachedTranscription when cache support is enabled. """ assert element and isinstance( - element, (Element, CachedElement) + element, Element | CachedElement ), "element shouldn't be null and should be an Element or CachedElement" query_params = {} if element_type: @@ -386,7 +389,7 @@ class TranscriptionMixin(object): query_params["recursive"] = recursive if worker_version is not None: assert isinstance( - worker_version, (str, bool) + worker_version, str | bool ), "worker_version should be of type str or bool" if isinstance(worker_version, bool): assert ( diff --git a/arkindex_worker/worker/version.py b/arkindex_worker/worker/version.py index 85d6f8d72b323c01478b68b509b4113cdca0cb86..d71e4509e2ff08a9be2365b69acc60cc14b50866 100644 --- a/arkindex_worker/worker/version.py +++ b/arkindex_worker/worker/version.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- """ ElementsWorker methods for worker versions. """ -class WorkerVersionMixin(object): +class WorkerVersionMixin: def get_worker_version(self, worker_version_id: str) -> dict: """ Retrieve a worker version, using the [ElementsWorker][arkindex_worker.worker.ElementsWorker]'s internal cache when possible. diff --git a/cookiecutter.json b/cookiecutter.json index eac2407f8894b2aaeae407962e09a1b1ffc9d96e..f99fb896ef79cbcbd79e07e75383a39391041c71 100644 --- a/cookiecutter.json +++ b/cookiecutter.json @@ -3,6 +3,6 @@ "name": "Demo", "description": "Demo ML worker for Arkindex", "worker_type": "demo", - "author": "", - "email": "" + "author": "John Doe", + "email": "john.doe@company.com" } diff --git a/demo.py b/demo.py index 3c1dc21195dfbb02350c7c09a6dd89041af5d386..2e7e830b352a855d896124760f3e4c540a6f9e4f 100644 --- a/demo.py +++ b/demo.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from arkindex_worker.worker import ElementsWorker diff --git a/docs-requirements.txt b/docs-requirements.txt index f3d033630f8cf8807608f6c0a6b231bc87e2c59d..7338647490dee386a728a7dc6983fef5cd8999d2 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -1,7 +1,7 @@ -black==23.11.0 +black==23.12.0 doc8==1.1.1 mkdocs==1.5.3 -mkdocs-material==9.4.8 +mkdocs-material==9.5.2 mkdocstrings==0.24.0 -mkdocstrings-python==1.7.3 +mkdocstrings-python==1.7.5 recommonmark==0.7.1 diff --git a/docs/contents/implem/configure.md b/docs/contents/implem/configure.md index 1614984d69e148e20cd7f35ca677ff29897ef9fb..cdb425e497a13e6be8c061a18425f74ac8e95fff 100644 --- a/docs/contents/implem/configure.md +++ b/docs/contents/implem/configure.md @@ -115,6 +115,9 @@ Many attributes are set on the worker during at the configuration stage. Here is `model_version_id` : The ID of the model version linked to the current `WorkerRun` object on Arkindex. You may set it in developer mode via the `ARKINDEX_MODEL_VERSION_ID` environment variable. +`model_details` +: The details of the model for the model version linked to the current `WorkerRun` object on Arkindex. You may populate it in developer mode via the `ARKINDEX_MODEL_ID` environment variable. + `process_information` : The details about the process parent to this worker execution. Only set in Arkindex mode. diff --git a/docs/contents/workers/ci/index.md b/docs/contents/workers/ci/index.md index 26f0adfd61bcd5ee17be9ba343e02960eacd8896..ef802c261fb84b428762f6f97cf589698397e85e 100644 --- a/docs/contents/workers/ci/index.md +++ b/docs/contents/workers/ci/index.md @@ -24,7 +24,8 @@ At Teklia, we use a simple version of [Git Flow][gitflow]: - Developments should happen in branches, with merge requests to enable code review and Gitlab CI pipelines. - Project maintainers should use Git tags to create official releases, by - updating the `VERSION` file and using the same version string as the tag name. + updating the `project.version` key of the `pyproject.toml` file and using + the same version string as the tag name. This process is reflected the template's `.gitlab-ci.yml` file. diff --git a/docs/contents/workers/create.md b/docs/contents/workers/create.md index bfc0cfcb2fe5bc9c23c509601f72819888593ab0..2ba3db492440370900c27a2ee108f9661ce88f7e 100644 --- a/docs/contents/workers/create.md +++ b/docs/contents/workers/create.md @@ -131,9 +131,9 @@ to get a basic structure for your worker. Cookiecutter will ask you for several options: `slug` -: A slug for the worker. This should use lowercase alphanumeric characters or - underscores to meet the code formatting requirements that the template - automatically enforces via [black]. +: A slug for the worker. This should use lowercase alphanumeric characters, + underscores or hyphens to meet the code formatting requirements that the + template automatically enforces via [black]. `name` : A name for the worker, purely used for display purposes. @@ -159,6 +159,16 @@ Cookiecutter will ask you for several options: `email` : Your e-mail address. This will be used to contact you if any administrative need arise +Cookiecutter will also automatically normalize your worker's `slug` in new parameters: + +`__package` +: The name of the Python package for your worker, generated by normalizing the `slug` + with characters' lowering and replacing underscores with hyphens. + +`__module` +: The name of the Python module for your worker, generated by normalizing the `slug` + with characters' lowering and replacing hyphens with underscores. + ### Pushing to GitLab This section guides you through pushing the newly created worker from your @@ -169,7 +179,7 @@ This section assumes you have Maintainer or Owner access to the GitLab project. #### To push to GitLab 1. Enter the newly created directory, starting in `worker-` and ending with your - worker's slug. + worker's `slug`. 2. Add your GitLab project as a Git remote: diff --git a/docs/contents/workers/run-local.md b/docs/contents/workers/run-local.md index 31d3935690ce9123f275661b317f5e624d02aa8c..a14c58e31f3c7e30d7b75302cb316d7a279e1103 100644 --- a/docs/contents/workers/run-local.md +++ b/docs/contents/workers/run-local.md @@ -115,6 +115,6 @@ in the browser's address bar when browsing an element on Arkindex. 1. Activate the Python environment: run `workon X` where `X` is the name of your Python environment. -2. Run `worker-X`, where `X` is the slug of your worker, followed by +2. Run `worker-X`, where `X` is the `__package` name of your worker, followed by `--element=Y` where `Y` is the ID of an element. You can repeat `--element` as many times as you need to process multiple elements. diff --git a/docs/contents/workers/template-structure.md b/docs/contents/workers/template-structure.md index 425e9da379a73f138e702751a7c1e03dce16eecc..5d4b38d704ab9e1312d4c6ffdca18bf62380aef9 100644 --- a/docs/contents/workers/template-structure.md +++ b/docs/contents/workers/template-structure.md @@ -53,8 +53,8 @@ package, a Docker build, with the best development practices: `setup.py` : Configures the worker's Python package. -`VERSION` -: Official version number of your worker. Defaults to `0.1.0`. +`pyproject.toml` +: Configures the worker's Python package. `ci/build.sh` : Script that gets run by [CI](ci/index.md) pipelines @@ -68,10 +68,10 @@ package, a Docker build, with the best development practices: TODO: For more information, see [Writing tests for your worker](tests). --> -`worker_[slug]/__init__.py` +`worker_[__module]/__init__.py` : Declares the folder as a Python package. -`worker_[slug]/worker.py` +`worker_[__module]/worker.py` : The core part of the worker. This is where you can write code that processes Arkindex elements. diff --git a/docs/contents/workers/user_configuration/model_config.png b/docs/contents/workers/user_configuration/model_config.png new file mode 100644 index 0000000000000000000000000000000000000000..622d8f989aa8fb77407b85e08697fabbdc3249d5 Binary files /dev/null and b/docs/contents/workers/user_configuration/model_config.png differ diff --git a/docs/contents/workers/yaml.md b/docs/contents/workers/yaml.md index 62534d8ec49efd23ed4807e7c6ce52c1c3656ddd..b7018c7c0573e3501e214f30065c5ce80c6eb573 100644 --- a/docs/contents/workers/yaml.md +++ b/docs/contents/workers/yaml.md @@ -54,7 +54,7 @@ All attributes are optional unless explicitly specified. : Mandatory. Name of the worker, for display purposes. `slug` -: Mandatory. Slug of this worker. The slug must be unique across the repository and must only hold alphanumerical characters, underscores or dashes. +: Mandatory. Slug of this worker. The slug must be unique across the repository and must only hold alphanumerical characters, underscores or hyphens. `type` : Mandatory. Type of the worker, for display purposes only. Some common values @@ -80,7 +80,16 @@ include: : This worker does not support GPUs. It may run on a host that has a GPU, but it will ignore it. `model_usage` -: Boolean. Whether or not this worker requires a model version to run. Defaults to `false`. +: Whether or not this worker requires a model version to run. Defaults to `disabled`. May take one of the following values: + + `required` + : This worker requires a model version, and will only be run on processes with a model. + + `supported` + : This worker supports a model version, but may run on any processes, including those without model. + + `disabled` + : This worker does not support model version. It may run on a process that has a model, but it will ignore it. `docker` : Regroups Docker-related configuration attributes: @@ -137,6 +146,7 @@ A parameter is defined using the following settings: - `enum` - `list` - `dict` + - `model` `default` : Optional. A default value for the parameter. Must be of the defined parameter `type`. @@ -272,7 +282,7 @@ Which will result in the following display for the user: #### Dictionary parameters -Dictionary-type parameters must be defined using a `title`, the `dict` `type`. You can also set a `default` value for this parameter, which must be one a dictionary, as well as make it a `required` parameter, which prevents users from leaving it blank. You can use dictionary parameters for example to specify a correspondence between the classes that are predicted by a worker and the elements that are created on Arkindex from these predictions. +Dictionary-type parameters must be defined using a `title` and the `dict` `type`. You can also set a `default` value for this parameter, which must be a dictionary, as well as make it a `required` parameter, which prevents users from leaving it blank. You can use dictionary parameters for example to specify a correspondence between the classes that are predicted by a worker and the elements that are created on Arkindex from these predictions. Dictionary-type parameters only accept strings as values. @@ -293,6 +303,26 @@ Which will result in the following display for the user:  +#### Model parameters + +Model-type parameters must be defined using a `title` and the `model` type. You can also set a `default` value for this parameter, which must be the UUID of an existing Model, and make it a `required` parameter, which prevents users from leaving it blank. You can use a model parameter to specify to which Model the Model Version that is created by a Training process will be attached. + +Model-type parameters only accept Model UUIDs as values. + +In the configuration form, model parameters are displayed as an input field. Users can select a model from a list of available Models: what they type into the input field filters that list, allowing them to search for a model using its name or UUID. + +For example, a model-type parameter can be defined like this: + +```yaml +model_param: + title: Training Model + type: model +``` + +Which will result in the following display for the user: + + + #### Example user_configuration ```yaml @@ -318,6 +348,9 @@ user_configuration: - 23 - 56 title: Another Parameter + a_model_parameter: + type: model + title: Model to train ``` #### Fallback to free JSON input diff --git a/docs/ref/git.md b/docs/ref/git.md deleted file mode 100644 index 56ddfee01d26b9eb53c1dc8dfd58541d52869127..0000000000000000000000000000000000000000 --- a/docs/ref/git.md +++ /dev/null @@ -1,3 +0,0 @@ -# Git & Gitlab support - -::: arkindex_worker.git diff --git a/docs/releases.md b/docs/releases.md index a701e85729e6d3727cbb01d1f4721657d3d5065c..62556775f42ce550a810d8874125c0cba1bab0d6 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -115,7 +115,7 @@ Released on **8 November 2022** • View on [Gitlab](https://gitlab.teklia.co - A new version of the cache was released with the updated Django models. - Improvements to our Machine Learning training API to allow workers to use models published on Arkindex. - Support workers that have no configuration. - - Allow publishing metadatas with falsy but non-null values. + - Allow publishing metadata with falsy but non-null values. - Add `.polygon` attribute shortcut on `Element`. - Add a major test speedup on our worker template. - Support cache usage on our metadata API endpoint helpers. diff --git a/hooks/pre-commit b/hooks/pre-commit index 6205d988f3dd1ad7e32a3f4036aa472e8fe827c4..96d3a8acbba92c76efdc071aba2a9e1a1f3e6f38 100755 --- a/hooks/pre-commit +++ b/hooks/pre-commit @@ -1,50 +1,20 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- +#!/usr/bin/env bash # File generated by pre-commit: https://pre-commit.com # ID: 138fd403232d2ddd5efb44317e38bf03 -import os -import sys - -# we try our best, but the shebang of this script is difficult to determine: -# - macos doesn't ship with python3 -# - windows executables are almost always `python.exe` -# therefore we continue to support python2 for this small script -if sys.version_info < (3, 3): - from distutils.spawn import find_executable as which -else: - from shutil import which - -# work around https://github.com/Homebrew/homebrew-core/issues/30445 -os.environ.pop("__PYVENV_LAUNCHER__", None) # start templated -INSTALL_PYTHON = "/usr/bin/python3" -ARGS = [ - "hook-impl", - "--config=.pre-commit-config.yaml", - "--hook-type=pre-commit", - "--skip-on-missing-config", -] +INSTALL_PYTHON=/usr/bin/python3 +ARGS=(hook-impl --config=.pre-commit-config.yaml --hook-type=pre-commit --skip-on-missing-config) # end templated -ARGS.extend(("--hook-dir", os.path.realpath(os.path.dirname(__file__)))) -ARGS.append("--") -ARGS.extend(sys.argv[1:]) - -DONE = "`pre-commit` not found. Did you forget to activate your virtualenv?" -if os.access(INSTALL_PYTHON, os.X_OK): - CMD = [INSTALL_PYTHON, "-mpre_commit"] -elif which("pre-commit"): - CMD = ["pre-commit"] -else: - raise SystemExit(DONE) -CMD.extend(ARGS) -if sys.platform == "win32": # https://bugs.python.org/issue19124 - import subprocess +HERE="$(cd "$(dirname "$0")" && pwd)" +ARGS+=(--hook-dir "$HERE" -- "$@") - if sys.version_info < (3, 7): # https://bugs.python.org/issue25942 - raise SystemExit(subprocess.Popen(CMD).wait()) - else: - raise SystemExit(subprocess.call(CMD)) -else: - os.execvp(CMD[0], CMD) +if [ -x "$INSTALL_PYTHON" ]; then + exec "$INSTALL_PYTHON" -mpre_commit "${ARGS[@]}" +elif command -v pre-commit > /dev/null; then + exec pre-commit "${ARGS[@]}" +else + echo '`pre-commit` not found. Did you forget to activate your virtualenv?' 1>&2 + exit 1 +fi diff --git a/hooks/pre_gen_project.py b/hooks/pre_gen_project.py new file mode 100644 index 0000000000000000000000000000000000000000..34c0da07a48a5d33104652c6c5ce5aa9f0a51b54 --- /dev/null +++ b/hooks/pre_gen_project.py @@ -0,0 +1,3 @@ +# Normalize the slug to generate __package and __module private variables +{{cookiecutter.update({"__package": cookiecutter.slug.lower().replace("_", "-")})}} # noqa: F821 +{{cookiecutter.update({"__module": cookiecutter.slug.lower().replace("-", "_")})}} # noqa: F821 diff --git a/mkdocs.yml b/mkdocs.yml index 343e83cab2174890bd2badc3c666b9a763c41933..324af6346c7712062edff09eab5b495539af6ba5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,7 +36,6 @@ plugins: - search - autorefs - mkdocstrings: - custom_templates: templates handlers: python: import: # enable auto refs to the doc @@ -90,7 +89,6 @@ nav: - Transcription: ref/api/transcription.md - WorkerVersion: ref/api/worker_version.md - Models: ref/models.md - - Git & Gitlab support: ref/git.md - Image utilities: ref/image.md - Cache: ref/cache.md - Utils: ref/utils.md diff --git a/pyproject.toml b/pyproject.toml index debb3e7c4e6d4b9e74facaada51df8523e87f42b..6d757073c64fbe646c2ca3057060ce3442459b67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,82 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "arkindex-base-worker" +version = "0.3.6-rc5" +description = "Base Worker to easily build Arkindex ML workflows" +license = { file = "LICENSE" } +dynamic = ["dependencies", "optional-dependencies"] +authors = [ + { name = "Teklia", email = "contact@teklia.com" }, +] +maintainers = [ + { name = "Teklia", email = "contact@teklia.com" }, +] +requires-python = ">=3.10" +readme = { file = "README.md", content-type = "text/markdown" } +keywords = ["python"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: MIT License", + # Specify the Python versions you support here. + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + # Topics + "Topic :: Text Processing :: Linguistic", +] + +[project.urls] +Homepage = "https://workers.arkindex.org" +Documentation = "https://workers.arkindex.org" +Repository = "https://gitlab.teklia.com/workers/base-worker" +"Bug Tracker" = "https://gitlab.teklia.com/workers/base-worker/issues" +Authors = "https://teklia.com" + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } +optional-dependencies = { docs = { file = ["docs-requirements.txt"] } } + [tool.ruff] exclude = [".git", "__pycache__"] ignore = ["E501"] -select = ["E", "F", "T1", "W", "I"] +select = [ + # pycodestyle + "E", + "W", + # Pyflakes + "F", + # Flake8 Debugger + "T1", + # Isort + "I", + # Implicit Optional + "RUF013", + # Invalid pyproject.toml + "RUF200", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # flake8-pytest-style + "PT", + # flake8-use-pathlib + "PTH", +] + +[tool.ruff.per-file-ignores] +# Ignore `pytest-composite-assertion` rules of `flake8-pytest-style` linter for non-test files +"arkindex_worker/**/*.py" = ["PT018"] [tool.ruff.isort] known-first-party = ["arkindex", "arkindex_common", "arkindex_worker"] known-third-party = [ "PIL", "apistar", - "gitlab", "gnupg", "peewee", "playhouse", @@ -16,7 +84,6 @@ known-third-party = [ "requests", "responses", "setuptools", - "sh", "shapely", "tenacity", "yaml", diff --git a/requirements.txt b/requirements.txt index e61242c38850d2998ad1f1f8f6d6807d730e5caf..946db4f302d5ebcc9abf243e960bef03e0ddde7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ arkindex-client==1.0.14 peewee==3.17.0 Pillow==10.1.0 -pymdown-extensions==10.3.1 -python-gitlab==4.1.1 -python-gnupg==0.5.1 -sh==2.0.6 +pymdown-extensions==10.5 +python-gnupg==0.5.2 shapely==2.0.2 tenacity==8.2.3 zstandard==0.22.0 diff --git a/setup.py b/setup.py index 2c28725a422eb5c1c6dd635375876f03243a5a0e..ca9ba4a1b5b004d42d4c001f0f8d8b8e7338f102 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -from pathlib import Path - from setuptools import find_packages, setup - -def requirements(path: Path): - assert path.exists(), "Missing requirements {}".format(path) - with path.open() as f: - return list(map(str.strip, f.read().splitlines())) - - -with open("VERSION") as f: - VERSION = f.read() - -setup( - name="arkindex-base-worker", - version=VERSION, - description="Base Worker to easily build Arkindex ML workflows", - author="Teklia", - author_email="contact@teklia.com", - url="https://teklia.com", - python_requires=">=3.7", - install_requires=requirements(Path("requirements.txt")), - extras_require={"docs": requirements(Path("docs-requirements.txt"))}, - packages=find_packages(), -) +setup(packages=find_packages()) diff --git a/tests/conftest.py b/tests/conftest.py index ad0c0f264ae4fcd19cf9423c51e82bc18a242437..4155fbe1be7032087b855e6da7f823f7bab2b077 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import hashlib import json import os @@ -19,10 +18,10 @@ from arkindex_worker.cache import ( CachedImage, CachedTranscription, Version, + create_tables, create_version_table, init_cache_db, ) -from arkindex_worker.git import GitHelper, GitlabHelper from arkindex_worker.models import Artifact, Dataset from arkindex_worker.worker import BaseWorker, DatasetWorker, ElementsWorker from arkindex_worker.worker.dataset import DatasetState @@ -37,7 +36,7 @@ __yaml_cache = {} @pytest.fixture(autouse=True) -def disable_sleep(monkeypatch): +def _disable_sleep(monkeypatch): """ Do not sleep at all in between API executions when errors occur in unit tests. @@ -46,8 +45,8 @@ def disable_sleep(monkeypatch): monkeypatch.setattr(time, "sleep", lambda x: None) -@pytest.fixture -def cache_yaml(monkeypatch): +@pytest.fixture() +def _cache_yaml(monkeypatch): """ Cache all calls to yaml.safe_load in order to speedup every test cases that load the OpenAPI schema @@ -75,7 +74,7 @@ def cache_yaml(monkeypatch): @pytest.fixture(autouse=True) -def setup_api(responses, monkeypatch, cache_yaml): +def _setup_api(responses, monkeypatch, _cache_yaml): # Always use the environment variable first schema_url = os.environ.get("ARKINDEX_API_SCHEMA_URL") if schema_url is None: @@ -106,13 +105,13 @@ def setup_api(responses, monkeypatch, cache_yaml): @pytest.fixture(autouse=True) -def give_env_variable(request, monkeypatch): +def _give_env_variable(monkeypatch): """Defines required environment variables""" monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", "56785678-5678-5678-5678-567856785678") -@pytest.fixture -def mock_worker_run_api(responses): +@pytest.fixture() +def _mock_worker_run_api(responses): """Provide a mock API response to get worker run information""" payload = { "id": "56785678-5678-5678-5678-567856785678", @@ -140,7 +139,7 @@ def mock_worker_run_api(responses): "docker_image_name": None, "state": "created", "gpu_usage": "disabled", - "model_usage": False, + "model_usage": "disabled", "worker": { "id": "deadbeef-1234-5678-1234-worker", "name": "Fake worker", @@ -180,8 +179,8 @@ def mock_worker_run_api(responses): ) -@pytest.fixture -def mock_worker_run_no_revision_api(responses): +@pytest.fixture() +def _mock_worker_run_no_revision_api(responses): """Provide a mock API response to get worker run not linked to a revision information""" payload = { "id": "56785678-5678-5678-5678-567856785678", @@ -207,7 +206,7 @@ def mock_worker_run_no_revision_api(responses): "docker_image_name": None, "state": "created", "gpu_usage": "disabled", - "model_usage": False, + "model_usage": "disabled", "worker": { "id": "deadbeef-1234-5678-1234-worker", "name": "Fake worker", @@ -247,8 +246,8 @@ def mock_worker_run_no_revision_api(responses): ) -@pytest.fixture -def mock_activity_calls(responses): +@pytest.fixture() +def _mock_activity_calls(responses): """ Mock responses when updating the activity state for multiple element of the same version """ @@ -259,8 +258,8 @@ def mock_activity_calls(responses): ) -@pytest.fixture -def mock_elements_worker(monkeypatch, mock_worker_run_api): +@pytest.fixture() +def mock_elements_worker(monkeypatch, _mock_worker_run_api): """Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest""" monkeypatch.setattr(sys, "argv", ["worker"]) worker = ElementsWorker() @@ -268,7 +267,7 @@ def mock_elements_worker(monkeypatch, mock_worker_run_api): return worker -@pytest.fixture +@pytest.fixture() def mock_elements_worker_read_only(monkeypatch): """Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest""" monkeypatch.setattr(sys, "argv", ["worker", "--dev"]) @@ -277,7 +276,7 @@ def mock_elements_worker_read_only(monkeypatch): return worker -@pytest.fixture +@pytest.fixture() def mock_elements_worker_with_list(monkeypatch, responses, mock_elements_worker): """ Mock a worker instance to list and retrieve a single element @@ -298,8 +297,19 @@ def mock_elements_worker_with_list(monkeypatch, responses, mock_elements_worker) return mock_elements_worker -@pytest.fixture -def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api): +@pytest.fixture() +def mock_cache_db(tmp_path): + cache_path = tmp_path / "db.sqlite" + + init_cache_db(cache_path) + create_version_table() + create_tables() + + return cache_path + + +@pytest.fixture() +def mock_base_worker_with_cache(monkeypatch, _mock_worker_run_api): """Build a BaseWorker using SQLite cache, also mocking a PONOS_TASK""" monkeypatch.setattr(sys, "argv", ["worker"]) @@ -309,13 +319,10 @@ def mock_base_worker_with_cache(mocker, monkeypatch, mock_worker_run_api): return worker -@pytest.fixture -def mock_elements_worker_with_cache(monkeypatch, mock_worker_run_api, tmp_path): +@pytest.fixture() +def mock_elements_worker_with_cache(monkeypatch, mock_cache_db, _mock_worker_run_api): """Build and configure an ElementsWorker using SQLite cache with fixed CLI parameters to avoid issues with pytest""" - cache_path = tmp_path / "db.sqlite" - init_cache_db(cache_path) - create_version_table() - monkeypatch.setattr(sys, "argv", ["worker", "-d", str(cache_path)]) + monkeypatch.setattr(sys, "argv", ["worker", "-d", str(mock_cache_db)]) worker = ElementsWorker(support_cache=True) worker.configure() @@ -323,35 +330,34 @@ def mock_elements_worker_with_cache(monkeypatch, mock_worker_run_api, tmp_path): return worker -@pytest.fixture +@pytest.fixture() def fake_page_element(): - with open(FIXTURES_DIR / "page_element.json", "r") as f: - return json.load(f) + return json.loads((FIXTURES_DIR / "page_element.json").read_text()) -@pytest.fixture +@pytest.fixture() def fake_ufcn_worker_version(): - with open(FIXTURES_DIR / "ufcn_line_historical_worker_version.json", "r") as f: - return json.load(f) + return json.loads( + (FIXTURES_DIR / "ufcn_line_historical_worker_version.json").read_text() + ) -@pytest.fixture +@pytest.fixture() def fake_transcriptions_small(): - with open(FIXTURES_DIR / "line_transcriptions_small.json", "r") as f: - return json.load(f) + return json.loads((FIXTURES_DIR / "line_transcriptions_small.json").read_text()) -@pytest.fixture +@pytest.fixture() def model_file_dir(): return SAMPLES_DIR / "model_files" -@pytest.fixture +@pytest.fixture() def model_file_dir_with_subfolder(): return SAMPLES_DIR / "root_folder" -@pytest.fixture +@pytest.fixture() def fake_dummy_worker(): api_client = MockApiClient() worker = ElementsWorker() @@ -359,34 +365,8 @@ def fake_dummy_worker(): return worker -@pytest.fixture -def fake_git_helper(mocker): - gitlab_helper = mocker.MagicMock() - return GitHelper( - "repo_url", - "/tmp/git_test/foo/", - "/tmp/test/path/", - "tmp_workflow_id", - gitlab_helper, - ) - - -@pytest.fixture -def fake_gitlab_helper_factory(): - # have to set up the responses, before creating the client - def run(): - return GitlabHelper( - "balsac_exporter/balsac-exported-xmls-testing", - "https://gitlab.com", - "<GITLAB_TOKEN>", - "gitlab_branch", - ) - - return run - - -@pytest.fixture -def mock_cached_elements(): +@pytest.fixture() +def _mock_cached_elements(mock_cache_db): """Insert few elements in local cache""" CachedElement.create( id=UUID("99999999-9999-9999-9999-999999999999"), @@ -430,8 +410,8 @@ def mock_cached_elements(): assert CachedElement.select().count() == 5 -@pytest.fixture -def mock_cached_images(): +@pytest.fixture() +def _mock_cached_images(mock_cache_db): """Insert few elements in local cache""" CachedImage.create( id=UUID("99999999-9999-9999-9999-999999999999"), @@ -442,8 +422,8 @@ def mock_cached_images(): assert CachedImage.select().count() == 1 -@pytest.fixture -def mock_cached_transcriptions(): +@pytest.fixture() +def _mock_cached_transcriptions(mock_cache_db): """Insert few transcriptions in local cache, on a shared element""" CachedElement.create( id=UUID("11111111-1111-1111-1111-111111111111"), @@ -529,7 +509,7 @@ def mock_cached_transcriptions(): ) -@pytest.fixture(scope="function") +@pytest.fixture() def mock_databases(tmp_path): """ Initialize several temporary databases @@ -612,7 +592,7 @@ def mock_databases(tmp_path): return out -@pytest.fixture +@pytest.fixture() def default_dataset(): return Dataset( **{ @@ -630,8 +610,8 @@ def default_dataset(): ) -@pytest.fixture -def mock_dataset_worker(monkeypatch, mocker, mock_worker_run_api): +@pytest.fixture() +def mock_dataset_worker(monkeypatch, mocker, _mock_worker_run_api): monkeypatch.setenv("PONOS_TASK", "my_task") mocker.patch.object(sys, "argv", ["worker"]) @@ -644,7 +624,7 @@ def mock_dataset_worker(monkeypatch, mocker, mock_worker_run_api): return dataset_worker -@pytest.fixture +@pytest.fixture() def mock_dev_dataset_worker(mocker): mocker.patch.object( sys, @@ -668,7 +648,7 @@ def mock_dev_dataset_worker(mocker): return dataset_worker -@pytest.fixture +@pytest.fixture() def default_artifact(): return Artifact( **{ diff --git a/tests/test_base_worker.py b/tests/test_base_worker.py index 9547df6f6e87127fbcf01db35ab2b645589e6e06..75accdd3998cc50439acc6982df7a52a84f7c0b9 100644 --- a/tests/test_base_worker.py +++ b/tests/test_base_worker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import json import logging import sys @@ -15,7 +14,7 @@ from arkindex_worker.worker.base import ExtrasDirNotFoundError from tests.conftest import FIXTURES_DIR -def test_init_default_local_share(monkeypatch): +def test_init_default_local_share(): worker = BaseWorker() assert worker.work_dir == Path("~/.local/share/arkindex").expanduser() @@ -29,7 +28,7 @@ def test_init_default_xdg_data_home(monkeypatch): assert str(worker.work_dir) == f"{path}/arkindex" -def test_init_with_local_cache(monkeypatch): +def test_init_with_local_cache(): worker = BaseWorker(support_cache=True) assert worker.work_dir == Path("~/.local/share/arkindex").expanduser() @@ -72,7 +71,8 @@ def test_init_var_worker_local_file(monkeypatch, tmp_path): config.unlink() -def test_cli_default(mocker, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_cli_default(mocker): worker = BaseWorker() assert logger.level == logging.NOTSET @@ -91,7 +91,8 @@ def test_cli_default(mocker, mock_worker_run_api): logger.setLevel(logging.NOTSET) -def test_cli_arg_verbose_given(mocker, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_cli_arg_verbose_given(mocker): worker = BaseWorker() assert logger.level == logging.NOTSET @@ -110,7 +111,8 @@ def test_cli_arg_verbose_given(mocker, mock_worker_run_api): logger.setLevel(logging.NOTSET) -def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_cli_envvar_debug_given(mocker, monkeypatch): worker = BaseWorker() assert logger.level == logging.NOTSET @@ -129,7 +131,7 @@ def test_cli_envvar_debug_given(mocker, monkeypatch, mock_worker_run_api): logger.setLevel(logging.NOTSET) -def test_configure_dev_mode(mocker, monkeypatch): +def test_configure_dev_mode(mocker): """ Configuring a worker in developer mode avoid retrieving process information """ @@ -145,7 +147,7 @@ def test_configure_dev_mode(mocker, monkeypatch): assert worker.user_configuration == {} -def test_configure_worker_run(mocker, monkeypatch, responses, caplog): +def test_configure_worker_run(mocker, responses, caplog): # Capture log messages caplog.set_level(logging.INFO) @@ -214,9 +216,8 @@ def test_configure_worker_run(mocker, monkeypatch, responses, caplog): assert worker.user_configuration == {"a": "b"} -def test_configure_worker_run_no_revision( - mocker, monkeypatch, mock_worker_run_no_revision_api, caplog -): +@pytest.mark.usefixtures("_mock_worker_run_no_revision_api") +def test_configure_worker_run_no_revision(mocker, caplog): worker = BaseWorker() mocker.patch.object(sys, "argv", ["worker"]) @@ -234,11 +235,7 @@ def test_configure_worker_run_no_revision( ] -def test_configure_user_configuration_defaults( - mocker, - monkeypatch, - responses, -): +def test_configure_user_configuration_defaults(mocker, responses): worker = BaseWorker() mocker.patch.object(sys, "argv") worker.args = worker.parser.parse_args() @@ -300,8 +297,8 @@ def test_configure_user_configuration_defaults( } -@pytest.mark.parametrize("debug", (True, False)) -def test_configure_user_config_debug(mocker, monkeypatch, responses, debug): +@pytest.mark.parametrize("debug", [True, False]) +def test_configure_user_config_debug(mocker, responses, debug): worker = BaseWorker() mocker.patch.object(sys, "argv", ["worker"]) assert logger.level == logging.NOTSET @@ -347,7 +344,7 @@ def test_configure_user_config_debug(mocker, monkeypatch, responses, debug): logger.setLevel(logging.NOTSET) -def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses): +def test_configure_worker_run_missing_conf(mocker, responses): worker = BaseWorker() mocker.patch.object(sys, "argv", ["worker"]) @@ -392,7 +389,7 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses): assert worker.user_configuration == {} -def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses): +def test_configure_worker_run_no_worker_run_conf(mocker, responses): """ No configuration is provided but should not crash """ @@ -434,7 +431,7 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses) assert worker.user_configuration == {} -def test_configure_load_model_configuration(mocker, monkeypatch, responses): +def test_configure_load_model_configuration(mocker, responses): worker = BaseWorker() mocker.patch.object(sys, "argv", ["worker"]) payload = { @@ -454,7 +451,10 @@ def test_configure_load_model_configuration(mocker, monkeypatch, responses): "configuration": None, "model_version": { "id": "12341234-1234-1234-1234-123412341234", - "name": "Model version 1337", + "model": { + "id": "43214321-4321-4321-4321-432143214321", + "name": "Model 1337", + }, "configuration": { "param1": "value1", "param2": 2, @@ -489,6 +489,10 @@ def test_configure_load_model_configuration(mocker, monkeypatch, responses): "param3": None, } assert worker.model_version_id == "12341234-1234-1234-1234-123412341234" + assert worker.model_details == { + "id": "43214321-4321-4321-4321-432143214321", + "name": "Model 1337", + } def test_load_missing_secret(): @@ -578,7 +582,7 @@ def test_load_local_secret(monkeypatch, tmp_path): secret.write_text("this is a local secret value", encoding="utf-8") # Mock GPG decryption - class GpgDecrypt(object): + class GpgDecrypt: def __init__(self, fd): self.ok = True self.data = fd.read() @@ -631,15 +635,15 @@ def test_find_extras_directory_from_config(monkeypatch): @pytest.mark.parametrize( - "extras_path, exists, error", - ( - [ + ("extras_path", "exists", "error"), + [ + ( None, True, "No path to the directory for extra files was provided. Please provide extras_dir either through configuration or as CLI argument.", - ], - ["extra_files", False, "The path extra_files does not link to any directory"], - ), + ), + ("extra_files", False, "The path extra_files does not link to any directory"), + ], ) def test_find_extras_directory_not_found(monkeypatch, extras_path, exists, error): if extras_path: @@ -666,7 +670,9 @@ def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_pat ) filename = Path("my_file.txt") - for parent_id, content in zip(["first", "third"], ["Some text", "Other text"]): + for parent_id, content in zip( + ["first", "third"], ["Some text", "Other text"], strict=True + ): (tmp_path / parent_id).mkdir() file_path = tmp_path / parent_id / filename with file_path.open("w", encoding="utf-8") as f: @@ -742,7 +748,7 @@ def test_corpus_id_not_set_read_only_mode( with pytest.raises( Exception, match="Missing ARKINDEX_CORPUS_ID environment variable" ): - mock_elements_worker_read_only.corpus_id + _ = mock_elements_worker_read_only.corpus_id def test_corpus_id_set_read_only_mode( diff --git a/tests/test_cache.py b/tests/test_cache.py index adcef102fcb4144d1c1a5ae1f3971d71239f6d45..0d441dd24f98a9fadc373df06211efbf5525fa2c 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from pathlib import Path from uuid import UUID @@ -31,22 +30,20 @@ def test_init(tmp_path): def test_create_tables_existing_table(tmp_path): - db_path = f"{tmp_path}/db.sqlite" + db_path = tmp_path / "db.sqlite" # Create the tables once… init_cache_db(db_path) create_tables() db.close() - with open(db_path, "rb") as before_file: - before = before_file.read() + before = db_path.read_bytes() # Create them again init_cache_db(db_path) create_tables() - with open(db_path, "rb") as after_file: - after = after_file.read() + after = db_path.read_bytes() assert before == after, "Existing table structure was modified" @@ -56,6 +53,9 @@ def test_create_tables(tmp_path): init_cache_db(db_path) create_tables() + # WARNING: If you are updating this schema following a development you have made + # in base-worker, make sure to upgrade the arkindex_worker.cache.SQL_VERSION in + # the same merge request as your changes. expected_schema = """CREATE TABLE "classifications" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "class_name" TEXT NOT NULL, "confidence" REAL NOT NULL, "state" VARCHAR(10) NOT NULL, "worker_run_id" TEXT, FOREIGN KEY ("element_id") REFERENCES "elements" ("id")) CREATE TABLE "dataset_elements" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "dataset_id" TEXT NOT NULL, "set_name" VARCHAR(255) NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"), FOREIGN KEY ("dataset_id") REFERENCES "datasets" ("id")) CREATE TABLE "datasets" ("id" TEXT NOT NULL PRIMARY KEY, "name" VARCHAR(255) NOT NULL, "state" VARCHAR(255) NOT NULL DEFAULT 'open', "sets" TEXT NOT NULL) @@ -144,7 +144,17 @@ def test_check_version_same_version(tmp_path): @pytest.mark.parametrize( - "image_width,image_height,polygon_x,polygon_y,polygon_width,polygon_height,max_width,max_height,expected_url", + ( + "image_width", + "image_height", + "polygon_x", + "polygon_y", + "polygon_width", + "polygon_height", + "max_width", + "max_height", + "expected_url", + ), [ # No max_size: no resize ( diff --git a/tests/test_dataset_worker.py b/tests/test_dataset_worker.py index ca7feb0ab2f7a4477f131e718c6b7e9fe8d0f5df..0aca6d6502cc9538f3fb2f401e55271f29633995 100644 --- a/tests/test_dataset_worker.py +++ b/tests/test_dataset_worker.py @@ -413,7 +413,7 @@ def test_list_datasets(responses, mock_dataset_worker): ] -@pytest.mark.parametrize("generator", (True, False)) +@pytest.mark.parametrize("generator", [True, False]) def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator): mocker.patch("arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[]) mock_dataset_worker.generator = generator @@ -428,7 +428,7 @@ def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator): @pytest.mark.parametrize( - "generator, error", + ("generator", "error"), [ (True, "When generating a new dataset, its state should be Open."), (False, "When processing an existing dataset, its state should be Complete."), @@ -657,7 +657,7 @@ def test_run_no_downloaded_artifact_error( @pytest.mark.parametrize( - "generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)] + ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)] ) def test_run( mocker, @@ -749,7 +749,7 @@ def test_run( @pytest.mark.parametrize( - "generator, state", [(True, DatasetState.Open), (False, DatasetState.Complete)] + ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)] ) def test_run_read_only( mocker, diff --git a/tests/test_element.py b/tests/test_element.py index 225fbb5754005d395962c20c1de18146d04ba19c..9fdca1842f2ed158f1857a94a8c38c250159dee0 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from requests import HTTPError diff --git a/tests/test_elements_worker/__init__.py b/tests/test_elements_worker/__init__.py index c553aa50d0829fd17d5fc92966d1cfe2bce27824..c1e5a038ae034b11e8d3a95481fc869b03d69175 100644 --- a/tests/test_elements_worker/__init__.py +++ b/tests/test_elements_worker/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # API calls during worker configuration BASE_API_CALLS = [ ( diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py index 5b224c82dcd7b8196d9f979c7498bb1f669f5a7e..aec222298765cd4d6d09408bb84fb036704752ee 100644 --- a/tests/test_elements_worker/test_classifications.py +++ b/tests/test_elements_worker/test_classifications.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import json import re from uuid import UUID, uuid4 diff --git a/tests/test_elements_worker/test_cli.py b/tests/test_elements_worker/test_cli.py index ed2dfbe1ab7c6e146436c4918283a6cd203c1ebf..f52978fdedad40c8335f7f74d60868e19d55f5bc 100644 --- a/tests/test_elements_worker/test_cli.py +++ b/tests/test_elements_worker/test_cli.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- import json -import os import sys import tempfile +from pathlib import Path from uuid import UUID import pytest @@ -10,49 +9,53 @@ import pytest from arkindex_worker.worker import ElementsWorker -def test_cli_default(monkeypatch, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_cli_default(monkeypatch): _, path = tempfile.mkstemp() - with open(path, "w") as f: - json.dump( + path = Path(path) + path.write_text( + json.dumps( [ {"id": "volumeid", "type": "volume"}, {"id": "pageid", "type": "page"}, {"id": "actid", "type": "act"}, {"id": "surfaceid", "type": "surface"}, ], - f, ) + ) monkeypatch.setenv("TASK_ELEMENTS", path) monkeypatch.setattr(sys, "argv", ["worker"]) worker = ElementsWorker() worker.configure() - assert worker.args.elements_list.name == path + assert worker.args.elements_list.name == str(path) assert not worker.args.element - os.unlink(path) + path.unlink() -def test_cli_arg_elements_list_given(mocker, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_cli_arg_elements_list_given(mocker): _, path = tempfile.mkstemp() - with open(path, "w") as f: - json.dump( + path = Path(path) + path.write_text( + json.dumps( [ {"id": "volumeid", "type": "volume"}, {"id": "pageid", "type": "page"}, {"id": "actid", "type": "act"}, {"id": "surfaceid", "type": "surface"}, ], - f, ) + ) - mocker.patch.object(sys, "argv", ["worker", "--elements-list", path]) + mocker.patch.object(sys, "argv", ["worker", "--elements-list", str(path)]) worker = ElementsWorker() worker.configure() - assert worker.args.elements_list.name == path + assert worker.args.elements_list.name == str(path) assert not worker.args.element - os.unlink(path) + path.unlink() def test_cli_arg_element_one_given_not_uuid(mocker): @@ -62,7 +65,8 @@ def test_cli_arg_element_one_given_not_uuid(mocker): worker.configure() -def test_cli_arg_element_one_given(mocker, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_cli_arg_element_one_given(mocker): mocker.patch.object( sys, "argv", ["worker", "--element", "12341234-1234-1234-1234-123412341234"] ) @@ -74,7 +78,8 @@ def test_cli_arg_element_one_given(mocker, mock_worker_run_api): assert not worker.args.elements_list -def test_cli_arg_element_many_given(mocker, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_cli_arg_element_many_given(mocker): mocker.patch.object( sys, "argv", diff --git a/tests/test_elements_worker/test_dataset.py b/tests/test_elements_worker/test_dataset.py index ddfae978689db2787095a3ffb4f93b57dce39934..7c80dd3a43568c627fae69403d829d94a9e65c54 100644 --- a/tests/test_elements_worker/test_dataset.py +++ b/tests/test_elements_worker/test_dataset.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import json import logging @@ -107,8 +106,8 @@ def test_list_process_datasets( @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Dataset ( {"dataset": None}, @@ -118,7 +117,7 @@ def test_list_process_datasets( {"dataset": "not Dataset type"}, "dataset shouldn't be null and should be a Dataset", ), - ), + ], ) def test_list_dataset_elements_wrong_param_dataset(mock_dataset_worker, payload, error): with pytest.raises(AssertionError, match=error): @@ -265,8 +264,8 @@ def test_list_dataset_elements( @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Dataset ( {"dataset": None}, @@ -276,7 +275,7 @@ def test_list_dataset_elements( {"dataset": "not dataset type"}, "dataset shouldn't be null and should be a Dataset", ), - ), + ], ) def test_update_dataset_state_wrong_param_dataset( mock_dataset_worker, default_dataset, payload, error @@ -292,8 +291,8 @@ def test_update_dataset_state_wrong_param_dataset( @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # DatasetState ( {"state": None}, @@ -303,7 +302,7 @@ def test_update_dataset_state_wrong_param_dataset( {"state": "not dataset type"}, "state shouldn't be null and should be a str from DatasetState", ), - ), + ], ) def test_update_dataset_state_wrong_param_state( mock_dataset_worker, default_dataset, payload, error diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py index d6db94b49382a5a681c995f32c733fb041407ade..2445f230f085a8fedb656992ad0f9c03aa77dd11 100644 --- a/tests/test_elements_worker/test_elements.py +++ b/tests/test_elements_worker/test_elements.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import json import re from argparse import Namespace @@ -429,19 +428,7 @@ def test_create_sub_element_wrong_name(mock_elements_worker): def test_create_sub_element_wrong_polygon(mock_elements_worker): elt = Element({"zone": None}) - with pytest.raises( - AssertionError, match="polygon shouldn't be null and should be of type list" - ): - mock_elements_worker.create_sub_element( - element=elt, - type="something", - name="0", - polygon=None, - ) - - with pytest.raises( - AssertionError, match="polygon shouldn't be null and should be of type list" - ): + with pytest.raises(AssertionError, match="polygon should be None or a list"): mock_elements_worker.create_sub_element( element=elt, type="something", @@ -505,6 +492,42 @@ def test_create_sub_element_wrong_confidence(mock_elements_worker, confidence): ) +@pytest.mark.parametrize( + ("image", "error_type", "error_message"), + [ + (1, AssertionError, "image should be None or string"), + ("not a uuid", ValueError, "image is not a valid uuid."), + ], +) +def test_create_sub_element_wrong_image( + mock_elements_worker, image, error_type, error_message +): + with pytest.raises(error_type, match=re.escape(error_message)): + mock_elements_worker.create_sub_element( + element=Element({"zone": None}), + type="something", + name="blah", + polygon=[[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]], + image=image, + ) + + +def test_create_sub_element_wrong_image_and_polygon(mock_elements_worker): + with pytest.raises( + AssertionError, + match=re.escape( + "An image or a parent with an image is required to create an element with a polygon." + ), + ): + mock_elements_worker.create_sub_element( + element=Element({"zone": None}), + type="something", + name="blah", + polygon=[[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]], + image=None, + ) + + def test_create_sub_element_api_error(responses, mock_elements_worker): elt = Element( { @@ -581,7 +604,7 @@ def test_create_sub_element(responses, mock_elements_worker, slim_output): assert json.loads(responses.calls[-1].request.body) == { "type": "something", "name": "0", - "image": "22222222-2222-2222-2222-222222222222", + "image": None, "corpus": "11111111-1111-1111-1111-111111111111", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "parent": "12341234-1234-1234-1234-123412341234", @@ -626,7 +649,7 @@ def test_create_sub_element_confidence(responses, mock_elements_worker): assert json.loads(responses.calls[-1].request.body) == { "type": "something", "name": "0", - "image": "22222222-2222-2222-2222-222222222222", + "image": None, "corpus": "11111111-1111-1111-1111-111111111111", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "parent": "12341234-1234-1234-1234-123412341234", @@ -1219,8 +1242,96 @@ def test_create_elements_integrity_error( @pytest.mark.parametrize( - "payload, error", - ( + ("params", "error_message"), + [ + ( + {"parent": None, "child": None}, + "parent shouldn't be null and should be of type Element", + ), + ( + {"parent": "not an element", "child": None}, + "parent shouldn't be null and should be of type Element", + ), + ( + {"parent": Element(zone=None), "child": None}, + "child shouldn't be null and should be of type Element", + ), + ( + {"parent": Element(zone=None), "child": "not an element"}, + "child shouldn't be null and should be of type Element", + ), + ], +) +def test_create_element_parent_invalid_params( + mock_elements_worker, params, error_message +): + with pytest.raises(AssertionError, match=re.escape(error_message)): + mock_elements_worker.create_element_parent(**params) + + +def test_create_element_parent_api_error(responses, mock_elements_worker): + parent = Element({"id": "12341234-1234-1234-1234-123412341234"}) + child = Element({"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}) + responses.add( + responses.POST, + "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/", + status=500, + ) + + with pytest.raises(ErrorResponse): + mock_elements_worker.create_element_parent( + parent=parent, + child=child, + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 5 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + # We retry 5 times the API call + ( + "POST", + "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/", + ), + ] * 5 + + +def test_create_element_parent(responses, mock_elements_worker): + parent = Element({"id": "12341234-1234-1234-1234-123412341234"}) + child = Element({"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}) + responses.add( + responses.POST, + "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/", + status=200, + json={ + "parent": "12341234-1234-1234-1234-123412341234", + "child": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + }, + ) + + created_element_parent = mock_elements_worker.create_element_parent( + parent=parent, + child=child, + ) + + assert len(responses.calls) == len(BASE_API_CALLS) + 1 + assert [ + (call.request.method, call.request.url) for call in responses.calls + ] == BASE_API_CALLS + [ + ( + "POST", + "http://testserver/api/v1/element/497f6eca-6276-4993-bfeb-53cbbbba6f08/parent/12341234-1234-1234-1234-123412341234/", + ), + ] + assert created_element_parent == { + "parent": "12341234-1234-1234-1234-123412341234", + "child": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + } + + +@pytest.mark.parametrize( + ("payload", "error"), + [ # Element ( {"element": None}, @@ -1230,7 +1341,7 @@ def test_create_elements_integrity_error( {"element": "not element type"}, "element shouldn't be null and should be an Element or CachedElement", ), - ), + ], ) def test_partial_update_element_wrong_param_element( mock_elements_worker, payload, error @@ -1247,12 +1358,12 @@ def test_partial_update_element_wrong_param_element( @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Type ({"type": 1234}, "type should be a str"), ({"type": None}, "type should be a str"), - ), + ], ) def test_partial_update_element_wrong_param_type(mock_elements_worker, payload, error): api_payload = { @@ -1267,12 +1378,12 @@ def test_partial_update_element_wrong_param_type(mock_elements_worker, payload, @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Name ({"name": 1234}, "name should be a str"), ({"name": None}, "name should be a str"), - ), + ], ) def test_partial_update_element_wrong_param_name(mock_elements_worker, payload, error): api_payload = { @@ -1287,8 +1398,8 @@ def test_partial_update_element_wrong_param_name(mock_elements_worker, payload, @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Polygon ({"polygon": "not a polygon"}, "polygon should be a list"), ({"polygon": None}, "polygon should be a list"), @@ -1305,7 +1416,7 @@ def test_partial_update_element_wrong_param_name(mock_elements_worker, payload, {"polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]]}, "polygon points should be lists of two numbers", ), - ), + ], ) def test_partial_update_element_wrong_param_polygon( mock_elements_worker, payload, error @@ -1322,8 +1433,8 @@ def test_partial_update_element_wrong_param_polygon( @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Confidence ({"confidence": "lol"}, "confidence should be None or a float in [0..1] range"), ({"confidence": "0.2"}, "confidence should be None or a float in [0..1] range"), @@ -1333,7 +1444,7 @@ def test_partial_update_element_wrong_param_polygon( {"confidence": float("inf")}, "confidence should be None or a float in [0..1] range", ), - ), + ], ) def test_partial_update_element_wrong_param_conf(mock_elements_worker, payload, error): api_payload = { @@ -1348,14 +1459,14 @@ def test_partial_update_element_wrong_param_conf(mock_elements_worker, payload, @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Rotation angle ({"rotation_angle": "lol"}, "rotation_angle should be a positive integer"), ({"rotation_angle": -1}, "rotation_angle should be a positive integer"), ({"rotation_angle": 0.5}, "rotation_angle should be a positive integer"), ({"rotation_angle": None}, "rotation_angle should be a positive integer"), - ), + ], ) def test_partial_update_element_wrong_param_rota(mock_elements_worker, payload, error): api_payload = { @@ -1370,13 +1481,13 @@ def test_partial_update_element_wrong_param_rota(mock_elements_worker, payload, @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Mirrored ({"mirrored": "lol"}, "mirrored should be a boolean"), ({"mirrored": 1234}, "mirrored should be a boolean"), ({"mirrored": None}, "mirrored should be a boolean"), - ), + ], ) def test_partial_update_element_wrong_param_mir(mock_elements_worker, payload, error): api_payload = { @@ -1391,13 +1502,13 @@ def test_partial_update_element_wrong_param_mir(mock_elements_worker, payload, e @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Image ({"image": "lol"}, "image should be a UUID"), ({"image": 1234}, "image should be a UUID"), ({"image": None}, "image should be a UUID"), - ), + ], ) def test_partial_update_element_wrong_param_image(mock_elements_worker, payload, error): api_payload = { @@ -1440,9 +1551,10 @@ def test_partial_update_element_api_error(responses, mock_elements_worker): ] +@pytest.mark.usefixtures("_mock_cached_elements", "_mock_cached_images") @pytest.mark.parametrize( "payload", - ( + [ ( { "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]], @@ -1463,15 +1575,9 @@ def test_partial_update_element_api_error(responses, mock_elements_worker): "mirrored": False, } ), - ), + ], ) -def test_partial_update_element( - responses, - mock_elements_worker_with_cache, - mock_cached_elements, - mock_cached_images, - payload, -): +def test_partial_update_element(responses, mock_elements_worker_with_cache, payload): elt = CachedElement.select().first() new_image = CachedImage.select().first() @@ -1516,9 +1622,10 @@ def test_partial_update_element( assert getattr(cached_element, param) == elt_response[param] -@pytest.mark.parametrize("confidence", (None, 0.42)) +@pytest.mark.usefixtures("_mock_cached_elements") +@pytest.mark.parametrize("confidence", [None, 0.42]) def test_partial_update_element_confidence( - responses, mock_elements_worker_with_cache, mock_cached_elements, confidence + responses, mock_elements_worker_with_cache, confidence ): elt = CachedElement.select().first() elt_response = { @@ -1661,13 +1768,13 @@ def test_list_element_children_wrong_with_metadata(mock_elements_worker): @pytest.mark.parametrize( - "param, value", - ( + ("param", "value"), + [ ("worker_version", 1234), ("worker_run", 1234), ("transcription_worker_version", 1234), ("transcription_worker_run", 1234), - ), + ], ) def test_list_element_children_wrong_worker_version(mock_elements_worker, param, value): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) @@ -1681,12 +1788,12 @@ def test_list_element_children_wrong_worker_version(mock_elements_worker, param, @pytest.mark.parametrize( "param", - ( - ("worker_version"), - ("worker_run"), - ("transcription_worker_version"), - ("transcription_worker_run"), - ), + [ + "worker_version", + "worker_run", + "transcription_worker_version", + "transcription_worker_run", + ], ) def test_list_element_children_wrong_bool_worker_version(mock_elements_worker, param): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) @@ -1908,9 +2015,10 @@ def test_list_element_children_with_cache_unhandled_param( ) +@pytest.mark.usefixtures("_mock_cached_elements") @pytest.mark.parametrize( - "filters, expected_ids", - ( + ("filters", "expected_ids"), + [ # Filter on element should give all elements inserted ( { @@ -1977,12 +2085,11 @@ def test_list_element_children_with_cache_unhandled_param( "33333333-3333-3333-3333-333333333333", ), ), - ), + ], ) def test_list_element_children_with_cache( responses, mock_elements_worker_with_cache, - mock_cached_elements, filters, expected_ids, ): @@ -1992,7 +2099,7 @@ def test_list_element_children_with_cache( # Query database through cache elements = mock_elements_worker_with_cache.list_element_children(**filters) assert elements.count() == len(expected_ids) - for child, expected_id in zip(elements.order_by("id"), expected_ids): + for child, expected_id in zip(elements.order_by("id"), expected_ids, strict=True): assert child.id == UUID(expected_id) # Check the worker never hits the API for elements @@ -2109,13 +2216,13 @@ def test_list_element_parents_wrong_with_metadata(mock_elements_worker): @pytest.mark.parametrize( - "param, value", - ( + ("param", "value"), + [ ("worker_version", 1234), ("worker_run", 1234), ("transcription_worker_version", 1234), ("transcription_worker_run", 1234), - ), + ], ) def test_list_element_parents_wrong_worker_version(mock_elements_worker, param, value): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) @@ -2129,12 +2236,12 @@ def test_list_element_parents_wrong_worker_version(mock_elements_worker, param, @pytest.mark.parametrize( "param", - ( - ("worker_version"), - ("worker_run"), - ("transcription_worker_version"), - ("transcription_worker_run"), - ), + [ + "worker_version", + "worker_run", + "transcription_worker_version", + "transcription_worker_run", + ], ) def test_list_element_parents_wrong_bool_worker_version(mock_elements_worker, param): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) @@ -2356,9 +2463,10 @@ def test_list_element_parents_with_cache_unhandled_param( ) +@pytest.mark.usefixtures("_mock_cached_elements") @pytest.mark.parametrize( - "filters, expected_id", - ( + ("filters", "expected_id"), + [ # Filter on element ( { @@ -2415,12 +2523,11 @@ def test_list_element_parents_with_cache_unhandled_param( }, "99999999-9999-9999-9999-999999999999", ), - ), + ], ) def test_list_element_parents_with_cache( responses, mock_elements_worker_with_cache, - mock_cached_elements, filters, expected_id, ): diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 426c485cedb0c8954b7a79c97d157e778c3ba9b9..5734458f6127b80eafd6c2033359b6ff421379d0 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import json import re from uuid import UUID @@ -56,7 +55,7 @@ def test_create_entity_wrong_type(mock_elements_worker): ) -def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker): +def test_create_entity_wrong_corpus(mock_elements_worker): # Triggering an error on metas param, not giving corpus should work since # ARKINDEX_CORPUS_ID environment variable is set on mock_elements_worker with pytest.raises(AssertionError, match="metas should be of type dict"): @@ -742,12 +741,13 @@ def test_list_corpus_entities(responses, mock_elements_worker): }, ) - # list is required to actually do the request - assert list(mock_elements_worker.list_corpus_entities()) == [ - { + mock_elements_worker.list_corpus_entities() + + assert mock_elements_worker.entities == { + "fake_entity_id": { "id": "fake_entity_id", } - ] + } assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ @@ -760,22 +760,13 @@ def test_list_corpus_entities(responses, mock_elements_worker): ] -@pytest.mark.parametrize( - "wrong_name", - [ - 1234, - 12.5, - ], -) +@pytest.mark.parametrize("wrong_name", [1234, 12.5]) def test_list_corpus_entities_wrong_name(mock_elements_worker, wrong_name): with pytest.raises(AssertionError, match="name should be of type str"): mock_elements_worker.list_corpus_entities(name=wrong_name) -@pytest.mark.parametrize( - "wrong_parent", - [{"id": "element_id"}, 12.5, "blabla"], -) +@pytest.mark.parametrize("wrong_parent", [{"id": "element_id"}, 12.5, "blabla"]) def test_list_corpus_entities_wrong_parent(mock_elements_worker, wrong_parent): with pytest.raises(AssertionError, match="parent should be of type Element"): mock_elements_worker.list_corpus_entities(parent=wrong_parent) @@ -850,7 +841,7 @@ def test_check_required_entity_types_no_creation_allowed( ] == BASE_API_CALLS -@pytest.mark.parametrize("transcription", (None, "not a transcription", 1)) +@pytest.mark.parametrize("transcription", [None, "not a transcription", 1]) def test_create_transcription_entities_wrong_transcription( mock_elements_worker, transcription ): @@ -865,8 +856,8 @@ def test_create_transcription_entities_wrong_transcription( @pytest.mark.parametrize( - "entities, error", - ( + ("entities", "error"), + [ (None, "entities shouldn't be null and should be of type list"), ( "not a list of entities", @@ -886,7 +877,7 @@ def test_create_transcription_entities_wrong_transcription( * 2, "entities should be unique", ), - ), + ], ) def test_create_transcription_entities_wrong_entities( mock_elements_worker, entities, error @@ -909,8 +900,8 @@ def test_create_transcription_entities_wrong_entities_subtype(mock_elements_work @pytest.mark.parametrize( - "entity, error", - ( + ("entity", "error"), + [ ( { "name": None, @@ -989,7 +980,7 @@ def test_create_transcription_entities_wrong_entities_subtype(mock_elements_work }, "Entity at index 0 in entities: confidence should be None or a float in [0..1] range", ), - ), + ], ) def test_create_transcription_entities_wrong_entity( mock_elements_worker, entity, error diff --git a/tests/test_elements_worker/test_metadata.py b/tests/test_elements_worker/test_metadata.py index ed3d6e43ee12cff456792a0c0558c121c991b63b..3cce81f4ac194074448d81e11e5c39fc8513bec6 100644 --- a/tests/test_elements_worker/test_metadata.py +++ b/tests/test_elements_worker/test_metadata.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import json import re @@ -247,22 +246,20 @@ def test_create_metadata_cached_element(responses, mock_elements_worker_with_cac @pytest.mark.parametrize( - "metadatas", + "metadata_list", [ - ([{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}]), - ( - [ - { - "type": MetaType.Text, - "name": "fake_name", - "value": "fake_value", - "entity_id": "fake_entity_id", - } - ] - ), + [{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}], + [ + { + "type": MetaType.Text, + "name": "fake_name", + "value": "fake_value", + "entity_id": "fake_entity_id", + } + ], ], ) -def test_create_metadatas(responses, mock_elements_worker, metadatas): +def test_create_metadatas(responses, mock_elements_worker, metadata_list): element = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, @@ -273,17 +270,19 @@ def test_create_metadatas(responses, mock_elements_worker, metadatas): "metadata_list": [ { "id": "fake_metadata_id", - "type": metadatas[0]["type"].value, - "name": metadatas[0]["name"], - "value": metadatas[0]["value"], + "type": metadata_list[0]["type"].value, + "name": metadata_list[0]["name"], + "value": metadata_list[0]["value"], "dates": [], - "entity_id": metadatas[0].get("entity_id"), + "entity_id": metadata_list[0].get("entity_id"), } ], }, ) - created_metadatas = mock_elements_worker.create_metadatas(element, metadatas) + created_metadata_list = mock_elements_worker.create_metadatas( + element, metadata_list + ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ @@ -296,42 +295,40 @@ def test_create_metadatas(responses, mock_elements_worker, metadatas): ] assert json.loads(responses.calls[-1].request.body)["metadata_list"] == [ { - "type": metadatas[0]["type"].value, - "name": metadatas[0]["name"], - "value": metadatas[0]["value"], - "entity_id": metadatas[0].get("entity_id"), + "type": metadata_list[0]["type"].value, + "name": metadata_list[0]["name"], + "value": metadata_list[0]["value"], + "entity_id": metadata_list[0].get("entity_id"), } ] - assert created_metadatas == [ + assert created_metadata_list == [ { "id": "fake_metadata_id", - "type": metadatas[0]["type"].value, - "name": metadatas[0]["name"], - "value": metadatas[0]["value"], + "type": metadata_list[0]["type"].value, + "name": metadata_list[0]["name"], + "value": metadata_list[0]["value"], "dates": [], - "entity_id": metadatas[0].get("entity_id"), + "entity_id": metadata_list[0].get("entity_id"), } ] @pytest.mark.parametrize( - "metadatas", + "metadata_list", [ - ([{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}]), - ( - [ - { - "type": MetaType.Text, - "name": "fake_name", - "value": "fake_value", - "entity_id": "fake_entity_id", - } - ] - ), + [{"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}], + [ + { + "type": MetaType.Text, + "name": "fake_name", + "value": "fake_value", + "entity_id": "fake_entity_id", + } + ], ], ) def test_create_metadatas_cached_element( - responses, mock_elements_worker_with_cache, metadatas + responses, mock_elements_worker_with_cache, metadata_list ): element = CachedElement.create( id="12341234-1234-1234-1234-123412341234", type="thing" @@ -345,18 +342,18 @@ def test_create_metadatas_cached_element( "metadata_list": [ { "id": "fake_metadata_id", - "type": metadatas[0]["type"].value, - "name": metadatas[0]["name"], - "value": metadatas[0]["value"], + "type": metadata_list[0]["type"].value, + "name": metadata_list[0]["name"], + "value": metadata_list[0]["value"], "dates": [], - "entity_id": metadatas[0].get("entity_id"), + "entity_id": metadata_list[0].get("entity_id"), } ], }, ) - created_metadatas = mock_elements_worker_with_cache.create_metadatas( - element, metadatas + created_metadata_list = mock_elements_worker_with_cache.create_metadatas( + element, metadata_list ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 @@ -370,35 +367,27 @@ def test_create_metadatas_cached_element( ] assert json.loads(responses.calls[-1].request.body)["metadata_list"] == [ { - "type": metadatas[0]["type"].value, - "name": metadatas[0]["name"], - "value": metadatas[0]["value"], - "entity_id": metadatas[0].get("entity_id"), + "type": metadata_list[0]["type"].value, + "name": metadata_list[0]["name"], + "value": metadata_list[0]["value"], + "entity_id": metadata_list[0].get("entity_id"), } ] - assert created_metadatas == [ + assert created_metadata_list == [ { "id": "fake_metadata_id", - "type": metadatas[0]["type"].value, - "name": metadatas[0]["name"], - "value": metadatas[0]["value"], + "type": metadata_list[0]["type"].value, + "name": metadata_list[0]["name"], + "value": metadata_list[0]["value"], "dates": [], - "entity_id": metadatas[0].get("entity_id"), + "entity_id": metadata_list[0].get("entity_id"), } ] -@pytest.mark.parametrize( - "wrong_element", - [ - None, - "not_element_type", - 1234, - 12.5, - ], -) +@pytest.mark.parametrize("wrong_element", [None, "not_element_type", 1234, 12.5]) def test_create_metadatas_wrong_element(mock_elements_worker, wrong_element): - wrong_metadatas = [ + wrong_metadata_list = [ {"type": MetaType.Text, "name": "fake_name", "value": "fake_value"} ] with pytest.raises( @@ -406,48 +395,42 @@ def test_create_metadatas_wrong_element(mock_elements_worker, wrong_element): match="element shouldn't be null and should be of type Element or CachedElement", ): mock_elements_worker.create_metadatas( - element=wrong_element, metadatas=wrong_metadatas + element=wrong_element, metadatas=wrong_metadata_list ) -@pytest.mark.parametrize( - "wrong_type", - [ - None, - "not_metadata_type", - 1234, - 12.5, - ], -) +@pytest.mark.parametrize("wrong_type", [None, "not_metadata_type", 1234, 12.5]) def test_create_metadatas_wrong_type(mock_elements_worker, wrong_type): element = Element({"id": "12341234-1234-1234-1234-123412341234"}) - wrong_metadatas = [{"type": wrong_type, "name": "fake_name", "value": "fake_value"}] + wrong_metadata_list = [ + {"type": wrong_type, "name": "fake_name", "value": "fake_value"} + ] with pytest.raises( AssertionError, match="type shouldn't be null and should be of type MetaType" ): mock_elements_worker.create_metadatas( - element=element, metadatas=wrong_metadatas + element=element, metadatas=wrong_metadata_list ) -@pytest.mark.parametrize("wrong_name", [(None), (1234), (12.5), ([1, 2, 3, 4])]) +@pytest.mark.parametrize("wrong_name", [None, 1234, 12.5, [1, 2, 3, 4]]) def test_create_metadatas_wrong_name(mock_elements_worker, wrong_name): element = Element({"id": "fake_element_id"}) - wrong_metadatas = [ + wrong_metadata_list = [ {"type": MetaType.Text, "name": wrong_name, "value": "fake_value"} ] with pytest.raises( AssertionError, match="name shouldn't be null and should be of type str" ): mock_elements_worker.create_metadatas( - element=element, metadatas=wrong_metadatas + element=element, metadatas=wrong_metadata_list ) -@pytest.mark.parametrize("wrong_value", [(None), ([1, 2, 3, 4])]) +@pytest.mark.parametrize("wrong_value", [None, [1, 2, 3, 4]]) def test_create_metadatas_wrong_value(mock_elements_worker, wrong_value): element = Element({"id": "fake_element_id"}) - wrong_metadatas = [ + wrong_metadata_list = [ {"type": MetaType.Text, "name": "fake_name", "value": wrong_value} ] with pytest.raises( @@ -457,21 +440,14 @@ def test_create_metadatas_wrong_value(mock_elements_worker, wrong_value): ), ): mock_elements_worker.create_metadatas( - element=element, metadatas=wrong_metadatas + element=element, metadatas=wrong_metadata_list ) -@pytest.mark.parametrize( - "wrong_entity", - [ - [1, 2, 3, 4], - 1234, - 12.5, - ], -) +@pytest.mark.parametrize("wrong_entity", [[1, 2, 3, 4], 1234, 12.5]) def test_create_metadatas_wrong_entity(mock_elements_worker, wrong_entity): element = Element({"id": "fake_element_id"}) - wrong_metadatas = [ + wrong_metadata_list = [ { "type": MetaType.Text, "name": "fake_name", @@ -481,13 +457,13 @@ def test_create_metadatas_wrong_entity(mock_elements_worker, wrong_entity): ] with pytest.raises(AssertionError, match="entity_id should be None or a str"): mock_elements_worker.create_metadatas( - element=element, metadatas=wrong_metadatas + element=element, metadatas=wrong_metadata_list ) def test_create_metadatas_api_error(responses, mock_elements_worker): element = Element({"id": "12341234-1234-1234-1234-123412341234"}) - metadatas = [ + metadata_list = [ { "type": MetaType.Text, "name": "fake_name", @@ -502,7 +478,7 @@ def test_create_metadatas_api_error(responses, mock_elements_worker): ) with pytest.raises(ErrorResponse): - mock_elements_worker.create_metadatas(element, metadatas) + mock_elements_worker.create_metadatas(element, metadata_list) assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert [ diff --git a/tests/test_elements_worker/test_task.py b/tests/test_elements_worker/test_task.py index decd7f2c7e6f2ff2b50ffd04b3a2d4ae054155fd..1492d02b7370f4d00c2e1328f224fb27bd420f2a 100644 --- a/tests/test_elements_worker/test_task.py +++ b/tests/test_elements_worker/test_task.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import uuid import pytest @@ -12,8 +11,8 @@ TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe") @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Task ID ( {"task_id": None}, @@ -23,7 +22,7 @@ TASK_ID = uuid.UUID("cafecafe-cafe-cafe-cafe-cafecafecafe") {"task_id": "12341234-1234-1234-1234-123412341234"}, "task_id shouldn't be null and should be an UUID", ), - ), + ], ) def test_list_artifacts_wrong_param_task_id(mock_dataset_worker, payload, error): with pytest.raises(AssertionError, match=error): @@ -97,8 +96,8 @@ def test_list_artifacts( @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Task ID ( {"task_id": None}, @@ -108,7 +107,7 @@ def test_list_artifacts( {"task_id": "12341234-1234-1234-1234-123412341234"}, "task_id shouldn't be null and should be an UUID", ), - ), + ], ) def test_download_artifact_wrong_param_task_id( mock_dataset_worker, default_artifact, payload, error @@ -124,8 +123,8 @@ def test_download_artifact_wrong_param_task_id( @pytest.mark.parametrize( - "payload, error", - ( + ("payload", "error"), + [ # Artifact ( {"artifact": None}, @@ -135,7 +134,7 @@ def test_download_artifact_wrong_param_task_id( {"artifact": "not artifact type"}, "artifact shouldn't be null and should be an Artifact", ), - ), + ], ) def test_download_artifact_wrong_param_artifact( mock_dataset_worker, default_artifact, payload, error diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 9d919c9898b3666d8d6fb9de99fe6208e2e55b36..aa2da6f6736ff6b643deaac12995bfcd1a3699fe 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -1,16 +1,14 @@ -# -*- coding: utf-8 -*- import logging import sys import pytest -import responses from arkindex.mock import MockApiClient from arkindex_worker.worker import BaseWorker from arkindex_worker.worker.training import TrainingMixin, create_archive -@pytest.fixture +@pytest.fixture() def mock_training_worker(monkeypatch): class TrainingWorker(BaseWorker, TrainingMixin): """ @@ -24,7 +22,7 @@ def mock_training_worker(monkeypatch): return training_worker -@pytest.fixture +@pytest.fixture() def default_model_version(): return { "id": "model_version_id", @@ -79,23 +77,32 @@ def test_create_archive_with_subfolder(model_file_dir_with_subfolder): assert not zst_archive_path.exists(), "Auto removal failed" -def test_handle_s3_uploading_errors(mock_training_worker, model_file_dir): +def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_dir): s3_endpoint_url = "http://s3.localhost.com" responses.add_passthru(s3_endpoint_url) - responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) + responses.add(responses.PUT, s3_endpoint_url, status=400) + + mock_training_worker.model_version = { + "state": "Created", + "s3_put_url": s3_endpoint_url, + } + file_path = model_file_dir / "model_file.pth" - with pytest.raises(Exception): - mock_training_worker.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url}) + with pytest.raises( + Exception, + match="400 Client Error: Bad Request for url: http://s3.localhost.com/", + ): + mock_training_worker.upload_to_s3(file_path) @pytest.mark.parametrize( "method", [ - ("publish_model_version"), - ("create_model_version"), - ("update_model_version"), - ("upload_to_s3"), - ("validate_model_version"), + "publish_model_version", + "create_model_version", + "update_model_version", + "upload_to_s3", + "validate_model_version", ], ) def test_training_mixin_read_only(mock_training_worker, method, caplog): diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index e2fdc33bee3b9160d17aedd1f2684f0068da7c59..eee2428271cd93063eaf51a18181062820e687e9 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import json import re from uuid import UUID @@ -1867,9 +1866,10 @@ def test_list_transcriptions_manual_worker_version(responses, mock_elements_work ] +@pytest.mark.usefixtures("_mock_cached_transcriptions") @pytest.mark.parametrize( - "filters, expected_ids", - ( + ("filters", "expected_ids"), + [ # Filter on element should give first and sixth transcription ( { @@ -1963,14 +1963,10 @@ def test_list_transcriptions_manual_worker_version(responses, mock_elements_work }, ("66666666-6666-6666-6666-666666666666",), ), - ), + ], ) def test_list_transcriptions_with_cache( - responses, - mock_elements_worker_with_cache, - mock_cached_transcriptions, - filters, - expected_ids, + responses, mock_elements_worker_with_cache, filters, expected_ids ): # Check we have 5 elements already present in database assert CachedTranscription.select().count() == 6 @@ -1979,7 +1975,7 @@ def test_list_transcriptions_with_cache( transcriptions = mock_elements_worker_with_cache.list_transcriptions(**filters) assert transcriptions.count() == len(expected_ids) for transcription, expected_id in zip( - transcriptions.order_by(CachedTranscription.id), expected_ids + transcriptions.order_by(CachedTranscription.id), expected_ids, strict=True ): assert transcription.id == UUID(expected_id) diff --git a/tests/test_elements_worker/test_worker.py b/tests/test_elements_worker/test_worker.py index b65556a1534ad7782b8ea1c3375274dc98fad18d..22a558c79aad6e20cff504387a0286923d3aa016 100644 --- a/tests/test_elements_worker/test_worker.py +++ b/tests/test_elements_worker/test_worker.py @@ -1,4 +1,3 @@ -# . -*- coding: utf-8 -*- import json import sys @@ -78,7 +77,8 @@ def test_readonly(responses, mock_elements_worker): ] == BASE_API_CALLS -def test_activities_disabled(responses, monkeypatch, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_activities_disabled(responses, monkeypatch): """Test worker process elements without updating activities when they are disabled for the process""" monkeypatch.setattr(sys, "argv", ["worker"]) worker = ElementsWorker() @@ -105,7 +105,8 @@ def test_activities_dev_mode(mocker): assert worker.store_activity is False -def test_update_call(responses, mock_elements_worker, mock_worker_run_api): +@pytest.mark.usefixtures("_mock_worker_run_api") +def test_update_call(responses, mock_elements_worker): """Test an update call with feature enabled triggers an API call""" responses.add( responses.PUT, @@ -141,8 +142,9 @@ def test_update_call(responses, mock_elements_worker, mock_worker_run_api): } +@pytest.mark.usefixtures("_mock_activity_calls") @pytest.mark.parametrize( - "process_exception, final_state", + ("process_exception", "final_state"), [ # Successful process_element (None, "processed"), @@ -161,7 +163,6 @@ def test_run( responses, process_exception, final_state, - mock_activity_calls, ): """Check the normal runtime sends 2 API calls to update activity""" # Disable second configure call from run() @@ -210,13 +211,8 @@ def test_run( } -def test_run_cache( - monkeypatch, - mocker, - mock_elements_worker_with_cache, - mock_cached_elements, - mock_activity_calls, -): +@pytest.mark.usefixtures("_mock_cached_elements", "_mock_activity_calls") +def test_run_cache(monkeypatch, mocker, mock_elements_worker_with_cache): # Disable second configure call from run() monkeypatch.setattr(mock_elements_worker_with_cache, "configure", lambda: None) @@ -310,8 +306,14 @@ def test_start_activity_error( @pytest.mark.parametrize( - "wk_version_config,wk_version_user_config,frontend_user_config,model_config,expected_config", ( + "wk_version_config", + "wk_version_user_config", + "frontend_user_config", + "model_config", + "expected_config", + ), + [ ({}, {}, {}, {}, {}), # Keep parameters from worker version configuration ({"parameter": 0}, {}, {}, {}, {"parameter": 0}), @@ -411,7 +413,7 @@ def test_start_activity_error( {"parameter": 2}, {"parameter": 3}, ), - ), + ], ) def test_worker_config_multiple_source( monkeypatch, diff --git a/tests/test_git.py b/tests/test_git.py deleted file mode 100644 index 0c78a456ab8a7f5367c182f89cbd8d8f3c763cdf..0000000000000000000000000000000000000000 --- a/tests/test_git.py +++ /dev/null @@ -1,480 +0,0 @@ -# -*- coding: utf-8 -*- -from pathlib import Path - -import pytest -from gitlab import GitlabCreateError, GitlabError -from requests import ConnectionError -from responses import matchers - -from arkindex_worker.git import GitlabHelper - -PROJECT_ID = 21259233 -MERGE_REQUEST_ID = 7 -SOURCE_BRANCH = "new_branch" -TARGET_BRANCH = "master" -MR_TITLE = "merge request title" -CREATE_MR_RESPONSE_JSON = { - "id": 107, - "iid": MERGE_REQUEST_ID, - "project_id": PROJECT_ID, - "title": MR_TITLE, - "target_branch": TARGET_BRANCH, - "source_branch": SOURCE_BRANCH, - # several fields omitted -} - - -@pytest.fixture -def fake_responses(responses): - responses.add( - responses.GET, - "https://gitlab.com/api/v4/projects/balsac_exporter%2Fbalsac-exported-xmls-testing", - json={ - "id": PROJECT_ID, - # several fields omitted - }, - ) - return responses - - -def test_clone_done(fake_git_helper): - assert not fake_git_helper.is_clone_finished - fake_git_helper._clone_done(None, None, None) - assert fake_git_helper.is_clone_finished - - -def test_clone(fake_git_helper): - command = fake_git_helper.run_clone_in_background() - cmd_str = " ".join(list(map(str, command.cmd))) - - assert "git" in cmd_str - assert "clone" in cmd_str - - -def _get_fn_name_from_call(call): - # call.add(2, 3) => "add" - return str(call)[len("call.") :].split("(")[0] - - -def test_save_files(fake_git_helper, mocker): - mocker.patch("sh.wc", return_value=2) - fake_git_helper._git = mocker.MagicMock() - fake_git_helper.is_clone_finished = True - fake_git_helper.success = True - - fake_git_helper.save_files(Path("/tmp/test_1234/tmp/")) - - expected_calls = ["checkout", "add", "commit", "show", "push"] - actual_calls = list(map(_get_fn_name_from_call, fake_git_helper._git.mock_calls)) - - assert actual_calls == expected_calls - assert fake_git_helper.gitlab_helper.merge.call_count == 1 - - -def test_save_files__fail_with_failed_clone(fake_git_helper, mocker): - mocker.patch("sh.wc", return_value=2) - fake_git_helper._git = mocker.MagicMock() - fake_git_helper.is_clone_finished = True - - with pytest.raises(Exception) as execinfo: - fake_git_helper.save_files(Path("/tmp/test_1234/tmp/")) - - assert execinfo.value.args[0] == "Clone was not a success" - - -def test_merge(mocker): - api = mocker.MagicMock() - project = mocker.MagicMock() - api.projects.get.return_value = project - merqe_request = mocker.MagicMock() - project.mergerequests.create.return_value = merqe_request - mocker.patch("gitlab.Gitlab", return_value=api) - - gitlab_helper = GitlabHelper("project_id", "url", "token", "branch") - - gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock() - gitlab_helper._wait_for_rebase_to_finish.return_value = True - - success = gitlab_helper.merge("source", "merge title") - - assert success - assert project.mergerequests.create.call_count == 1 - assert merqe_request.merge.call_count == 1 - - -def test_merge__rebase_failed(mocker): - api = mocker.MagicMock() - project = mocker.MagicMock() - api.projects.get.return_value = project - merqe_request = mocker.MagicMock() - project.mergerequests.create.return_value = merqe_request - mocker.patch("gitlab.Gitlab", return_value=api) - - gitlab_helper = GitlabHelper("project_id", "url", "token", "branch") - - gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock() - gitlab_helper._wait_for_rebase_to_finish.return_value = False - - success = gitlab_helper.merge("source", "merge title") - - assert not success - assert project.mergerequests.create.call_count == 1 - assert merqe_request.merge.call_count == 0 - - -def test_wait_for_rebase_to_finish(fake_responses, fake_gitlab_helper_factory): - get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True" - - fake_responses.add( - fake_responses.GET, - get_mr_url, - json={ - "rebase_in_progress": True, - "merge_error": None, - }, - ) - - fake_responses.add( - fake_responses.GET, - get_mr_url, - json={ - "rebase_in_progress": True, - "merge_error": None, - }, - ) - - fake_responses.add( - fake_responses.GET, - get_mr_url, - json={ - "rebase_in_progress": False, - "merge_error": None, - }, - ) - - gitlab_helper = fake_gitlab_helper_factory() - - success = gitlab_helper._wait_for_rebase_to_finish(MERGE_REQUEST_ID) - - assert success - assert len(fake_responses.calls) == 4 - assert gitlab_helper.is_rebase_finished - - -def test_wait_for_rebase_to_finish__fail_connection_error( - fake_responses, fake_gitlab_helper_factory -): - get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True" - - fake_responses.add( - fake_responses.GET, - get_mr_url, - body=ConnectionError(), - ) - - gitlab_helper = fake_gitlab_helper_factory() - - with pytest.raises(ConnectionError): - gitlab_helper._wait_for_rebase_to_finish(MERGE_REQUEST_ID) - - assert len(fake_responses.calls) == 2 - assert not gitlab_helper.is_rebase_finished - - -def test_wait_for_rebase_to_finish__fail_server_error( - fake_responses, fake_gitlab_helper_factory -): - get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True" - - fake_responses.add( - fake_responses.GET, - get_mr_url, - body="Service Unavailable", - status=503, - ) - - gitlab_helper = fake_gitlab_helper_factory() - - with pytest.raises(GitlabError): - gitlab_helper._wait_for_rebase_to_finish(MERGE_REQUEST_ID) - - assert len(fake_responses.calls) == 2 - assert not gitlab_helper.is_rebase_finished - - -def test_merge_request(fake_responses, fake_gitlab_helper_factory, mocker): - fake_responses.add( - fake_responses.POST, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests", - json=CREATE_MR_RESPONSE_JSON, - ) - - fake_responses.add( - fake_responses.PUT, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase", - json={}, - ) - - fake_responses.add( - fake_responses.PUT, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/merge", - json={ - "iid": MERGE_REQUEST_ID, - "state": "merged", - # several fields omitted - }, - match=[matchers.json_params_matcher({"should_remove_source_branch": True})], - ) - - # the fake_responses are defined in the same order as they are expected to be called - expected_http_methods = [r.method for r in fake_responses.registered()] - expected_urls = [r.url for r in fake_responses.registered()] - - gitlab_helper = fake_gitlab_helper_factory() - gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock() - gitlab_helper._wait_for_rebase_to_finish.return_value = True - - success = gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE) - assert success - assert len(fake_responses.calls) == 4 - assert [c.request.method for c in fake_responses.calls] == expected_http_methods - assert [c.request.url for c in fake_responses.calls] == expected_urls - - -def test_merge_request_fail(fake_responses, fake_gitlab_helper_factory, mocker): - fake_responses.add( - fake_responses.POST, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests", - json=CREATE_MR_RESPONSE_JSON, - ) - - fake_responses.add( - fake_responses.PUT, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase", - json={}, - ) - - fake_responses.add( - fake_responses.PUT, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/merge", - json={"error": "Method not allowed"}, - status=405, - match=[matchers.json_params_matcher({"should_remove_source_branch": True})], - ) - - # the fake_responses are defined in the same order as they are expected to be called - expected_http_methods = [r.method for r in fake_responses.registered()] - expected_urls = [r.url for r in fake_responses.registered()] - - gitlab_helper = fake_gitlab_helper_factory() - gitlab_helper._wait_for_rebase_to_finish = mocker.MagicMock() - gitlab_helper._wait_for_rebase_to_finish.return_value = True - - success = gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE) - - assert not success - assert len(fake_responses.calls) == 4 - assert [c.request.method for c in fake_responses.calls] == expected_http_methods - assert [c.request.url for c in fake_responses.calls] == expected_urls - - -def test_merge_request__success_after_errors( - fake_responses, fake_gitlab_helper_factory -): - fake_responses.add( - fake_responses.POST, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests", - json=CREATE_MR_RESPONSE_JSON, - ) - - rebase_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase" - - fake_responses.add( - fake_responses.PUT, - rebase_url, - json={"rebase_in_progress": True}, - ) - - get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True" - - fake_responses.add( - fake_responses.GET, - get_mr_url, - body="Service Unavailable", - status=503, - ) - - fake_responses.add( - fake_responses.PUT, - rebase_url, - json={"rebase_in_progress": True}, - ) - - fake_responses.add( - fake_responses.GET, - get_mr_url, - body=ConnectionError(), - ) - - fake_responses.add( - fake_responses.PUT, - rebase_url, - json={"rebase_in_progress": True}, - ) - - fake_responses.add( - fake_responses.GET, - get_mr_url, - json={ - "rebase_in_progress": True, - "merge_error": None, - }, - ) - - fake_responses.add( - fake_responses.GET, - get_mr_url, - json={ - "rebase_in_progress": False, - "merge_error": None, - }, - ) - - fake_responses.add( - fake_responses.PUT, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/merge", - json={ - "iid": MERGE_REQUEST_ID, - "state": "merged", - # several fields omitted - }, - match=[matchers.json_params_matcher({"should_remove_source_branch": True})], - ) - - # the fake_responses are defined in the same order as they are expected to be called - expected_http_methods = [r.method for r in fake_responses.registered()] - expected_urls = [r.url for r in fake_responses.registered()] - - gitlab_helper = fake_gitlab_helper_factory() - - success = gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE) - - assert success - assert len(fake_responses.calls) == 10 - assert [c.request.method for c in fake_responses.calls] == expected_http_methods - assert [c.request.url for c in fake_responses.calls] == expected_urls - - -def test_merge_request__fail_bad_request(fake_responses, fake_gitlab_helper_factory): - fake_responses.add( - fake_responses.POST, - f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests", - json=CREATE_MR_RESPONSE_JSON, - ) - - rebase_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}/rebase" - - fake_responses.add( - fake_responses.PUT, - rebase_url, - json={"rebase_in_progress": True}, - ) - - get_mr_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests/{MERGE_REQUEST_ID}?include_rebase_in_progress=True" - - fake_responses.add( - fake_responses.GET, - get_mr_url, - body="Bad Request", - status=400, - ) - - # the fake_responses are defined in the same order as they are expected to be called - expected_http_methods = [r.method for r in fake_responses.registered()] - expected_urls = [r.url for r in fake_responses.registered()] - - gitlab_helper = fake_gitlab_helper_factory() - - with pytest.raises(GitlabError): - gitlab_helper.merge(SOURCE_BRANCH, MR_TITLE) - - assert len(fake_responses.calls) == 4 - assert [c.request.method for c in fake_responses.calls] == expected_http_methods - assert [c.request.url for c in fake_responses.calls] == expected_urls - - -def test_create_merge_request__no_retry_5xx_error( - fake_responses, fake_gitlab_helper_factory -): - request_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests" - - fake_responses.add( - fake_responses.POST, - request_url, - body="Service Unavailable", - status=503, - ) - - # the fake_responses are defined in the same order as they are expected to be called - expected_http_methods = [r.method for r in fake_responses.registered()] - expected_urls = [r.url for r in fake_responses.registered()] - - gitlab_helper = fake_gitlab_helper_factory() - - with pytest.raises(GitlabCreateError): - gitlab_helper.project.mergerequests.create( - { - "source_branch": "branch", - "target_branch": gitlab_helper.branch, - "title": "MR title", - } - ) - - assert len(fake_responses.calls) == 2 - assert [c.request.method for c in fake_responses.calls] == expected_http_methods - assert [c.request.url for c in fake_responses.calls] == expected_urls - - -def test_create_merge_request__retry_5xx_error( - fake_responses, fake_gitlab_helper_factory -): - request_url = f"https://gitlab.com/api/v4/projects/{PROJECT_ID}/merge_requests" - - fake_responses.add( - fake_responses.POST, - request_url, - body="Service Unavailable", - status=503, - ) - - fake_responses.add( - fake_responses.POST, - request_url, - body="Service Unavailable", - status=503, - ) - - fake_responses.add( - fake_responses.POST, - request_url, - json=CREATE_MR_RESPONSE_JSON, - ) - - # the fake_responses are defined in the same order as they are expected to be called - expected_http_methods = [r.method for r in fake_responses.registered()] - expected_urls = [r.url for r in fake_responses.registered()] - - gitlab_helper = fake_gitlab_helper_factory() - - gitlab_helper.project.mergerequests.create( - { - "source_branch": "branch", - "target_branch": gitlab_helper.branch, - "title": "MR title", - }, - retry_transient_errors=True, - ) - - assert len(fake_responses.calls) == 4 - assert [c.request.method for c in fake_responses.calls] == expected_http_methods - assert [c.request.url for c in fake_responses.calls] == expected_urls diff --git a/tests/test_image.py b/tests/test_image.py index 16be7d4ac1bf98a9421aee2260fef58a3323a76b..2facf4af2349a9089ddc92f03b81507cca6d152b 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import math import unittest import uuid @@ -124,13 +123,13 @@ def test_download_tiles_small(responses): @pytest.mark.parametrize( - "path, is_local", - ( + ("path", "is_local"), + [ ("http://somewhere/test.jpg", False), ("https://somewhere/test.jpg", False), ("path/to/something", True), ("/absolute/path/to/something", True), - ), + ], ) def test_open_image(path, is_local, monkeypatch): """Check if the path triggers a local load or a remote one""" @@ -149,13 +148,13 @@ def test_open_image(path, is_local, monkeypatch): @pytest.mark.parametrize( - "rotation_angle,mirrored,expected_path", - ( + ("rotation_angle", "mirrored", "expected_path"), + [ (0, False, TILE), (45, False, ROTATED_IMAGE), (0, True, MIRRORED_IMAGE), (45, True, ROTATED_MIRRORED_IMAGE), - ), + ], ) def test_open_image_rotate_mirror(rotation_angle, mirrored, expected_path): expected = Image.open(expected_path).convert("RGB") @@ -245,8 +244,9 @@ class TestTrimPolygon(unittest.TestCase): [99, 208], ] } - with self.assertRaises( - AssertionError, msg="Input polygon must be a valid list or tuple of points." + with pytest.raises( + AssertionError, + match="Input polygon must be a valid list or tuple of points.", ): trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"]) @@ -305,8 +305,8 @@ class TestTrimPolygon(unittest.TestCase): [997, 206], [999, 200], ] - with self.assertRaises( - AssertionError, msg="This polygon is entirely outside the image's bounds." + with pytest.raises( + AssertionError, match="This polygon is entirely outside the image's bounds." ): trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"]) @@ -328,8 +328,8 @@ class TestTrimPolygon(unittest.TestCase): [197, 206], [99, 20.8], ] - with self.assertRaises( - AssertionError, msg="Polygon point coordinates must be integers." + with pytest.raises( + AssertionError, match="Polygon point coordinates must be integers." ): trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"]) @@ -347,8 +347,8 @@ class TestTrimPolygon(unittest.TestCase): [72, 57], [12, 56], ] - with self.assertRaises( - AssertionError, msg="Polygon points must be tuples or lists." + with pytest.raises( + AssertionError, match="Polygon points must be tuples or lists." ): trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"]) @@ -366,15 +366,16 @@ class TestTrimPolygon(unittest.TestCase): [72, 57], [12, 56], ] - with self.assertRaises( - AssertionError, msg="Polygon points must be tuples or lists of 2 elements." + with pytest.raises( + AssertionError, + match="Polygon points must be tuples or lists of 2 elements.", ): trim_polygon(bad_polygon, TEST_IMAGE["width"], TEST_IMAGE["height"]) @pytest.mark.parametrize( - "angle, mirrored, updated_bounds, reverse", - ( + ("angle", "mirrored", "updated_bounds", "reverse"), + [ ( 0, False, @@ -471,7 +472,7 @@ class TestTrimPolygon(unittest.TestCase): {"x": 11, "y": 295, "width": 47, "height": 111}, # upper right False, ), - ), + ], ) def test_revert_orientation(angle, mirrored, updated_bounds, reverse, tmp_path): """Test cases, for both Elements and CachedElements: diff --git a/tests/test_merge.py b/tests/test_merge.py index 0def6ba144da9d87c6f36c67a51605b28c6cbc42..504a8e7f19ed2bfe15368fe4d12bf5cc00a95110 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from uuid import UUID import pytest @@ -18,8 +17,8 @@ from arkindex_worker.cache import ( @pytest.mark.parametrize( - "parents, expected_elements, expected_transcriptions", - ( + ("parents", "expected_elements", "expected_transcriptions"), + [ # Nothing happen when no parents are available ([], [], []), # Nothing happen when the parent file does not exist @@ -73,7 +72,7 @@ from arkindex_worker.cache import ( UUID("22222222-2222-2222-2222-222222222222"), ], ), - ), + ], ) def test_merge_databases( mock_databases, tmp_path, parents, expected_elements, expected_transcriptions @@ -114,7 +113,7 @@ def test_merge_databases( ] == expected_transcriptions -def test_merge_chunk(mock_databases, tmp_path, monkeypatch): +def test_merge_chunk(mock_databases, tmp_path): """ Check the db merge algorithm support two parents and one of them has a chunk @@ -155,7 +154,7 @@ def test_merge_chunk(mock_databases, tmp_path, monkeypatch): def test_merge_from_worker( - responses, mock_base_worker_with_cache, mock_databases, tmp_path, monkeypatch + responses, mock_base_worker_with_cache, mock_databases, tmp_path ): """ High level merge from the base worker diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e6d571fba6a9ccdffca4d7576e950c6b1a29c1c..46cb7d12a19bdf3d0794bf03a4e403c1c6d44ec3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from pathlib import Path from arkindex_worker.utils import close_delete_file, extract_tar_zst_archive diff --git a/worker-{{cookiecutter.slug}}/.cookiecutter.json b/worker-{{cookiecutter.slug}}/.cookiecutter.json index 80cdd2c340a3ce399ca2d3e7e9dd5c9704547b29..20b4982c006d7432c5a3fea63c981bbbd2e945df 100644 --- a/worker-{{cookiecutter.slug}}/.cookiecutter.json +++ b/worker-{{cookiecutter.slug}}/.cookiecutter.json @@ -4,5 +4,5 @@ "description": "{{ cookiecutter.description }}", "worker_type": "{{ cookiecutter.worker_type }}", "author": "{{ cookiecutter.author }}", - "email": "{{ cookiecutter.email}}" + "email": "{{ cookiecutter.email }}" } diff --git a/worker-{{cookiecutter.slug}}/.gitlab-ci.yml b/worker-{{cookiecutter.slug}}/.gitlab-ci.yml index b74ad97bba7a991f67c73ce5c681d02c69f26e64..92a0dc90cfdb68ac92de68f6b4ee3a6a0e5476c3 100644 --- a/worker-{{cookiecutter.slug}}/.gitlab-ci.yml +++ b/worker-{{cookiecutter.slug}}/.gitlab-ci.yml @@ -3,9 +3,18 @@ stages: - build - release +# GitLab provides a template to ensure pipelines run only for branches and tags, not for merge requests +# This prevents duplicate pipelines in merge requests. +# See https://docs.gitlab.com/ee/ci/troubleshooting.html#job-may-allow-multiple-pipelines-to-run-for-a-single-action +include: + - template: 'Workflows/Branch-Pipelines.gitlab-ci.yml' + +variables: + VERSION: $CI_COMMIT_SHA + DEBIAN_FRONTEND: non-interactive + test: - # Pinned to <3.12 till next arkindex-base-worker release - image: python:3.11 + image: python:slim stage: test cache: @@ -19,6 +28,9 @@ test: before_script: - pip install tox + # Install curl + - apt-get update -q -y && apt-get install -q -y --no-install-recommends curl + # Download OpenAPI schema from last backend build - curl https://assets.teklia.com/arkindex/openapi.yml > schema.yml @@ -29,7 +41,7 @@ test: - tox -- --junitxml=test-report.xml --durations=50 lint: - image: python:3 + image: python:slim cache: paths: @@ -43,6 +55,9 @@ lint: before_script: - pip install pre-commit + # Install git + - apt-get update -q -y && apt-get install -q -y --no-install-recommends git + except: - schedules @@ -58,8 +73,15 @@ docker-build: DOCKER_DRIVER: overlay2 DOCKER_HOST: tcp://docker:2375/ - except: - - schedules + rules: + # Never run on scheduled pipelines + - if: '$CI_PIPELINE_SOURCE == "schedule"' + when: never + # Use commit tag when running on tagged commit + - if: $CI_COMMIT_TAG + variables: + VERSION: $CI_COMMIT_TAG + - when: on_success script: - ci/build.sh @@ -68,6 +90,7 @@ release-notes: stage: release image: registry.gitlab.teklia.com/infra/devops:latest + # Only run on tags only: - tags @@ -83,3 +106,27 @@ bump-python-deps: script: - devops python-deps requirements.txt + +publish-worker: + stage: release + allow_failure: true + image: registry.gitlab.teklia.com/arkindex/cli:latest + + script: + - arkindex -p "$ARKINDEX_INSTANCE" --gitlab-secure-file arkindex-cli.yaml worker publish "$CI_REGISTRY_IMAGE:$VERSION" + + rules: + # Never run on scheduled pipelines + - if: '$CI_PIPELINE_SOURCE == "schedule"' + when: never + # Use commit tag when running on tagged commit + - if: $CI_COMMIT_TAG + variables: + VERSION: $CI_COMMIT_TAG + - when: on_success + + parallel: + matrix: + - ARKINDEX_INSTANCE: + # Publish worker on https://demo.arkindex.org + - demo diff --git a/worker-{{cookiecutter.slug}}/.pre-commit-config.yaml b/worker-{{cookiecutter.slug}}/.pre-commit-config.yaml index 91c7412f7437ef4066ba4193502b99c510bb347b..0de2cee1f7f6af5a26f8d081f7b6dd2acbb7a1a6 100644 --- a/worker-{{cookiecutter.slug}}/.pre-commit-config.yaml +++ b/worker-{{cookiecutter.slug}}/.pre-commit-config.yaml @@ -1,16 +1,15 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.278 + rev: v0.1.7 hooks: + # Run the linter. - id: ruff args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/ambv/black - rev: 23.1.0 - hooks: - - id: black + # Run the formatter. + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-ast - id: check-docstring-first @@ -25,9 +24,10 @@ repos: - id: name-tests-test args: ['--django'] - id: check-json + - id: check-toml - id: requirements-txt-fixer - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.2.6 hooks: - id: codespell args: ['--write-changes'] diff --git a/worker-{{cookiecutter.slug}}/Dockerfile b/worker-{{cookiecutter.slug}}/Dockerfile index aac8903642e3bc416dbbbbbf8f23779f3f09fe6b..d8ea5fcf2609f36732e531d9b4ab77900275b0f4 100644 --- a/worker-{{cookiecutter.slug}}/Dockerfile +++ b/worker-{{cookiecutter.slug}}/Dockerfile @@ -7,12 +7,12 @@ ENV DEBIAN_FRONTEND=non-interactive RUN apt-get update -q -y && apt-get install -q -y --no-install-recommends curl # Install worker as a package -COPY worker_{{cookiecutter.slug}} worker_{{cookiecutter.slug}} -COPY requirements.txt setup.py VERSION ./ +COPY worker_{{cookiecutter.__module}} worker_{{cookiecutter.__module}} +COPY requirements.txt setup.py pyproject.toml ./ RUN pip install . --no-cache-dir # Add archi local CA RUN curl https://assets.teklia.com/teklia_dev_ca.pem > /usr/local/share/ca-certificates/arkindex-dev.crt && update-ca-certificates ENV REQUESTS_CA_BUNDLE /etc/ssl/certs/ca-certificates.crt -CMD ["worker-{{ cookiecutter.slug }}"] +CMD ["worker-{{ cookiecutter.__package }}"] diff --git a/worker-{{cookiecutter.slug}}/MANIFEST.in b/worker-{{cookiecutter.slug}}/MANIFEST.in index fd959fa8501e56bc4f1869e363b4a2118a86edce..f9bd1455b374de796e12d240c1211dee9829d97e 100644 --- a/worker-{{cookiecutter.slug}}/MANIFEST.in +++ b/worker-{{cookiecutter.slug}}/MANIFEST.in @@ -1,2 +1 @@ include requirements.txt -include VERSION diff --git a/worker-{{cookiecutter.slug}}/Makefile b/worker-{{cookiecutter.slug}}/Makefile index c98647b47591ecbd3dceecbe906b8f38328ad381..f9322fd1664aa6f32a3259456bb19a2e285859f3 100644 --- a/worker-{{cookiecutter.slug}}/Makefile +++ b/worker-{{cookiecutter.slug}}/Makefile @@ -1,8 +1,10 @@ .PHONY: release release: - $(eval version:=$(shell cat VERSION)) + # Grep the version from pyproject.toml, squeeze multiple spaces, delete double and single quotes, get 3rd val. + # This command tolerates multiple whitespace sequences around the version number. + $(eval version:=$(shell grep -m 1 version pyproject.toml | tr -s ' ' | tr -d '"' | tr -d "'" | cut -d' ' -f3)) echo Releasing version $(version) - git commit VERSION -m "Version $(version)" + git commit pyproject.toml -m "Version $(version)" git tag $(version) git push origin master $(version) diff --git a/worker-{{cookiecutter.slug}}/README.md b/worker-{{cookiecutter.slug}}/README.md index fff0e963eef20ec7eb3fbdc1c4da0e6a56392427..1611fd3b31f42ca658481007cefec95f7aeb5ef5 100644 --- a/worker-{{cookiecutter.slug}}/README.md +++ b/worker-{{cookiecutter.slug}}/README.md @@ -1,4 +1,4 @@ -# {{ cookiecutter.slug }} +# {{ cookiecutter.name }} {{ cookiecutter.description }} diff --git a/worker-{{cookiecutter.slug}}/VERSION b/worker-{{cookiecutter.slug}}/VERSION deleted file mode 100644 index 6e8bf73aa550d4c57f6f35830f1bcdc7a4a62f38..0000000000000000000000000000000000000000 --- a/worker-{{cookiecutter.slug}}/VERSION +++ /dev/null @@ -1 +0,0 @@ -0.1.0 diff --git a/worker-{{cookiecutter.slug}}/ci/build.sh b/worker-{{cookiecutter.slug}}/ci/build.sh index 7e0c29f23d9e2c2e1cbf2eab63b092efe573a738..fe1577dd5c1aca4db5562a07c943a8165418e1a5 100755 --- a/worker-{{cookiecutter.slug}}/ci/build.sh +++ b/worker-{{cookiecutter.slug}}/ci/build.sh @@ -1,17 +1,12 @@ #!/bin/sh -e # Build the tasks Docker image. # Requires CI_PROJECT_DIR and CI_REGISTRY_IMAGE to be set. -# VERSION defaults to latest. # Will automatically login to a registry if CI_REGISTRY, CI_REGISTRY_USER and CI_REGISTRY_PASSWORD are set. # Will only push an image if $CI_REGISTRY is set. -if [ -z "$VERSION" ]; then - VERSION=${CI_COMMIT_TAG:-latest} -fi - if [ -z "$VERSION" -o -z "$CI_PROJECT_DIR" -o -z "$CI_REGISTRY_IMAGE" ]; then - echo Missing environment variables - exit 1 + echo Missing environment variables + exit 1 fi IMAGE_TAG="$CI_REGISTRY_IMAGE:$VERSION" @@ -19,14 +14,9 @@ IMAGE_TAG="$CI_REGISTRY_IMAGE:$VERSION" cd $CI_PROJECT_DIR docker build -f Dockerfile . -t "$IMAGE_TAG" -# Publish the image on the main branch or on a tag -if [ "$CI_COMMIT_REF_NAME" = "$CI_DEFAULT_BRANCH" -o -n "$CI_COMMIT_TAG" ]; then - if [ -n "$CI_REGISTRY" -a -n "$CI_REGISTRY_USER" -a -n "$CI_REGISTRY_PASSWORD" ]; then - echo $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY - docker push $IMAGE_TAG - else - echo "Missing environment variables to log in to the container registry…" - fi +if [ -n "$CI_REGISTRY" -a -n "$CI_REGISTRY_USER" -a -n "$CI_REGISTRY_PASSWORD" ]; then + echo $CI_REGISTRY_PASSWORD | docker login -u $CI_REGISTRY_USER --password-stdin $CI_REGISTRY + docker push $IMAGE_TAG else - echo "The build was not published to the repository registry (only for main branch or tags)…" + echo "Missing environment variables to log in to the container registry…" fi diff --git a/worker-{{cookiecutter.slug}}/pyproject.toml b/worker-{{cookiecutter.slug}}/pyproject.toml index 947ec91214fecf969a899d496d14d166af95abb3..1603002153c2b625d4f9d75ab93d15432b675758 100644 --- a/worker-{{cookiecutter.slug}}/pyproject.toml +++ b/worker-{{cookiecutter.slug}}/pyproject.toml @@ -1,7 +1,66 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "worker_{{ cookiecutter.__module }}" +version = "0.1.0" +description = "{{ cookiecutter.description }}" +dynamic = ["dependencies"] +authors = [ + { name = "{{ cookiecutter.author }}", email = "{{ cookiecutter.email }}" }, +] +maintainers = [ + { name = "{{ cookiecutter.author }}", email = "{{ cookiecutter.email }}" }, +] +requires-python = ">=3.10" +readme = { file = "README.md", content-type = "text/markdown" } +keywords = ["python"] +classifiers = [ + # Specify the Python versions you support here. + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] + +[project.scripts] +worker-{{ cookiecutter.__package }} = "worker_{{ cookiecutter.__module }}.worker:main" + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } + [tool.ruff] exclude = [".git", "__pycache__"] ignore = ["E501"] -select = ["E", "F", "T1", "W", "I"] +select = [ + # pycodestyle + "E", + "W", + # Pyflakes + "F", + # Flake8 Debugger + "T1", + # Isort + "I", + # Implicit Optional + "RUF013", + # Invalid pyproject.toml + "RUF200", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # flake8-pytest-style + "PT", + # flake8-use-pathlib + "PTH", +] + +[tool.ruff.per-file-ignores] +# Ignore `pytest-composite-assertion` rules of `flake8-pytest-style` linter for non-test files +"worker_{{ cookiecutter.__module }}/**/*.py" = ["PT018"] [tool.ruff.isort] known-first-party = ["arkindex", "arkindex_worker"] diff --git a/worker-{{cookiecutter.slug}}/requirements.txt b/worker-{{cookiecutter.slug}}/requirements.txt index 6c9874aedf19f8cb6363134be80fa672fd00ef01..74b62843647b522a395ad37b77364d1a51c7e4ab 100644 --- a/worker-{{cookiecutter.slug}}/requirements.txt +++ b/worker-{{cookiecutter.slug}}/requirements.txt @@ -1 +1 @@ -arkindex-base-worker==0.3.4 +arkindex-base-worker==0.3.5 diff --git a/worker-{{cookiecutter.slug}}/setup.py b/worker-{{cookiecutter.slug}}/setup.py index 9bc36fdf395b540596abe7a34f02a7ad38dea0a3..ca9ba4a1b5b004d42d4c001f0f8d8b8e7338f102 100755 --- a/worker-{{cookiecutter.slug}}/setup.py +++ b/worker-{{cookiecutter.slug}}/setup.py @@ -1,52 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- - -import re -from pathlib import Path -from typing import List - from setuptools import find_packages, setup -MODULE = "worker_{{cookiecutter.slug}}" -COMMAND = "worker-{{cookiecutter.slug}}" - -SUBMODULE_PATTERN = re.compile("-e ((?:(?!#egg=).)*)(?:#egg=)?(.*)") - - -def parse_requirements_line(line: str) -> str: - # Special case for git requirements - if line.startswith("git+http"): - assert "@" in line, "Branch should be specified with suffix (ex: @master)" - assert ( - "#egg=" in line - ), "Package name should be specified with suffix (ex: #egg=kraken)" - package_name: str = line.split("#egg=")[-1] - return f"{package_name} @ {line}" - # Special case for submodule requirements - elif line.startswith("-e"): - package_path, package_name = SUBMODULE_PATTERN.match(line).groups() - package_path: Path = Path(package_path).resolve() - # Package name is optional: use folder name by default - return f"{package_name or package_path.name} @ file://{package_path}" - else: - return line - - -def parse_requirements() -> List[str]: - path = Path(__file__).parent.resolve() / "requirements.txt" - assert path.exists(), f"Missing requirements: {path}" - return list( - map(parse_requirements_line, map(str.strip, path.read_text().splitlines())) - ) - - -setup( - name=MODULE, - version=open("VERSION").read(), - description="{{ cookiecutter.description }}", - author="{{ cookiecutter.author }}", - author_email="{{ cookiecutter.email }}", - install_requires=parse_requirements(), - entry_points={"console_scripts": [f"{COMMAND}={MODULE}.worker:main"]}, - packages=find_packages(), -) +setup(packages=find_packages()) diff --git a/worker-{{cookiecutter.slug}}/tests/conftest.py b/worker-{{cookiecutter.slug}}/tests/conftest.py index da3b0bc18524432fad2b04e21b38f0460e8d9495..8d0ba07dc3eff3ac9f03f8f353de337c2f827ab3 100644 --- a/worker-{{cookiecutter.slug}}/tests/conftest.py +++ b/worker-{{cookiecutter.slug}}/tests/conftest.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os import pytest @@ -8,14 +7,14 @@ from arkindex_worker.worker.base import BaseWorker @pytest.fixture(autouse=True) -def setup_environment(responses, monkeypatch) -> None: +def _setup_environment(responses, monkeypatch) -> None: """Setup needed environment variables""" # Allow accessing remote API schemas # defaulting to the prod environment schema_url = os.environ.get( "ARKINDEX_API_SCHEMA_URL", - "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json", + "https://demo.arkindex.org/api/v1/openapi/?format=openapi-json", ) responses.add_passthru(schema_url) @@ -23,6 +22,8 @@ def setup_environment(responses, monkeypatch) -> None: os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url # Setup a fake worker run ID os.environ["ARKINDEX_WORKER_RUN_ID"] = "1234-{{ cookiecutter.slug }}" + # Setup a fake corpus ID + os.environ["ARKINDEX_CORPUS_ID"] = "1234-corpus-id" # Setup a mock api client instead of using a real one def mock_setup_api_client(self): diff --git a/worker-{{cookiecutter.slug}}/tests/test_worker.py b/worker-{{cookiecutter.slug}}/tests/test_worker.py index b995a94f30ef8f7864baba1946717fe923e302b8..ca9a24e4eaefa991cf1d4cc28b840b6d1fcf971d 100644 --- a/worker-{{cookiecutter.slug}}/tests/test_worker.py +++ b/worker-{{cookiecutter.slug}}/tests/test_worker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import importlib @@ -8,6 +7,6 @@ def test_dummy(): def test_import(): """Import our newly created module, through importlib to avoid parsing issues""" - worker = importlib.import_module("worker_{{ cookiecutter.slug }}.worker") + worker = importlib.import_module("worker_{{ cookiecutter.__module }}.worker") assert hasattr(worker, "Demo") assert hasattr(worker.Demo, "process_element") diff --git a/worker-{{cookiecutter.slug}}/tox.ini b/worker-{{cookiecutter.slug}}/tox.ini index a30f423cabbebbf543b7389e0eec7b3d359895fc..1a8890ace24fbc3e6b00d90ef585b0459bc0cfd4 100644 --- a/worker-{{cookiecutter.slug}}/tox.ini +++ b/worker-{{cookiecutter.slug}}/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = worker-{{ cookiecutter.slug }} +envlist = worker-{{ cookiecutter.__package }} [testenv] passenv = ARKINDEX_API_SCHEMA_URL diff --git a/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.slug}}/__init__.py b/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.__module}}/__init__.py similarity index 83% rename from worker-{{cookiecutter.slug}}/worker_{{cookiecutter.slug}}/__init__.py rename to worker-{{cookiecutter.slug}}/worker_{{cookiecutter.__module}}/__init__.py index def772032d44c4b0a53b635ae11d814ab8ebbfc8..75fdb48008b7c68853176385bc7b70ebe02c6fc8 100644 --- a/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.slug}}/__init__.py +++ b/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.__module}}/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import logging logging.basicConfig( diff --git a/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.slug}}/worker.py b/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.__module}}/worker.py similarity index 94% rename from worker-{{cookiecutter.slug}}/worker_{{cookiecutter.slug}}/worker.py rename to worker-{{cookiecutter.slug}}/worker_{{cookiecutter.__module}}/worker.py index 4b02d0bcc585c494235d7c5f656a6664863f5199..081f4fc962e088e518639e80a3768e0d80a8a558 100644 --- a/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.slug}}/worker.py +++ b/worker-{{cookiecutter.slug}}/worker_{{cookiecutter.__module}}/worker.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from logging import Logger, getLogger from arkindex_worker.models import Element