diff --git a/arkindex_worker/worker/training.py b/arkindex_worker/worker/training.py index cae83c0c1fddab68fcfefbf161d2c0018c039a7d..8ffe9048533ec87ea29cc91e6fe8e999161701bf 100644 --- a/arkindex_worker/worker/training.py +++ b/arkindex_worker/worker/training.py @@ -6,7 +6,7 @@ import tempfile from contextlib import contextmanager from typing import NewType, Tuple -import requests +# import requests import zstandard as zstd from apistar.exceptions import ErrorResponse @@ -77,7 +77,7 @@ def create_archive(path: FilePath) -> Archive: class TrainingMixin(object): - def publish_model_version(self, client, model_path, model_id): + def publish_model_version(self, model_path, model_id): # Create the zst archive, get its hash and size with create_archive(path=model_path) as ( path_to_archive, @@ -87,7 +87,7 @@ class TrainingMixin(object): ): # Create a new model version with hash and size model_version_details = self.create_model_version( - client=client, + client=self.api_client, model_id=model_id, hash=hash, size=size, @@ -102,12 +102,17 @@ class TrainingMixin(object): # Update the model version with state, configuration parsed, tag, description (defaults to name of the worker) self.update_model_version( - client=client, + client=self.api_client, model_version_details=model_version_details, ) def create_model_version( - client: ArkindexClient, model_id: str, hash: str, size: int, archive_hash: str + self, + client: ArkindexClient, + model_id: str, + hash: str, + size: int, + archive_hash: str, ) -> dict: # Create a new model version with hash and size try: @@ -129,12 +134,12 @@ class TrainingMixin(object): return return model_version_details - def upload_to_s3(archive_path: str, model_version_details: dict) -> None: + def upload_to_s3(self, archive_path: str, model_version_details: dict) -> None: s3_put_url = model_version_details.get("s3_put_url") logger.info("Uploading to s3...") # Upload the archive on s3 with open(archive_path, "rb") as archive: - r = requests.put( + r = self.request.put( url=s3_put_url, data=archive, headers={"Content-Type": "application/zstd"}, @@ -142,17 +147,23 @@ class TrainingMixin(object): r.raise_for_status() def update_model_version( - client: ArkindexClient, model_version_details: dict + self, + model_version_details: dict, + description: str = None, + configuration: dict = None, + tag: str = None, ) -> None: logger.info("Updating the model version...") try: - client.request( + # request or requests ? + self.request( "UpdateModelVersion", id=model_version_details.get("id"), body={ "state": "available", - "description": "DOC UFCN", - "configuration": {}, + "description": description, + "configuration": configuration, + "tag": tag, }, ) except ErrorResponse as e: diff --git a/tests-requirements.txt b/tests-requirements.txt index cbfbfa256a156026a0070ca834ffcac16a867d2f..d1f9171264f9bebaf629492863299aabe1c3592e 100644 --- a/tests-requirements.txt +++ b/tests-requirements.txt @@ -1,3 +1,4 @@ pytest==7.1.1 pytest-mock==3.7.0 pytest-responses==0.5.0 +requests==2.27.1 diff --git a/tests/samples/model_file.pth b/tests/samples/model_file.pth new file mode 100644 index 0000000000000000000000000000000000000000..cc78ba3026c620f5d0e8c5b65071ae8ae2dfe157 --- /dev/null +++ b/tests/samples/model_file.pth @@ -0,0 +1 @@ +Wow this is actually the data of the best model ever created on Arkindex \ No newline at end of file diff --git a/tests/samples/model_files/model_file.pth b/tests/samples/model_files/model_file.pth new file mode 100644 index 0000000000000000000000000000000000000000..cc78ba3026c620f5d0e8c5b65071ae8ae2dfe157 --- /dev/null +++ b/tests/samples/model_files/model_file.pth @@ -0,0 +1 @@ +Wow this is actually the data of the best model ever created on Arkindex \ No newline at end of file diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py new file mode 100644 index 0000000000000000000000000000000000000000..a20e3e4b1d05e38594a1418f98c23db1ce79c4b5 --- /dev/null +++ b/tests/test_elements_worker/test_training.py @@ -0,0 +1,152 @@ +from http import client +import imp +import pytest +from arkindex_worker.worker.training import create_archive, TrainingMixin +import os +import responses +from responses import matchers +from pathlib import Path + +from arkindex.mock import MockApiClient + +def test_create_archive_folder(): + model_file_dir = Path("tests/samples/model_files") + + with create_archive(path=model_file_dir) as ( + zst_archive_path, + hash, + size, + archive_hash, + ): + assert os.path.exists(zst_archive_path), "The archive was not created" + assert ( + hash == "7dd70931222ef0496ea75e5aee674043" + ), "Hash was not properly computed" + assert 300 < size < 700 + + assert not os.path.exists(zst_archive_path), "Auto removal failed" + + +def test_create_model_version(): + api_client = MockApiClient() + + """A new model version is returned""" + model_id = "fake_model_id" + model_version_id = "fake_model_version_id" + # Create a model archive and keep its hash and size. + model_files_dir = Path("tests/samples/model_files") + # model_file_path = model_files_dir / "model_file.pth" + training = TrainingMixin() + with create_archive(path=model_files_dir) as ( + zst_archive_path, + hash, + size, + archive_hash, + ): + model_version_details = { + "id": model_version_id, + "model_id": model_id, + "hash": hash, + "archive_hash": archive_hash, + "size": size, + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } + + responses.add( + responses.POST, + f"http://testserver/api/v1/model/{model_id}/versions/", + status=200, + match=[ + matchers.json_params_matcher( + {"hash": hash, "archive_hash": archive_hash, "size": size} + ) + ], + json=model_version_details, + ) + + assert ( + training.create_model_version(api_client, model_id, hash, size, archive_hash) + == model_version_details + ) + + +# def test_retrieve_created_model_version(api_client, samples_dir): +# """There is an existing model version in Created mode, A 400 was raised. +# But the model is still returned in error content +# """ +# model_id = "fake_model_id" +# model_version_id = "fake_model_version_id" +# # Create a model archive and keep its hash and size. +# model_file_path = samples_dir / "model_file.pth" +# training = TrainingMixin() +# with create_archive(path=model_file_path) as ( +# zst_archive_path, +# hash, +# size, +# archive_hash, +# ): +# model_version_details = { +# "id": model_version_id, +# "model_id": model_id, +# "hash": hash, +# "archive_hash": archive_hash, +# "size": size, +# "s3_url": "http://hehehe.com", +# "s3_put_url": "http://hehehe.com", +# } + +# responses.add( +# responses.POST, +# f"http://testserver/api/v1/model/{model_id}/versions/", +# status=400, +# match=[ +# matchers.json_params_matcher( +# {"hash": hash, "archive_hash": archive_hash, "size": size} +# ) +# ], +# json={"hash": model_version_details}, +# ) + +# assert ( +# training.create_model_version(api_client, model_id, hash, size, archive_hash) +# == model_version_details +# ) + + +# def test_retrieve_available_model_version(api_client, samples_dir): +# """Raise error when there is an existing model version in Available mode""" +# model_id = "fake_model_id" +# # Create a model archive and keep its hash and size. +# model_file_path = samples_dir / "model_file.pth" +# training = TrainingMixin() +# with create_archive(path=model_file_path) as ( +# zst_archive_path, +# hash, +# size, +# archive_hash, +# ): +# responses.add( +# responses.POST, +# f"http://testserver/api/v1/model/{model_id}/versions/", +# status=403, +# match=[ +# matchers.json_params_matcher( +# {"hash": hash, "archive_hash": archive_hash, "size": size} +# ) +# ], +# ) + +# with pytest.raises(Exception): +# training.create_model_version(api_client, model_id, hash, size, archive_hash) + + + +# def test_handle_s3_uploading_errors(samples_dir): +# s3_endpoint_url = "http://s3.localhost.com" +# responses.add_passthru(s3_endpoint_url) +# responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) +# file_path = samples_dir / "model_file.pth" +# training = TrainingMixin() +# with pytest.raises(Exception): +# training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})