diff --git a/.isort.cfg b/.isort.cfg index ad4d2fb8c0e010ee8b098fe2393efd319b3cc8b0..f03c5435a2bf4754e2f733ada06d172d3321cef3 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -8,4 +8,4 @@ line_length = 88 default_section=FIRSTPARTY known_first_party = arkindex,arkindex_common -known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard +known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml,zstandard diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index c517220b891a8957ff20e1fdaae8149afea9d49f..cca442ea7c5334119a415c3a391aaa450190fd47 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -1,14 +1,16 @@ # -*- coding: utf-8 -*- +# import json import os from pathlib import Path -import responses -from responses import matchers - from arkindex.mock import MockApiClient from arkindex_worker.worker import BaseWorker from arkindex_worker.worker.training import TrainingMixin, create_archive +# import pytest +# import responses +# from responses import matchers + class TrainingWorker(BaseWorker, TrainingMixin): def __init__(self): @@ -26,7 +28,9 @@ def test_create_archive_folder(): ): assert os.path.exists(zst_archive_path), "The archive was not created" assert ( - hash == "7dd70931222ef0496ea75e5aee674043" + # hash == "7dd70931222ef0496ea75e5aee674043" + hash + == "c5aedde18a768757351068b840c8c8f9" ), "Hash was not properly computed" assert 300 < size < 700 @@ -41,7 +45,7 @@ def test_create_model_version(): model_files_dir = Path("tests/samples/model_files") # model_file_path = model_files_dir / "model_file.pth" training = TrainingWorker() - client = MockApiClient() + training.api_client = MockApiClient() with create_archive(path=model_files_dir) as ( zst_archive_path, hash, @@ -58,35 +62,16 @@ def test_create_model_version(): "s3_put_url": "http://hehehe.com", } - client.__setattr__("model_version_details", model_version_details) - responses.add( - responses.POST, - f"http://testserver/api/v1/model/{model_id}/versions/", - status=200, - match=[ - matchers.json_params_matcher( - {"hash": hash, "archive_hash": archive_hash, "size": size} - ) - ], - json=model_version_details, + response_mock = training.api_client.add_response( + "CreateModelVersion", + response=model_version_details, + body={"hash": hash, "archive_hash": archive_hash, "size": size}, ) - print(model_version_details) - # responses.add( - # responses.POST, - # f"http://testserver/api/v1/model/{model_id}/versions/", - # status=200, - # match=[ - # matchers.json_params_matcher( - # {"hash": hash, "archive_hash": archive_hash, "size": size} - # ) - # ], - # json=model_version_details, - # ) - + print(response_mock) assert ( training.create_model_version(model_id, hash, size, archive_hash) - == model_version_details + == response_mock ) @@ -98,7 +83,8 @@ def test_create_model_version(): # model_version_id = "fake_model_version_id" # # Create a model archive and keep its hash and size. # model_file_path = samples_dir / "model_file.pth" -# training = TrainingMixin() +# training = TrainingWorker() +# training.api_client = MockApiClient() # with create_archive(path=model_file_path) as ( # zst_archive_path, # hash, @@ -115,18 +101,22 @@ def test_create_model_version(): # "s3_put_url": "http://hehehe.com", # } -# responses.add( -# responses.POST, -# f"http://testserver/api/v1/model/{model_id}/versions/", -# status=400, -# match=[ -# matchers.json_params_matcher( -# {"hash": hash, "archive_hash": archive_hash, "size": size} -# ) -# ], -# json={"hash": model_version_details}, +# response_mock = training.api_client.add_response( + # ) +# # responses.add( +# # responses.POST, +# # f"http://testserver/api/v1/model/{model_id}/versions/", +# # status=400, +# # match=[ +# # matchers.json_params_matcher( +# # {"hash": hash, "archive_hash": archive_hash, "size": size} +# # ) +# # ], +# # json={"hash": model_version_details}, +# # ) + # assert ( # training.create_model_version(api_client, model_id, hash, size, archive_hash) # == model_version_details @@ -138,7 +128,7 @@ def test_create_model_version(): # model_id = "fake_model_id" # # Create a model archive and keep its hash and size. # model_file_path = samples_dir / "model_file.pth" -# training = TrainingMixin() +# training = TrainingWorker() # with create_archive(path=model_file_path) as ( # zst_archive_path, # hash, @@ -165,6 +155,6 @@ def test_create_model_version(): # responses.add_passthru(s3_endpoint_url) # responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) # file_path = samples_dir / "model_file.pth" -# training = TrainingMixin() +# training = TrainingWorker() # with pytest.raises(Exception): # training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})