Skip to content
Snippets Groups Projects
Commit 570e2301 authored by NolanB's avatar NolanB
Browse files

Modif with the review

parent 72c2043b
No related branches found
No related tags found
No related merge requests found
Pipeline #79380 passed
......@@ -14,14 +14,14 @@ from arkindex_worker import logger
CHUNK_SIZE = 1024
FilePath = NewType("FilePath", str)
DirPath = NewType("DirPath", str)
Hash = NewType("Hash", str)
FileSize = NewType("FileSize", int)
Archive = Tuple[FilePath, Hash, FileSize]
Archive = Tuple[DirPath, Hash, FileSize]
@contextmanager
def create_archive(path: FilePath) -> Archive:
def create_archive(path: DirPath) -> Archive:
"""First create a tar archive, then compress to a zst archive.
Finally, get its hash and size
"""
......@@ -76,7 +76,16 @@ def create_archive(path: FilePath) -> Archive:
class TrainingMixin(object):
def publish_model_version(self, model_path, model_id):
"""
Mixin for the Training workers to add Model and ModelVersion helpers
"""
def publish_model_version(self, model_path: DirPath, model_id: str):
"""
This method creates a model archive and its associated hash,
to create a unique version that will be stored in an amazon s3 and published on arkindex.
"""
# Create the zst archive, get its hash and size
with create_archive(path=model_path) as (
path_to_archive,
......@@ -110,6 +119,11 @@ class TrainingMixin(object):
size: int,
archive_hash: str,
) -> dict:
"""
This method creates an unique version of the model if it does not already exist,
otherwise use the model version id existing and return a dict including details of model
"""
# Create a new model version with hash and size
try:
model_version_details = self.request(
......@@ -133,6 +147,10 @@ class TrainingMixin(object):
return model_version_details
def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None:
"""
This method upload the archive of the model to an amazon s3
"""
s3_put_url = model_version_details.get("s3_put_url")
logger.info("Uploading to s3...")
# Upload the archive on s3
......@@ -151,6 +169,10 @@ class TrainingMixin(object):
configuration: dict = None,
tag: str = None,
) -> None:
"""
This method update the model version in Arkindex with a state "available"
"""
logger.info("Updating the model version...")
try:
self.request(
......
......@@ -10,10 +10,14 @@ from arkindex_worker.worker.training import TrainingMixin, create_archive
class TrainingWorker(BaseWorker, TrainingMixin):
"""
This class is only needed for tests
"""
pass
def test_create_archive_folder(model_file_dir):
def test_create_archive(model_file_dir):
with create_archive(path=model_file_dir) as (
zst_archive_path,
hash,
......@@ -35,13 +39,13 @@ def test_create_model_version():
model_version_id = "fake_model_version_id"
training = TrainingWorker()
training.api_client = MockApiClient()
hash = "hash"
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
model_version_details = {
"id": model_version_id,
"model_id": model_id,
"hash": hash,
"hash": model_hash,
"archive_hash": archive_hash,
"size": size,
"s3_url": "http://hehehe.com",
......@@ -52,10 +56,10 @@ def test_create_model_version():
"CreateModelVersion",
id=model_id,
response=model_version_details,
body={"hash": hash, "archive_hash": archive_hash, "size": size},
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
)
assert (
training.create_model_version(model_id, hash, size, archive_hash)
training.create_model_version(model_id, model_hash, size, archive_hash)
== model_version_details
)
......@@ -81,32 +85,34 @@ def test_create_model_version():
],
)
def test_retrieve_created_model_version(content, status_code):
"""There is an existing model version in Created mode, A 400 was raised.
But the model is still returned in error content
"""
If there is an existing model version in Created mode,
A 400 was raised, but the model is still returned in error content.
Else if an existing model version in Available mode,
403 was raised, but None will be returned
"""
model_id = "fake_model_id"
training = TrainingWorker()
training.api_client = MockApiClient()
hash = "hash"
model_hash = "hash"
archive_hash = "archive_hash"
size = "30"
training.api_client.add_error_response(
"CreateModelVersion",
id=model_id,
status_code=status_code,
body={"hash": hash, "archive_hash": archive_hash, "size": size},
body={"hash": model_hash, "archive_hash": archive_hash, "size": size},
content=content,
)
if status_code == 400:
assert (
training.create_model_version(model_id, hash, size, archive_hash)
training.create_model_version(model_id, model_hash, size, archive_hash)
== content["hash"]
)
elif status_code == 403:
none_value = None
assert (
training.create_model_version(model_id, hash, size, archive_hash)
== none_value
training.create_model_version(model_id, model_hash, size, archive_hash)
is None
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment