Newer
Older
"""
BaseWorker methods for training.
"""
from typing import NewType
import requests
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
from arkindex_worker.utils import close_delete_file, create_tar_zst_archive
def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]:
"""
Create a tar archive from the files at the given location then compress it to a zst archive.
Yield its location, its hash, its size and its content's hash.
:param path: Create a compressed tar archive from the files
:returns: The location of the created archive, its hash, its size and its content's hash
"""
assert path.is_dir(), "create_archive needs a directory"
zst_descriptor, zst_archive, archive_hash, content_hash = create_tar_zst_archive(
# Get content hash, archive size and hash
yield zst_archive, content_hash, zst_archive.stat().st_size, archive_hash
# Remove the zst archive
close_delete_file(zst_descriptor, zst_archive)
def build_clean_payload(**kwargs):
"""
Remove null attributes from an API body payload
"""
return {key: value for key, value in kwargs.items() if value is not None}
def skip_if_read_only(func):
"""
Return shortly in case the is_read_only property is evaluated to True
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if getattr(self, "is_read_only", False):
logger.warning(
"Cannot perform this operation as the worker is in read-only mode"
)
return
return func(self, *args, **kwargs)
return wrapper
"""
A mixin helper to create a new model version easily.
You may use `publish_model_version` to publish a ready model version directly, or
separately create the model version then publish it (e.g to store training metrics).
Stores the currently handled model version as `self.model_version`.
"""
model_version = None
@property
def is_finetuning(self) -> bool:
return bool(self.model_version_id)
def publish_model_version(
self,
model_path: DirPath,
model_id: str,
tag: str | None = None,
description: str | None = None,
parent: str | UUID | None = None,
Publish a unique version of a model in Arkindex, identified by its hash.
In case the `create_model_version` method has been called, reuses that model
instead of creating a new one.
:param model_path: Path to the directory containing the model version's files.
:param model_id: ID of the model
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
:param parent: ID of the parent model version
if not self.model_version:
self.create_model_version(
model_id=model_id,
tag=tag,
description=description,
configuration=configuration,
parent=parent,
)
elif tag or description or configuration or parent:
assert (
self.model_version.get("model_id") == model_id
), "Given `model_id` does not match the current model version"
# If any attribute field has been defined, PATCH the current model version
self.update_model_version(
tag=tag,
description=description,
configuration=configuration,
parent=parent,
# Create the zst archive, get its hash and size
with create_archive(path=model_path) as (
path_to_archive,
hash,
size,
archive_hash,
):
# Create a new model version with hash and size
self.upload_to_s3(archive_path=path_to_archive)
current_version_id = self.model_version["id"]
# Mark the model as valid
self.validate_model_version(
if self.model_version["id"] != current_version_id and (
tag or description or configuration or parent
):
logger.warning(
"Updating the existing available model version with the given attributes."
)
self.update_model_version(
tag=tag,
description=description,
configuration=configuration,
parent=parent,
)
@skip_if_read_only
def create_model_version(
self,
model_id: str,
tag: str | None = None,
description: str | None = None,
parent: str | UUID | None = None,
Create a new version of the specified model with its base attributes.
Once successfully created, the model version is accessible via `self.model_version`.
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
:param parent: ID of the parent model version
assert not self.model_version, "A model version has already been created."
self.model_version = self.request(
"CreateModelVersion",
id=model_id,
body=build_clean_payload(
tag=tag,
description=description,
configuration=configuration,
parent=parent,
),
)
logger.info(
f"Model version ({self.model_version['id']}) was successfully created"
)
@skip_if_read_only
def update_model_version(
self,
tag: str | None = None,
description: str | None = None,
configuration: dict | None = None,
parent: str | UUID | None = None,
):
"""
Update the current model version with the given attributes.
:param tag: Tag of the model version
:param description: Description of the model version
:param configuration: Configuration of the model version
:param parent: ID of the parent model version
"""
assert self.model_version, "No model version has been created yet."
self.model_version = self.request(
"UpdateModelVersion",
id=self.model_version["id"],
body=build_clean_payload(
tag=tag,
description=description,
configuration=configuration,
parent=parent,
),
)
logger.info(
f"Model version ({self.model_version['id']}) was successfully updated"
)
def upload_to_s3(self, archive_path: Path) -> None:
"""
Upload the archive of the model's files to an Amazon s3 compatible storage
"""
assert (
self.model_version
), "You must create the model version before uploading an archive."
assert (
self.model_version["state"] != "Available"
), "The model is already marked as available."
s3_put_url = self.model_version.get("s3_put_url")
assert s3_put_url, "S3 PUT URL is not set, please ensure you have the right to validate a model version."
logger.info("Uploading to s3...")
# Upload the archive on s3
with archive_path.open("rb") as archive:
r = requests.put(
url=s3_put_url,
data=archive,
headers={"Content-Type": "application/zstd"},
)
r.raise_for_status()
@skip_if_read_only
def validate_model_version(
hash: str,
size: int,
archive_hash: str,
):
Sets the model version as `Available`, once its archive has been uploaded to S3.
:param hash: MD5 hash of the files contained in the archive
:param size: The size of the uploaded archive
:param archive_hash: MD5 hash of the uploaded archive
"""
assert self.model_version, "You must create the model version and upload its archive before validating it."
self.model_version = self.request(
"ValidateModelVersion",
id=self.model_version["id"],
"size": size,
"hash": hash,
"archive_hash": archive_hash,
},
)
except ErrorResponse as e:
if e.status_code != 409:
raise e
logger.warning(
f"An available model version exists with hash {hash}, using it instead of the pending version."
)
pending_version_id = self.model_version["id"]
self.model_version = getattr(e, "content", None)
assert self.model_version is not None, "An unexpected error occurred."
logger.warning("Removing the pending model version.")
try:
self.request("DestroyModelVersion", id=pending_version_id)
except ErrorResponse as e:
msg = getattr(e, "content", str(e))
logger.error(
f"An error occurred removing the pending version {pending_version_id}: {msg}."
)
logger.info(f"Model version {self.model_version['id']} is now available.")