diff --git a/.isort.cfg b/.isort.cfg index f03c5435a2bf4754e2f733ada06d172d3321cef3..ad4d2fb8c0e010ee8b098fe2393efd319b3cc8b0 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -8,4 +8,4 @@ line_length = 88 default_section=FIRSTPARTY known_first_party = arkindex,arkindex_common -known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,setuptools,sh,shapely,tenacity,yaml,zstandard +known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,recommonmark,requests,responses,setuptools,sh,shapely,tenacity,yaml,zstandard diff --git a/tests/conftest.py b/tests/conftest.py index 35ac31aeae79afacfaddb69a512dac09b603a7e7..ad3598205e8acc98f0874b6c52b90543ba8d9974 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ from arkindex_worker.worker import BaseWorker, ElementsWorker from arkindex_worker.worker.transcription import TextOrientation FIXTURES_DIR = Path(__file__).resolve().parent / "data" +SAMPLES_DIR = Path("tests/samples") __yaml_cache = {} @@ -276,6 +277,11 @@ def fake_transcriptions_small(): return json.load(f) +@pytest.fixture +def model_file_dir(): + return SAMPLES_DIR / "model_files" + + @pytest.fixture def fake_dummy_worker(): api_client = MockApiClient() diff --git a/tests/test_elements_worker/test_training.py b/tests/test_elements_worker/test_training.py index cca442ea7c5334119a415c3a391aaa450190fd47..e1b64f4228bd824c7c9e009a6af0d53090b6fce2 100644 --- a/tests/test_elements_worker/test_training.py +++ b/tests/test_elements_worker/test_training.py @@ -1,25 +1,23 @@ # -*- coding: utf-8 -*- # import json import os -from pathlib import Path + +import pytest +import responses +from apistar.exceptions import ErrorResponse from arkindex.mock import MockApiClient from arkindex_worker.worker import BaseWorker from arkindex_worker.worker.training import TrainingMixin, create_archive -# import pytest -# import responses # from responses import matchers class TrainingWorker(BaseWorker, TrainingMixin): - def __init__(self): - super().setup_api_client() - + pass -def test_create_archive_folder(): - model_file_dir = Path("tests/samples/model_files") +def test_create_archive_folder(model_file_dir): with create_archive(path=model_file_dir) as ( zst_archive_path, hash, @@ -28,9 +26,7 @@ def test_create_archive_folder(): ): assert os.path.exists(zst_archive_path), "The archive was not created" assert ( - # hash == "7dd70931222ef0496ea75e5aee674043" - hash - == "c5aedde18a768757351068b840c8c8f9" + hash == "c5aedde18a768757351068b840c8c8f9" ), "Hash was not properly computed" assert 300 < size < 700 @@ -41,120 +37,114 @@ def test_create_model_version(): """A new model version is returned""" model_id = "fake_model_id" model_version_id = "fake_model_version_id" - # Create a model archive and keep its hash and size. - model_files_dir = Path("tests/samples/model_files") - # model_file_path = model_files_dir / "model_file.pth" training = TrainingWorker() training.api_client = MockApiClient() - with create_archive(path=model_files_dir) as ( - zst_archive_path, - hash, - size, - archive_hash, - ): - model_version_details = { - "id": model_version_id, - "model_id": model_id, - "hash": hash, - "archive_hash": archive_hash, - "size": size, - "s3_url": "http://hehehe.com", - "s3_put_url": "http://hehehe.com", - } - - response_mock = training.api_client.add_response( + hash = "hash" + archive_hash = "archive_hash" + size = "30" + model_version_details = { + "id": model_version_id, + "model_id": model_id, + "hash": hash, + "archive_hash": archive_hash, + "size": size, + "s3_url": "http://hehehe.com", + "s3_put_url": "http://hehehe.com", + } + + training.api_client.add_response( "CreateModelVersion", + id=model_id, response=model_version_details, body={"hash": hash, "archive_hash": archive_hash, "size": size}, ) - - print(response_mock) assert ( training.create_model_version(model_id, hash, size, archive_hash) - == response_mock + == model_version_details ) -# def test_retrieve_created_model_version(api_client, samples_dir): -# """There is an existing model version in Created mode, A 400 was raised. -# But the model is still returned in error content -# """ -# model_id = "fake_model_id" -# model_version_id = "fake_model_version_id" -# # Create a model archive and keep its hash and size. -# model_file_path = samples_dir / "model_file.pth" -# training = TrainingWorker() -# training.api_client = MockApiClient() -# with create_archive(path=model_file_path) as ( -# zst_archive_path, -# hash, -# size, -# archive_hash, -# ): -# model_version_details = { -# "id": model_version_id, -# "model_id": model_id, -# "hash": hash, -# "archive_hash": archive_hash, -# "size": size, -# "s3_url": "http://hehehe.com", -# "s3_put_url": "http://hehehe.com", -# } - -# response_mock = training.api_client.add_response( +@pytest.mark.parametrize( + "process_exception, status_code", + [ + # (None, 200) + ( + ErrorResponse( + title="Mock error response", status_code=400, content="Bad gateway" + ), + 400, + ), + ( + ErrorResponse( + title="Mock error response", status_code=403, content="Bad gateway" + ), + 403, + ), + ], +) +def test_retrieve_created_model_version(process_exception, status_code): + """There is an existing model version in Created mode, A 400 was raised. + But the model is still returned in error content + """ + model_id = "fake_model_id" + training = TrainingWorker() + training.api_client = MockApiClient() + hash = "hash" + archive_hash = "archive_hash" + size = "30" + training.api_client.add_error_response( + "CreateModelVersion", + id=model_id, + status_code=status_code, + body={"hash": hash, "archive_hash": archive_hash, "size": size}, + ) -# ) + if process_exception.status_code == status_code: + assert ( + training.create_model_version(model_id, hash, size, archive_hash) + == process_exception + ) -# # responses.add( -# # responses.POST, -# # f"http://testserver/api/v1/model/{model_id}/versions/", -# # status=400, -# # match=[ -# # matchers.json_params_matcher( -# # {"hash": hash, "archive_hash": archive_hash, "size": size} -# # ) -# # ], -# # json={"hash": model_version_details}, -# # ) - -# assert ( -# training.create_model_version(api_client, model_id, hash, size, archive_hash) -# == model_version_details -# ) + # with pytest.raises(Exception): + # training.create_model_version(model_id, hash, size, archive_hash) -# def test_retrieve_available_model_version(api_client, samples_dir): +# def test_retrieve_available_model_version(): # """Raise error when there is an existing model version in Available mode""" # model_id = "fake_model_id" -# # Create a model archive and keep its hash and size. -# model_file_path = samples_dir / "model_file.pth" +# model_version_id = "fake_model_version_id" # training = TrainingWorker() -# with create_archive(path=model_file_path) as ( -# zst_archive_path, -# hash, -# size, -# archive_hash, -# ): -# responses.add( -# responses.POST, -# f"http://testserver/api/v1/model/{model_id}/versions/", -# status=403, -# match=[ -# matchers.json_params_matcher( -# {"hash": hash, "archive_hash": archive_hash, "size": size} -# ) -# ], -# ) +# training.api_client = MockApiClient() +# hash = "hash" +# archive_hash = "archive_hash" +# size = "30" +# model_version_details = { +# "id": model_version_id, +# "model_id": model_id, +# "hash": hash, +# "archive_hash": archive_hash, +# "size": size, +# "s3_url": "http://hehehe.com", +# "s3_put_url": "http://hehehe.com", +# } + +# training.api_client.add_error_response( +# "CreateModelVersion", +# id=model_id, +# status_code=403, +# body={"hash": hash, "archive_hash": archive_hash, "size": size} +# ) # with pytest.raises(Exception): -# training.create_model_version(api_client, model_id, hash, size, archive_hash) +# training.create_model_version(model_id, hash, size, archive_hash) -# def test_handle_s3_uploading_errors(samples_dir): -# s3_endpoint_url = "http://s3.localhost.com" -# responses.add_passthru(s3_endpoint_url) -# responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) -# file_path = samples_dir / "model_file.pth" -# training = TrainingWorker() -# with pytest.raises(Exception): -# training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url}) +def test_handle_s3_uploading_errors(model_file_dir): + training = TrainingWorker() + training.api_client = MockApiClient() + s3_endpoint_url = "http://s3.localhost.com" + responses.add_passthru(s3_endpoint_url) + responses.add(responses.Response(method="PUT", url=s3_endpoint_url, status=400)) + file_path = model_file_dir / "model_file.pth" + with pytest.raises(Exception): + training.upload_to_s3(file_path, {"s3_put_url": s3_endpoint_url})