Skip to content
Snippets Groups Projects
training.py 10 KiB
Newer Older
"""
BaseWorker methods for training.
"""

import functools
from contextlib import contextmanager
from pathlib import Path
from typing import NewType
from uuid import UUID

import requests
from apistar.exceptions import ErrorResponse

from arkindex_worker import logger
from arkindex_worker.utils import close_delete_file, create_tar_zst_archive
DirPath = NewType("DirPath", Path)
"""Path to a directory"""

Hash = NewType("Hash", str)
"""MD5 Hash"""

FileSize = NewType("FileSize", int)
"""File size"""
def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]:
    """
    Create a tar archive from the files at the given location then compress it to a zst archive.

    Yield its location, its hash, its size and its content's hash.

    :param path: Create a compressed tar archive from the files
    :returns: The location of the created archive, its hash, its size and its content's hash
    """
    assert path.is_dir(), "create_archive needs a directory"

    zst_descriptor, zst_archive, archive_hash, content_hash = create_tar_zst_archive(

    # Get content hash, archive size and hash
    yield zst_archive, content_hash, zst_archive.stat().st_size, archive_hash
    # Remove the zst archive
    close_delete_file(zst_descriptor, zst_archive)
def build_clean_payload(**kwargs):
    """
    Remove null attributes from an API body payload
    """
    return {key: value for key, value in kwargs.items() if value is not None}


def skip_if_read_only(func):
    """
    Return shortly in case the is_read_only property is evaluated to True
    """

    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        if getattr(self, "is_read_only", False):
            logger.warning(
                "Cannot perform this operation as the worker is in read-only mode"
            )
            return
        return func(self, *args, **kwargs)

    return wrapper


class TrainingMixin:
    """
    A mixin helper to create a new model version easily.
    You may use `publish_model_version` to publish a ready model version directly, or
    separately create the model version then publish it (e.g to store training metrics).
    Stores the currently handled model version as `self.model_version`.
    """

    model_version = None

    @property
    def is_finetuning(self) -> bool:
        return bool(self.model_version_id)

    @skip_if_read_only
        self,
        model_path: DirPath,
        model_id: str,
        tag: str | None = None,
        description: str | None = None,
        configuration: dict | None = None,
        parent: str | UUID | None = None,
        Publish a unique version of a model in Arkindex, identified by its hash.
        In case the `create_model_version` method has been called, reuses that model
        instead of creating a new one.

        :param model_path: Path to the directory containing the model version's files.
        :param model_id: ID of the model
        :param tag: Tag of the model version
        :param description: Description of the model version
        :param configuration: Configuration of the model version
        :param parent: ID of the parent model version
        configuration = configuration or {}
        if not self.model_version:
            self.create_model_version(
                model_id=model_id,
                tag=tag,
                description=description,
                configuration=configuration,
                parent=parent,
            )

        elif tag or description or configuration or parent:
            assert (
                self.model_version.get("model_id") == model_id
            ), "Given `model_id` does not match the current model version"
            # If any attribute field has been defined, PATCH the current model version
            self.update_model_version(
                tag=tag,
                description=description,
                configuration=configuration,
                parent=parent,

        # Create the zst archive, get its hash and size
        # Validate the model version
        with create_archive(path=model_path) as (
            path_to_archive,
            hash,
            size,
            archive_hash,
        ):
            # Create a new model version with hash and size
            self.upload_to_s3(archive_path=path_to_archive)

            current_version_id = self.model_version["id"]
            # Mark the model as valid
            self.validate_model_version(
                size=size,
                hash=hash,
                archive_hash=archive_hash,
            )
            if self.model_version["id"] != current_version_id and (
                tag or description or configuration or parent
            ):
                logger.warning(
                    "Updating the existing available model version with the given attributes."
                )
                self.update_model_version(
                    tag=tag,
                    description=description,
                    configuration=configuration,
                    parent=parent,
                )

    @skip_if_read_only
    def create_model_version(
        self,
        model_id: str,
        tag: str | None = None,
        description: str | None = None,
        configuration: dict | None = None,
        parent: str | UUID | None = None,
        Create a new version of the specified model with its base attributes.
        Once successfully created, the model version is accessible via `self.model_version`.

        :param tag: Tag of the model version
        :param description: Description of the model version
        :param configuration: Configuration of the model version
        :param parent: ID of the parent model version
        assert not self.model_version, "A model version has already been created."

        configuration = configuration or {}
        self.model_version = self.request(
            "CreateModelVersion",
            id=model_id,
            body=build_clean_payload(
                tag=tag,
                description=description,
                configuration=configuration,
                parent=parent,
            ),
        )
        logger.info(
            f"Model version ({self.model_version['id']}) was successfully created"
        )
    @skip_if_read_only
    def update_model_version(
        self,
        tag: str | None = None,
        description: str | None = None,
        configuration: dict | None = None,
        parent: str | UUID | None = None,
    ):
        """
        Update the current model version with the given attributes.
        :param tag: Tag of the model version
        :param description: Description of the model version
        :param configuration: Configuration of the model version
        :param parent: ID of the parent model version
        """
        assert self.model_version, "No model version has been created yet."
        self.model_version = self.request(
            "UpdateModelVersion",
            id=self.model_version["id"],
            body=build_clean_payload(
                tag=tag,
                description=description,
                configuration=configuration,
                parent=parent,
            ),
        )
        logger.info(
            f"Model version ({self.model_version['id']}) was successfully updated"
        )
    @skip_if_read_only
    def upload_to_s3(self, archive_path: Path) -> None:
        """
        Upload the archive of the model's files to an Amazon s3 compatible storage
        """

        assert (
            self.model_version
        ), "You must create the model version before uploading an archive."
        assert (
            self.model_version["state"] != "Available"
        ), "The model is already marked as available."

        s3_put_url = self.model_version.get("s3_put_url")
        assert s3_put_url, "S3 PUT URL is not set, please ensure you have the right to validate a model version."
        logger.info("Uploading to s3...")
        # Upload the archive on s3
        with archive_path.open("rb") as archive:
            r = requests.put(
                url=s3_put_url,
                data=archive,
                headers={"Content-Type": "application/zstd"},
            )
        r.raise_for_status()

    @skip_if_read_only
    def validate_model_version(
        hash: str,
        size: int,
        archive_hash: str,
    ):
        Sets the model version as `Available`, once its archive has been uploaded to S3.
        :param hash: MD5 hash of the files contained in the archive
        :param size: The size of the uploaded archive
        :param archive_hash: MD5 hash of the uploaded archive
        """
        assert self.model_version, "You must create the model version and upload its archive before validating it."
            self.model_version = self.request(
                "ValidateModelVersion",
                id=self.model_version["id"],
                    "size": size,
                    "hash": hash,
                    "archive_hash": archive_hash,
                },
            )
        except ErrorResponse as e:
            if e.status_code != 409:
                raise e
            logger.warning(
                f"An available model version exists with hash {hash}, using it instead of the pending version."
            )
            pending_version_id = self.model_version["id"]
            self.model_version = getattr(e, "content", None)
            assert self.model_version is not None, "An unexpected error occurred."

            logger.warning("Removing the pending model version.")
            try:
                self.request("DestroyModelVersion", id=pending_version_id)
            except ErrorResponse as e:
                msg = getattr(e, "content", str(e))
                logger.error(
                    f"An error occurred removing the pending version {pending_version_id}: {msg}."
                )

        logger.info(f"Model version {self.model_version['id']} is now available.")