diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index c0905ae47336ac7f985395fd3482bdb63221011e..c517220b891a8957ff20e1fdaae8149afea9d49f 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -5,6 +5,7 @@ from pathlib import Path import responses from responses import matchers +from arkindex.mock import MockApiClient from arkindex_worker.worker import BaseWorker from arkindex_worker.worker.training import TrainingMixin, create_archive @@ -40,6 +41,7 @@ def test_create_model_version(): model_files_dir = Path("tests/samples/model_files") # model_file_path = model_files_dir / "model_file.pth" training = TrainingWorker() + client = MockApiClient() with create_archive(path=model_files_dir) as ( zst_archive_path, hash, @@ -56,6 +58,7 @@ def test_create_model_version(): "s3_put_url": "http://hehehe.com", } + client.__setattr__("model_version_details", model_version_details) responses.add( responses.POST, f"http://testserver/api/v1/model/{model_id}/versions/", @@ -68,6 +71,19 @@ def test_create_model_version(): json=model_version_details, ) + print(model_version_details) + # responses.add( + # responses.POST, + # f"http://testserver/api/v1/model/{model_id}/versions/", + # status=200, + # match=[ + # matchers.json_params_matcher( + # {"hash": hash, "archive_hash": archive_hash, "size": size} + # ) + # ], + # json=model_version_details, + # ) + assert ( training.create_model_version(model_id, hash, size, archive_hash) == model_version_details