Skip to content
Snippets Groups Projects
Commit f67d28ff authored by Nolan's avatar Nolan Committed by Yoann Schneider
Browse files

support-training-model-version-publication

parent 61dd9cb6
No related branches found
No related tags found
1 merge request!184support-training-model-version-publication
Pipeline #79398 passed
......@@ -8,4 +8,4 @@ line_length = 88
default_section=FIRSTPARTY
known_first_party = arkindex,arkindex_common
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml
known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard
# -*- coding: utf-8 -*-
import hashlib
import os
import tarfile
import tempfile
from contextlib import contextmanager
from typing import NewType, Tuple
import requests
import zstandard as zstd
from apistar.exceptions import ErrorResponse
from arkindex_worker import logger
CHUNK_SIZE = 1024
DirPath = NewType("DirPath", str)
Hash = NewType("Hash", str)
FileSize = NewType("FileSize", int)
Archive = Tuple[DirPath, Hash, FileSize]
@contextmanager
def create_archive(path: DirPath) -> Archive:
"""First create a tar archive, then compress to a zst archive.
Finally, get its hash and size
"""
assert path.is_dir(), "create_archive needs a directory"
compressor = zstd.ZstdCompressor(level=3)
content_hasher = hashlib.md5()
archive_hasher = hashlib.md5()
# Remove extension from the model filename
_, path_to_tar_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar")
# Create an uncompressed tar archive with all the needed files
# Files hierarchy ifs kept in the archive.
file_list = []
with tarfile.open(path_to_tar_archive, "w") as tar:
for p in path.glob("**/*"):
x = p.relative_to(path)
tar.add(p, arcname=x, recursive=False)
file_list.append(p)
# Sort by path
file_list.sort()
# Compute hash of the files
for file_path in file_list:
with open(file_path, "rb") as file_data:
for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
content_hasher.update(chunk)
_, path_to_zst_archive = tempfile.mkstemp(prefix="teklia-", suffix=".tar.zst")
# Compress the archive
with open(path_to_zst_archive, "wb") as archive_file:
with open(path_to_tar_archive, "rb") as model_data:
for model_chunk in iter(lambda: model_data.read(CHUNK_SIZE), b""):
compressed_chunk = compressor.compress(model_chunk)
archive_hasher.update(compressed_chunk)
archive_file.write(compressed_chunk)
# Remove the tar archive
os.remove(path_to_tar_archive)
# Get content hash, archive size and hash
hash = content_hasher.hexdigest()
size = os.path.getsize(path_to_zst_archive)
archive_hash = archive_hasher.hexdigest()
yield path_to_zst_archive, hash, size, archive_hash
# Remove the zstd archive
os.remove(path_to_zst_archive)
class TrainingMixin(object):
"""
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version(self, model_path: DirPath, model_id: str):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored on a bucket and published on arkindex.
"""
# Create the zst archive, get its hash and size
with create_archive(path=model_path) as (
path_to_archive,
hash,
size,
archive_hash,
):
# Create a new model version with hash and size
model_version_details = self.create_model_version(
model_id=model_id,
hash=hash,
size=size,
archive_hash=archive_hash,
)
if model_version_details is None:
return
self.upload_to_s3(
archive_path=path_to_archive,
model_version_details=model_version_details,
)
# Update the model version with state, configuration parsed, tag, description (defaults to name of the worker)
self.update_model_version(
model_version_details=model_version_details,
)
def create_model_version(
self,
model_id: str,
hash: str,
size: int,
archive_hash: str,
) -> dict:
"""
Create a new version of the specified model with the given information (hashes and size).
If a version matching the information already exist, there are two cases:
- The version is in `Created` state: this version's details is used
- The version is in `Available` state: you cannot create twice the same version, an error is raised
"""
# Create a new model version with hash and size
try:
model_version_details = self.request(
"CreateModelVersion",
id=model_id,
body={"hash": hash, "size": size, "archive_hash": archive_hash},
)
except ErrorResponse as e:
if e.status_code >= 500:
logger.error(f"Failed to create model version: {e.content}")
model_version_details = e.content.get("hash")
# If the existing model is in Created state, this model is returned as a dict.
# Else an error in a list is returned.
if model_version_details and isinstance(
model_version_details, (list, tuple)
):
logger.error(model_version_details[0])
return
return model_version_details
def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None:
"""
Upload the archive of the model's files to an Amazon s3 compatible storage
"""
s3_put_url = model_version_details.get("s3_put_url")
logger.info("Uploading to s3...")
# Upload the archive on s3
with open(archive_path, "rb") as archive:
r = requests.put(
url=s3_put_url,
data=archive,
headers={"Content-Type": "application/zstd"},
)
r.raise_for_status()
def update_model_version(
self,
model_version_details: dict,
description: str = None,
configuration: dict = None,
tag: str = None,
) -> None:
"""
Update the specified model version to the state `Available` and use the given information"
"""
logger.info("Updating the model version...")
try:
self.request(
"UpdateModelVersion",
id=model_version_details.get("id"),
body={
"state": "available",
"description": description,
"configuration": configuration,
"tag": tag,
},
)
except ErrorResponse as e:
logger.error(f"Failed to update model version: {e.content}")
......@@ -6,3 +6,4 @@ python-gnupg==0.4.8
sh==1.14.2
shapely==1.8.2
tenacity==8.0.1
zstandard==0.18.0
pytest==7.1.1
pytest-mock==3.7.0
pytest-responses==0.5.0
requests==2.27.1
......@@ -26,6 +26,7 @@ from arkindex_worker.worker import BaseWorker, ElementsWorker
from arkindex_worker.worker.transcription import TextOrientation
FIXTURES_DIR = Path(__file__).resolve().parent / "data"
SAMPLES_DIR = Path(__file__).resolve().parent / "samples"
__yaml_cache = {}
......@@ -276,6 +277,11 @@ def fake_transcriptions_small():
return json.load(f)
@pytest.fixture
def model_file_dir():
return SAMPLES_DIR / "model_files"
@pytest.fixture
def fake_dummy_worker():
api_client = MockApiClient()
......
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
Wow this is actually the data of the best model ever created on Arkindex
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import pytest
import responses
from arkindex.mock import MockApiClient
from arkindex_worker.worker import BaseWorker
from arkindex_worker.worker.training import TrainingMixin, create_archive
class TrainingWorker(BaseWorker, TrainingMixin):
"""
This class is only needed for tests
"""
pass
def test_create_archive(model_file_dir):
"""Create an archive when the model's file is in a folder"""
with create_archive(path=model_file_dir) as (
zst_archive_path,
hash,
size,
archive_hash,
):
assert os.path.exists(zst_archive_path), "The archive was not created"
assert (
hash == "c5aedde18a768757351068b840c8c8f9"
), "Hash was not properly computed"
assert 300 < size < 700
assert not os.path.exists(zst_archive_path), "Auto removal failed"
def test_create_model_version():
"""A new model version is returned"""
model_id = "fake_model_id"
model_version_id = "fake_model_version_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
training.api_client.add_response(
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
)
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
== model_version_details
)
@pytest.mark.parametrize(
"content, status_code",
[
(
{
"hash": {
"id": "fake_model_version_id",
"model_id": "fake_model_id",
"hash": "hash",
"archive_hash": "archive_hash",
"size": "size",
"s3_url": "http://hehehe.com",
"s3_put_url": "http://hehehe.com",
}
},
400,
),
({"hash": ["A version for this model with this hash already exists."]}, 403),
],
)
def test_retrieve_created_model_version(content, status_code):
"""
If there is an existing model version in Created mode,
A 400 was raised, but the model is still returned in error content.
Else if an existing model version in Available mode,
403 was raised, but None will be returned
"""
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
content=content,
)
if status_code == 400:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
== content["hash"]
)
elif status_code == 403:
assert (
training.create_model_version(model_id, model_hash, size, archive_hash)
is None
)
def test_handle_s3_uploading_errors(model_file_dir):
training = TrainingWorker()
training.api_client = MockApiClient()
s3_endpoint_url = "http://s3.localhost.com"
responses.add_passthru(s3_endpoint_url)
responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400))
file_path = model_file_dir / "model_file.pth"
with pytest.raises(Exception):
training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment