diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index bb7048a5e6cc5ac177de6b4473524dcc89586cd3..cbd3325a5616eea6fa486c114019914680862883 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -14,14 +14,14 @@ from arkindex_worker import logger CHUNK_SIZE = 1024 -FilePath = NewType("FilePath", str) +DirPath = NewType("DirPath", str) Hash = NewType("Hash", str) FileSize = NewType("FileSize", int) -Archive = Tuple[FilePath, Hash, FileSize] +Archive = Tuple[DirPath, Hash, FileSize] @contextmanager -def create_archive(path: FilePath) -> Archive: +def create_archive(path: DirPath) -> Archive: """First create a tar archive, then compress to a zst archive. Finally, get its hash and size """ @@ -76,7 +76,16 @@ def create_archive(path: FilePath) -> Archive: class TrainingMixin(object): - def publish_model_version(self, model_path, model_id): + """ + Mixin for the Training workers to add Model and ModelVersion helpers + """ + + def publish_model_version(self, model_path: DirPath, model_id: str): + """ + This method creates a model archive and its associated hash, + to create a unique version that will be stored in an amazon s3 and published on arkindex. + """ + # Create the zst archive, get its hash and size with create_archive(path=model_path) as ( path_to_archive, @@ -110,6 +119,11 @@ class TrainingMixin(object): size: int, archive_hash: str, ) -> dict: + """ + This method creates an unique version of the model if it does not already exist, + otherwise use the model version id existing and return a dict including details of model + """ + # Create a new model version with hash and size try: model_version_details = self.request( @@ -133,6 +147,10 @@ class TrainingMixin(object): return model_version_details def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None: + """ + This method upload the archive of the model to an amazon s3 + """ + s3_put_url = model_version_details.get("s3_put_url") logger.info("Uploading to s3...") # Upload the archive on s3 @@ -151,6 +169,10 @@ class TrainingMixin(object): configuration: dict = None, tag: str = None, ) -> None: + """ + This method update the model version in Arkindex with a state "available" + """ + logger.info("Updating the model version...") try: self.request( diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index 7e76cc707df9157e3f9e86569fe3399fb9b4e4de..147536f3e62459c5c2332ffeb0e897464c745ec3 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -10,10 +10,14 @@ from arkindex_worker.worker.training import TrainingMixin, create_archive class TrainingWorker(BaseWorker, TrainingMixin): + """ + This class is only needed for tests + """ + pass -def test_create_archive_folder(model_file_dir): +def test_create_archive(model_file_dir): with create_archive(path=model_file_dir) as ( zst_archive_path, hash, @@ -35,13 +39,13 @@ def test_create_model_version(): model_version_id = "fake_model_version_id" training = TrainingWorker() training.api_client = MockApiClient() - hash = "hash" + model_hash = "hash" archive_hash = "archive_hash" size = "30" model_version_details = { "id": model_version_id, "model_id": model_id, - "hash": hash, + "hash": model_hash, "archive_hash": archive_hash, "size": size, "s3_url": "http://hehehe.com", @@ -52,10 +56,10 @@ def test_create_model_version(): "CreateModelVersion", id=model_id, response=model_version_details, - body={"hash": hash, "archive_hash": archive_hash, "size": size}, + body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, ) assert ( - training.create_model_version(model_id, hash, size, archive_hash) + training.create_model_version(model_id, model_hash, size, archive_hash) == model_version_details ) @@ -81,32 +85,34 @@ def test_create_model_version(): ], ) def test_retrieve_created_model_version(content, status_code): - """There is an existing model version in Created mode, A 400 was raised. - But the model is still returned in error content + """ + If there is an existing model version in Created mode, + A 400 was raised, but the model is still returned in error content. + Else if an existing model version in Available mode, + 403 was raised, but None will be returned """ model_id = "fake_model_id" training = TrainingWorker() training.api_client = MockApiClient() - hash = "hash" + model_hash = "hash" archive_hash = "archive_hash" size = "30" training.api_client.add_error_response( "CreateModelVersion", id=model_id, status_code=status_code, - body={"hash": hash, "archive_hash": archive_hash, "size": size}, + body={"hash": model_hash, "archive_hash": archive_hash, "size": size}, content=content, ) if status_code == 400: assert ( - training.create_model_version(model_id, hash, size, archive_hash) + training.create_model_version(model_id, model_hash, size, archive_hash) == content["hash"] ) elif status_code == 403: - none_value = None assert ( - training.create_model_version(model_id, hash, size, archive_hash) - == none_value + training.create_model_version(model_id, model_hash, size, archive_hash) + is None )